├── README.md
├── checkpoint
├── best_model.pth
└── eval.txt
├── data_utils
├── ModelNetDataLoader.py
├── S3DISDataLoader.py
├── ShapeNetDataLoader.py
├── __pycache__
│ ├── ModelNetDataLoader.cpython-37.pyc
│ └── ShapeNetDataLoader.cpython-37.pyc
├── collect_indoor3d_data.py
├── indoor3d_util.py
└── meta
│ ├── anno_paths.txt
│ └── class_names.txt
├── images
├── example.jpg
└── model.jpg
├── models
├── SCNet.py
├── utils.py
└── z_order.py
├── paint.py
├── provider.py
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # PointSCNet: Point Cloud Structure and Correlation Learning based on Space Filling Curve guided Sampling
2 |
3 | ## Description
4 | [](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=pointscnet-point-cloud-structure-and)
5 |
6 | This repository contains the code for our paper: [PointSCNet: Point Cloud Structure and Correlation Learning based on Space Filling Curve guided Sampling](https://doi.org/10.3390/sym14010008)
7 |
8 |
9 |

10 |
11 |
12 |
13 | ## Environment setup
14 |
15 | Current Code is tested on ubuntu18.04 with cuda11, python3.6.9, torch 1.10.0 and torchvision 0.11.3.
16 | We use a [pytorch version of pointnet++](https://github.com/yanx27/Pointnet_Pointnet2_pytorch) in our pipeline.
17 |
18 |
19 | ## Classification (ModelNet10/40)
20 | ### Data Preparation
21 | Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`.
22 |
23 | ### Data Preparation
24 |
25 |
26 | ### Run
27 |
28 | ```
29 | python train.py --model SCNet --log_dir SCNet_log --use_normals --process_data
30 | ```
31 |
32 | * --model: model name
33 | * --log_dir: path to log dir
34 | * --use_normals: use normals
35 | * --process_data: save data offline
36 |
37 | ## Test
38 |
39 | ```
40 | python test.py --log_dir SCNet_log --use_normals
41 | ```
42 |
43 |
44 | ## Performance
45 | | Model | Accuracy |
46 | |--|--|
47 | | PointNet (Official) | 89.2|
48 | | PointNet2 (Official) | 91.9 |
49 | | PointSCNet | **93.7**|
50 |
51 |
52 | ## Citation
53 | Please cite our paper if you find it useful in your research:
54 |
55 | ```
56 | @article{chen2022pointscnet,
57 | title={PointSCNet: Point Cloud Structure and Correlation Learning Based on Space-Filling Curve-Guided Sampling},
58 | author={Chen, Xingye and Wu, Yiqi and Xu, Wenjie and Li, Jin and Dong, Huaiyi and Chen, Yilin},
59 | journal={Symmetry},
60 | volume={14},
61 | number={1},
62 | pages={8},
63 | year={2022},
64 | publisher={Multidisciplinary Digital Publishing Institute}
65 | }
66 | ```
67 |
68 | ## Contact
69 | If you have any questions, please contact cxy@cug.edu.cn
70 |
--------------------------------------------------------------------------------
/checkpoint/best_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/checkpoint/best_model.pth
--------------------------------------------------------------------------------
/checkpoint/eval.txt:
--------------------------------------------------------------------------------
1 | 2021-11-30 18:50:58,111 - Model - INFO - PARAMETER ...
2 | 2021-11-30 18:50:58,111 - Model - INFO - Namespace(batch_size=24, gpu='0', log_dir='att_dot_test', num_category=40, num_point=1024, num_votes=2, use_cpu=False, use_normals=True, use_uniform_sample=False)
3 | 2021-11-30 18:50:58,111 - Model - INFO - Load dataset ...
4 | 2021-11-30 18:50:58,117 - Model - INFO - PointSCNet
5 | 2021-11-30 18:51:38,716 - Model - INFO - Test Instance Accuracy: 0.937217, Class Accuracy: 0.914043
6 | 2021-11-30 18:51:38,716 - Model - INFO - Class airplane Accuracy: 1.000000
7 | 2021-11-30 18:51:38,716 - Model - INFO - Class bathtub Accuracy: 0.986111
8 | 2021-11-30 18:51:38,716 - Model - INFO - Class bed Accuracy: 1.000000
9 | 2021-11-30 18:51:38,716 - Model - INFO - Class bench Accuracy: 0.845238
10 | 2021-11-30 18:51:38,716 - Model - INFO - Class bookshelf Accuracy: 0.988889
11 | 2021-11-30 18:51:38,716 - Model - INFO - Class bottle Accuracy: 0.960714
12 | 2021-11-30 18:51:38,716 - Model - INFO - Class bowl Accuracy: 1.000000
13 | 2021-11-30 18:51:38,716 - Model - INFO - Class car Accuracy: 1.000000
14 | 2021-11-30 18:51:38,716 - Model - INFO - Class chair Accuracy: 0.991667
15 | 2021-11-30 18:51:38,716 - Model - INFO - Class cone Accuracy: 1.000000
16 | 2021-11-30 18:51:38,717 - Model - INFO - Class cup Accuracy: 0.750000
17 | 2021-11-30 18:51:38,717 - Model - INFO - Class curtain Accuracy: 0.880952
18 | 2021-11-30 18:51:38,717 - Model - INFO - Class desk Accuracy: 0.903472
19 | 2021-11-30 18:51:38,717 - Model - INFO - Class door Accuracy: 0.968750
20 | 2021-11-30 18:51:38,717 - Model - INFO - Class dresser Accuracy: 0.883333
21 | 2021-11-30 18:51:38,717 - Model - INFO - Class flower_pot Accuracy: 0.111111
22 | 2021-11-30 18:51:38,717 - Model - INFO - Class glass_box Accuracy: 0.949242
23 | 2021-11-30 18:51:38,717 - Model - INFO - Class guitar Accuracy: 1.000000
24 | 2021-11-30 18:51:38,717 - Model - INFO - Class keyboard Accuracy: 1.000000
25 | 2021-11-30 18:51:38,717 - Model - INFO - Class lamp Accuracy: 1.000000
26 | 2021-11-30 18:51:38,717 - Model - INFO - Class laptop Accuracy: 1.000000
27 | 2021-11-30 18:51:38,717 - Model - INFO - Class mantel Accuracy: 0.986111
28 | 2021-11-30 18:51:38,717 - Model - INFO - Class monitor Accuracy: 0.991667
29 | 2021-11-30 18:51:38,717 - Model - INFO - Class night_stand Accuracy: 0.798611
30 | 2021-11-30 18:51:38,717 - Model - INFO - Class person Accuracy: 1.000000
31 | 2021-11-30 18:51:38,717 - Model - INFO - Class piano Accuracy: 0.940000
32 | 2021-11-30 18:51:38,717 - Model - INFO - Class plant Accuracy: 0.833333
33 | 2021-11-30 18:51:38,717 - Model - INFO - Class radio Accuracy: 0.750000
34 | 2021-11-30 18:51:38,717 - Model - INFO - Class range_hood Accuracy: 0.925000
35 | 2021-11-30 18:51:38,717 - Model - INFO - Class sink Accuracy: 0.900000
36 | 2021-11-30 18:51:38,717 - Model - INFO - Class sofa Accuracy: 1.000000
37 | 2021-11-30 18:51:38,717 - Model - INFO - Class stairs Accuracy: 1.000000
38 | 2021-11-30 18:51:38,717 - Model - INFO - Class stool Accuracy: 0.968750
39 | 2021-11-30 18:51:38,717 - Model - INFO - Class table Accuracy: 0.880000
40 | 2021-11-30 18:51:38,717 - Model - INFO - Class tent Accuracy: 0.968750
41 | 2021-11-30 18:51:38,717 - Model - INFO - Class toilet Accuracy: 1.000000
42 | 2021-11-30 18:51:38,717 - Model - INFO - Class tv_stand Accuracy: 0.900000
43 | 2021-11-30 18:51:38,717 - Model - INFO - Class vase Accuracy: 0.850000
44 | 2021-11-30 18:51:38,717 - Model - INFO - Class wardrobe Accuracy: 0.700000
45 | 2021-11-30 18:51:38,717 - Model - INFO - Class xbox Accuracy: 0.950000
46 |
--------------------------------------------------------------------------------
/data_utils/ModelNetDataLoader.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author: Xu Yan
3 | @file: ModelNet.py
4 | @time: 2021/3/19 15:51
5 | '''
6 | import os
7 | import numpy as np
8 | import warnings
9 | import pickle
10 |
11 | from tqdm import tqdm
12 | from torch.utils.data import Dataset
13 |
14 | warnings.filterwarnings('ignore')
15 |
16 |
17 | def pc_normalize(pc):
18 | centroid = np.mean(pc, axis=0)
19 | pc = pc - centroid
20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
21 | pc = pc / m
22 | return pc
23 |
24 |
25 | def farthest_point_sample(point, npoint):
26 | """
27 | Input:
28 | xyz: pointcloud data, [N, D]
29 | npoint: number of samples
30 | Return:
31 | centroids: sampled pointcloud index, [npoint, D]
32 | """
33 | N, D = point.shape
34 | xyz = point[:,:3]
35 | centroids = np.zeros((npoint,))
36 | distance = np.ones((N,)) * 1e10
37 | farthest = np.random.randint(0, N)
38 | for i in range(npoint):
39 | centroids[i] = farthest
40 | centroid = xyz[farthest, :]
41 | dist = np.sum((xyz - centroid) ** 2, -1)
42 | mask = dist < distance
43 | distance[mask] = dist[mask]
44 | farthest = np.argmax(distance, -1)
45 | point = point[centroids.astype(np.int32)]
46 | return point
47 |
48 |
49 | class ModelNetDataLoader(Dataset):
50 | def __init__(self, root, args, split='train', process_data=False):
51 | self.root = root
52 | self.npoints = args.num_point
53 | self.process_data = process_data
54 | self.uniform = args.use_uniform_sample
55 | self.use_normals = args.use_normals
56 | self.num_category = args.num_category
57 |
58 | if self.num_category == 10:
59 | self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
60 | else:
61 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
62 |
63 | self.cat = [line.rstrip() for line in open(self.catfile)]
64 | self.classes = dict(zip(self.cat, range(len(self.cat))))
65 |
66 | shape_ids = {}
67 | if self.num_category == 10:
68 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
69 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
70 | else:
71 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
72 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
73 |
74 | assert (split == 'train' or split == 'test')
75 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
76 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
77 | in range(len(shape_ids[split]))]
78 | print('The size of %s data is %d' % (split, len(self.datapath)))
79 |
80 | if self.uniform:
81 | self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
82 | else:
83 | self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
84 |
85 | if self.process_data:
86 | if not os.path.exists(self.save_path):
87 | print('Processing data %s (only running in the first time)...' % self.save_path)
88 | self.list_of_points = [None] * len(self.datapath)
89 | self.list_of_labels = [None] * len(self.datapath)
90 |
91 | for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
92 | fn = self.datapath[index]
93 | cls = self.classes[self.datapath[index][0]]
94 | cls = np.array([cls]).astype(np.int32)
95 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
96 |
97 | if self.uniform:
98 | point_set = farthest_point_sample(point_set, self.npoints)
99 | else:
100 | point_set = point_set[0:self.npoints, :]
101 |
102 | self.list_of_points[index] = point_set
103 | self.list_of_labels[index] = cls
104 |
105 | with open(self.save_path, 'wb') as f:
106 | pickle.dump([self.list_of_points, self.list_of_labels], f)
107 | else:
108 | print('Load processed data from %s...' % self.save_path)
109 | with open(self.save_path, 'rb') as f:
110 | self.list_of_points, self.list_of_labels = pickle.load(f)
111 |
112 | def __len__(self):
113 | return len(self.datapath)
114 |
115 | def _get_item(self, index):
116 | if self.process_data:
117 | point_set, label = self.list_of_points[index], self.list_of_labels[index]
118 | else:
119 | fn = self.datapath[index]
120 | cls = self.classes[self.datapath[index][0]]
121 | label = np.array([cls]).astype(np.int32)
122 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
123 |
124 | if self.uniform:
125 | point_set = farthest_point_sample(point_set, self.npoints)
126 | else:
127 | point_set = point_set[0:self.npoints, :]
128 |
129 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
130 | if not self.use_normals:
131 | point_set = point_set[:, 0:3]
132 |
133 | return point_set, label[0]
134 |
135 | def __getitem__(self, index):
136 | return self._get_item(index)
137 |
138 |
139 | if __name__ == '__main__':
140 | import torch
141 |
142 |
143 | data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train')
144 |
145 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True,num_workers=10)
146 | for point, label in DataLoader:
147 | print(point.shape)
148 | print(label.shape)
149 |
--------------------------------------------------------------------------------
/data_utils/S3DISDataLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | from tqdm import tqdm
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class S3DISDataset(Dataset):
9 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None):
10 | super().__init__()
11 | self.num_point = num_point
12 | self.block_size = block_size
13 | self.transform = transform
14 | rooms = sorted(os.listdir(data_root))
15 | rooms = [room for room in rooms if 'Area_' in room]
16 | if split == 'train':
17 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room]
18 | else:
19 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room]
20 |
21 | self.room_points, self.room_labels = [], []
22 | self.room_coord_min, self.room_coord_max = [], []
23 | num_point_all = []
24 | labelweights = np.zeros(13)
25 |
26 | for room_name in tqdm(rooms_split, total=len(rooms_split)):
27 | room_path = os.path.join(data_root, room_name)
28 | room_data = np.load(room_path) # xyzrgbl, N*7
29 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N
30 | tmp, _ = np.histogram(labels, range(14))
31 | labelweights += tmp
32 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
33 | self.room_points.append(points), self.room_labels.append(labels)
34 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
35 | num_point_all.append(labels.size)
36 | labelweights = labelweights.astype(np.float32)
37 | labelweights = labelweights / np.sum(labelweights)
38 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
39 | print(self.labelweights)
40 | sample_prob = num_point_all / np.sum(num_point_all)
41 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
42 | room_idxs = []
43 | for index in range(len(rooms_split)):
44 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
45 | self.room_idxs = np.array(room_idxs)
46 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split))
47 |
48 | def __getitem__(self, idx):
49 | room_idx = self.room_idxs[idx]
50 | points = self.room_points[room_idx] # N * 6
51 | labels = self.room_labels[room_idx] # N
52 | N_points = points.shape[0]
53 |
54 | while (True):
55 | center = points[np.random.choice(N_points)][:3]
56 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0]
57 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0]
58 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0]
59 | if point_idxs.size > 1024:
60 | break
61 |
62 | if point_idxs.size >= self.num_point:
63 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False)
64 | else:
65 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True)
66 |
67 | # normalize
68 | selected_points = points[selected_point_idxs, :] # num_point * 6
69 | current_points = np.zeros((self.num_point, 9)) # num_point * 9
70 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0]
71 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1]
72 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2]
73 | selected_points[:, 0] = selected_points[:, 0] - center[0]
74 | selected_points[:, 1] = selected_points[:, 1] - center[1]
75 | selected_points[:, 3:6] /= 255.0
76 | current_points[:, 0:6] = selected_points
77 | current_labels = labels[selected_point_idxs]
78 | if self.transform is not None:
79 | current_points, current_labels = self.transform(current_points, current_labels)
80 | return current_points, current_labels
81 |
82 | def __len__(self):
83 | return len(self.room_idxs)
84 |
85 | class ScannetDatasetWholeScene():
86 | # prepare to give prediction on each points
87 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001):
88 | self.block_points = block_points
89 | self.block_size = block_size
90 | self.padding = padding
91 | self.root = root
92 | self.split = split
93 | self.stride = stride
94 | self.scene_points_num = []
95 | assert split in ['train', 'test']
96 | if self.split == 'train':
97 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1]
98 | else:
99 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1]
100 | self.scene_points_list = []
101 | self.semantic_labels_list = []
102 | self.room_coord_min, self.room_coord_max = [], []
103 | for file in self.file_list:
104 | data = np.load(root + file)
105 | points = data[:, :3]
106 | self.scene_points_list.append(data[:, :6])
107 | self.semantic_labels_list.append(data[:, 6])
108 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
109 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
110 | assert len(self.scene_points_list) == len(self.semantic_labels_list)
111 |
112 | labelweights = np.zeros(13)
113 | for seg in self.semantic_labels_list:
114 | tmp, _ = np.histogram(seg, range(14))
115 | self.scene_points_num.append(seg.shape[0])
116 | labelweights += tmp
117 | labelweights = labelweights.astype(np.float32)
118 | labelweights = labelweights / np.sum(labelweights)
119 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
120 |
121 | def __getitem__(self, index):
122 | point_set_ini = self.scene_points_list[index]
123 | points = point_set_ini[:,:6]
124 | labels = self.semantic_labels_list[index]
125 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
126 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1)
127 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1)
128 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([])
129 | for index_y in range(0, grid_y):
130 | for index_x in range(0, grid_x):
131 | s_x = coord_min[0] + index_x * self.stride
132 | e_x = min(s_x + self.block_size, coord_max[0])
133 | s_x = e_x - self.block_size
134 | s_y = coord_min[1] + index_y * self.stride
135 | e_y = min(s_y + self.block_size, coord_max[1])
136 | s_y = e_y - self.block_size
137 | point_idxs = np.where(
138 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & (
139 | points[:, 1] <= e_y + self.padding))[0]
140 | if point_idxs.size == 0:
141 | continue
142 | num_batch = int(np.ceil(point_idxs.size / self.block_points))
143 | point_size = int(num_batch * self.block_points)
144 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True
145 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace)
146 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat))
147 | np.random.shuffle(point_idxs)
148 | data_batch = points[point_idxs, :]
149 | normlized_xyz = np.zeros((point_size, 3))
150 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0]
151 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1]
152 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2]
153 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0)
154 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0)
155 | data_batch[:, 3:6] /= 255.0
156 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1)
157 | label_batch = labels[point_idxs].astype(int)
158 | batch_weight = self.labelweights[label_batch]
159 |
160 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch
161 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch
162 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight
163 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs
164 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1]))
165 | label_room = label_room.reshape((-1, self.block_points))
166 | sample_weight = sample_weight.reshape((-1, self.block_points))
167 | index_room = index_room.reshape((-1, self.block_points))
168 | return data_room, label_room, sample_weight, index_room
169 |
170 | def __len__(self):
171 | return len(self.scene_points_list)
172 |
173 | if __name__ == '__main__':
174 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/'
175 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01
176 |
177 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None)
178 | print('point data size:', point_data.__len__())
179 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape)
180 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape)
181 | import torch, time, random
182 | manual_seed = 123
183 | random.seed(manual_seed)
184 | np.random.seed(manual_seed)
185 | torch.manual_seed(manual_seed)
186 | torch.cuda.manual_seed_all(manual_seed)
187 | def worker_init_fn(worker_id):
188 | random.seed(manual_seed + worker_id)
189 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn)
190 | for idx in range(4):
191 | end = time.time()
192 | for i, (input, target) in enumerate(train_loader):
193 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end))
194 | end = time.time()
--------------------------------------------------------------------------------
/data_utils/ShapeNetDataLoader.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import json
4 | import warnings
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 | warnings.filterwarnings('ignore')
8 |
9 | def pc_normalize(pc):
10 | centroid = np.mean(pc, axis=0)
11 | pc = pc - centroid
12 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
13 | pc = pc / m
14 | return pc
15 |
16 | class PartNormalDataset(Dataset):
17 | def __init__(self,root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False):
18 | self.npoints = npoints
19 | self.root = root
20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
21 | self.cat = {}
22 | self.normal_channel = normal_channel
23 |
24 |
25 | with open(self.catfile, 'r') as f:
26 | for line in f:
27 | ls = line.strip().split()
28 | self.cat[ls[0]] = ls[1]
29 | self.cat = {k: v for k, v in self.cat.items()}
30 | self.classes_original = dict(zip(self.cat, range(len(self.cat))))
31 |
32 | if not class_choice is None:
33 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
34 | # print(self.cat)
35 |
36 | self.meta = {}
37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
38 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
39 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
40 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
41 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
42 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
43 | for item in self.cat:
44 | # print('category', item)
45 | self.meta[item] = []
46 | dir_point = os.path.join(self.root, self.cat[item])
47 | fns = sorted(os.listdir(dir_point))
48 | # print(fns[0][0:-4])
49 | if split == 'trainval':
50 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
51 | elif split == 'train':
52 | fns = [fn for fn in fns if fn[0:-4] in train_ids]
53 | elif split == 'val':
54 | fns = [fn for fn in fns if fn[0:-4] in val_ids]
55 | elif split == 'test':
56 | fns = [fn for fn in fns if fn[0:-4] in test_ids]
57 | else:
58 | print('Unknown split: %s. Exiting..' % (split))
59 | exit(-1)
60 |
61 | # print(os.path.basename(fns))
62 | for fn in fns:
63 | token = (os.path.splitext(os.path.basename(fn))[0])
64 | self.meta[item].append(os.path.join(dir_point, token + '.txt'))
65 |
66 | self.datapath = []
67 | for item in self.cat:
68 | for fn in self.meta[item]:
69 | self.datapath.append((item, fn))
70 |
71 | self.classes = {}
72 | for i in self.cat.keys():
73 | self.classes[i] = self.classes_original[i]
74 |
75 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
76 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
77 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
78 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
79 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
80 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
81 |
82 | # for cat in sorted(self.seg_classes.keys()):
83 | # print(cat, self.seg_classes[cat])
84 |
85 | self.cache = {} # from index to (point_set, cls, seg) tuple
86 | self.cache_size = 20000
87 |
88 |
89 | def __getitem__(self, index):
90 | if index in self.cache:
91 | point_set, cls, seg = self.cache[index]
92 | else:
93 | fn = self.datapath[index]
94 | cat = self.datapath[index][0]
95 | cls = self.classes[cat]
96 | cls = np.array([cls]).astype(np.int32)
97 | data = np.loadtxt(fn[1]).astype(np.float32)
98 | if not self.normal_channel:
99 | point_set = data[:, 0:3]
100 | else:
101 | point_set = data[:, 0:6]
102 | seg = data[:, -1].astype(np.int32)
103 | if len(self.cache) < self.cache_size:
104 | self.cache[index] = (point_set, cls, seg)
105 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
106 |
107 | choice = np.random.choice(len(seg), self.npoints, replace=True)
108 | # resample
109 | point_set = point_set[choice, :]
110 | seg = seg[choice]
111 |
112 | return point_set, cls, seg
113 |
114 | def __len__(self):
115 | return len(self.datapath)
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/data_utils/__pycache__/ModelNetDataLoader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/data_utils/__pycache__/ModelNetDataLoader.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/__pycache__/ShapeNetDataLoader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/data_utils/__pycache__/ShapeNetDataLoader.cpython-37.pyc
--------------------------------------------------------------------------------
/data_utils/collect_indoor3d_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from indoor3d_util import DATA_PATH, collect_point_label
4 |
5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
6 | ROOT_DIR = os.path.dirname(BASE_DIR)
7 | sys.path.append(BASE_DIR)
8 |
9 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))]
10 | anno_paths = [os.path.join(DATA_PATH, p) for p in anno_paths]
11 |
12 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d')
13 | if not os.path.exists(output_folder):
14 | os.mkdir(output_folder)
15 |
16 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually.
17 | for anno_path in anno_paths:
18 | print(anno_path)
19 | try:
20 | elements = anno_path.split('/')
21 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' # Area_1_hallway_1.npy
22 | collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy')
23 | except:
24 | print(anno_path, 'ERROR!!')
25 |
--------------------------------------------------------------------------------
/data_utils/indoor3d_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 | import sys
5 |
6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7 | ROOT_DIR = os.path.dirname(BASE_DIR)
8 | sys.path.append(BASE_DIR)
9 |
10 | DATA_PATH = os.path.join(ROOT_DIR, 'data','s3dis', 'Stanford3dDataset_v1.2_Aligned_Version')
11 | g_classes = [x.rstrip() for x in open(os.path.join(BASE_DIR, 'meta/class_names.txt'))]
12 | g_class2label = {cls: i for i,cls in enumerate(g_classes)}
13 | g_class2color = {'ceiling': [0,255,0],
14 | 'floor': [0,0,255],
15 | 'wall': [0,255,255],
16 | 'beam': [255,255,0],
17 | 'column': [255,0,255],
18 | 'window': [100,100,255],
19 | 'door': [200,200,100],
20 | 'table': [170,120,200],
21 | 'chair': [255,0,0],
22 | 'sofa': [200,100,100],
23 | 'bookcase': [10,200,100],
24 | 'board': [200,200,200],
25 | 'clutter': [50,50,50]}
26 | g_easy_view_labels = [7,8,9,10,11,1]
27 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes}
28 |
29 |
30 | # -----------------------------------------------------------------------------
31 | # CONVERT ORIGINAL DATA TO OUR DATA_LABEL FILES
32 | # -----------------------------------------------------------------------------
33 |
34 | def collect_point_label(anno_path, out_filename, file_format='txt'):
35 | """ Convert original dataset files to data_label file (each line is XYZRGBL).
36 | We aggregated all the points from each instance in the room.
37 |
38 | Args:
39 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
40 | out_filename: path to save collected points and labels (each line is XYZRGBL)
41 | file_format: txt or numpy, determines what file format to save.
42 | Returns:
43 | None
44 | Note:
45 | the points are shifted before save, the most negative point is now at origin.
46 | """
47 | points_list = []
48 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
49 | cls = os.path.basename(f).split('_')[0]
50 | print(f)
51 | if cls not in g_classes: # note: in some room there is 'staris' class..
52 | cls = 'clutter'
53 |
54 | points = np.loadtxt(f)
55 | labels = np.ones((points.shape[0],1)) * g_class2label[cls]
56 | points_list.append(np.concatenate([points, labels], 1)) # Nx7
57 |
58 | data_label = np.concatenate(points_list, 0)
59 | xyz_min = np.amin(data_label, axis=0)[0:3]
60 | data_label[:, 0:3] -= xyz_min
61 |
62 | if file_format=='txt':
63 | fout = open(out_filename, 'w')
64 | for i in range(data_label.shape[0]):
65 | fout.write('%f %f %f %d %d %d %d\n' % \
66 | (data_label[i,0], data_label[i,1], data_label[i,2],
67 | data_label[i,3], data_label[i,4], data_label[i,5],
68 | data_label[i,6]))
69 | fout.close()
70 | elif file_format=='numpy':
71 | np.save(out_filename, data_label)
72 | else:
73 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \
74 | (file_format))
75 | exit()
76 |
77 | def data_to_obj(data,name='example.obj',no_wall=True):
78 | fout = open(name, 'w')
79 | label = data[:, -1].astype(int)
80 | for i in range(data.shape[0]):
81 | if no_wall and ((label[i] == 2) or (label[i]==0)):
82 | continue
83 | fout.write('v %f %f %f %d %d %d\n' % \
84 | (data[i, 0], data[i, 1], data[i, 2], data[i, 3], data[i, 4], data[i, 5]))
85 | fout.close()
86 |
87 | def point_label_to_obj(input_filename, out_filename, label_color=True, easy_view=False, no_wall=False):
88 | """ For visualization of a room from data_label file,
89 | input_filename: each line is X Y Z R G B L
90 | out_filename: OBJ filename,
91 | visualize input file by coloring point with label color
92 | easy_view: only visualize furnitures and floor
93 | """
94 | data_label = np.loadtxt(input_filename)
95 | data = data_label[:, 0:6]
96 | label = data_label[:, -1].astype(int)
97 | fout = open(out_filename, 'w')
98 | for i in range(data.shape[0]):
99 | color = g_label2color[label[i]]
100 | if easy_view and (label[i] not in g_easy_view_labels):
101 | continue
102 | if no_wall and ((label[i] == 2) or (label[i]==0)):
103 | continue
104 | if label_color:
105 | fout.write('v %f %f %f %d %d %d\n' % \
106 | (data[i,0], data[i,1], data[i,2], color[0], color[1], color[2]))
107 | else:
108 | fout.write('v %f %f %f %d %d %d\n' % \
109 | (data[i,0], data[i,1], data[i,2], data[i,3], data[i,4], data[i,5]))
110 | fout.close()
111 |
112 |
113 |
114 | # -----------------------------------------------------------------------------
115 | # PREPARE BLOCK DATA FOR DEEPNETS TRAINING/TESTING
116 | # -----------------------------------------------------------------------------
117 |
118 | def sample_data(data, num_sample):
119 | """ data is in N x ...
120 | we want to keep num_samplexC of them.
121 | if N > num_sample, we will randomly keep num_sample of them.
122 | if N < num_sample, we will randomly duplicate samples.
123 | """
124 | N = data.shape[0]
125 | if (N == num_sample):
126 | return data, range(N)
127 | elif (N > num_sample):
128 | sample = np.random.choice(N, num_sample)
129 | return data[sample, ...], sample
130 | else:
131 | sample = np.random.choice(N, num_sample-N)
132 | dup_data = data[sample, ...]
133 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample)
134 |
135 | def sample_data_label(data, label, num_sample):
136 | new_data, sample_indices = sample_data(data, num_sample)
137 | new_label = label[sample_indices]
138 | return new_data, new_label
139 |
140 | def room2blocks(data, label, num_point, block_size=1.0, stride=1.0,
141 | random_sample=False, sample_num=None, sample_aug=1):
142 | """ Prepare block training data.
143 | Args:
144 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1]
145 | assumes the data is shifted (min point is origin) and aligned
146 | (aligned with XYZ axis)
147 | label: N size uint8 numpy array from 0-12
148 | num_point: int, how many points to sample in each block
149 | block_size: float, physical size of the block in meters
150 | stride: float, stride for block sweeping
151 | random_sample: bool, if True, we will randomly sample blocks in the room
152 | sample_num: int, if random sample, how many blocks to sample
153 | [default: room area]
154 | sample_aug: if random sample, how much aug
155 | Returns:
156 | block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1]
157 | block_labels: K x num_point x 1 np array of uint8 labels
158 |
159 | TODO: for this version, blocking is in fixed, non-overlapping pattern.
160 | """
161 | assert(stride<=block_size)
162 |
163 | limit = np.amax(data, 0)[0:3]
164 |
165 | # Get the corner location for our sampling blocks
166 | xbeg_list = []
167 | ybeg_list = []
168 | if not random_sample:
169 | num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1
170 | num_block_y = int(np.ceil(collect_point_label(limit[1] - block_size) / stride)) + 1
171 | for i in range(num_block_x):
172 | for j in range(num_block_y):
173 | xbeg_list.append(i*stride)
174 | ybeg_list.append(j*stride)
175 | else:
176 | num_block_x = int(np.ceil(limit[0] / block_size))
177 | num_block_y = int(np.ceil(limit[1] / block_size))
178 | if sample_num is None:
179 | sample_num = num_block_x * num_block_y * sample_aug
180 | for _ in range(sample_num):
181 | xbeg = np.random.uniform(-block_size, limit[0])
182 | ybeg = np.random.uniform(-block_size, limit[1])
183 | xbeg_list.append(xbeg)
184 | ybeg_list.append(ybeg)
185 |
186 | # Collect blocks
187 | block_data_list = []
188 | block_label_list = []
189 | idx = 0
190 | for idx in range(len(xbeg_list)):
191 | xbeg = xbeg_list[idx]
192 | ybeg = ybeg_list[idx]
193 | xcond = (data[:,0]<=xbeg+block_size) & (data[:,0]>=xbeg)
194 | ycond = (data[:,1]<=ybeg+block_size) & (data[:,1]>=ybeg)
195 | cond = xcond & ycond
196 | if np.sum(cond) < 100: # discard block if there are less than 100 pts.
197 | continue
198 |
199 | block_data = data[cond, :]
200 | block_label = label[cond]
201 |
202 | # randomly subsample data
203 | block_data_sampled, block_label_sampled = \
204 | sample_data_label(block_data, block_label, num_point)
205 | block_data_list.append(np.expand_dims(block_data_sampled, 0))
206 | block_label_list.append(np.expand_dims(block_label_sampled, 0))
207 |
208 | return np.concatenate(block_data_list, 0), \
209 | np.concatenate(block_label_list, 0)
210 |
211 |
212 | def room2blocks_plus(data_label, num_point, block_size, stride,
213 | random_sample, sample_num, sample_aug):
214 | """ room2block with input filename and RGB preprocessing.
215 | """
216 | data = data_label[:,0:6]
217 | data[:,3:6] /= 255.0
218 | label = data_label[:,-1].astype(np.uint8)
219 |
220 | return room2blocks(data, label, num_point, block_size, stride,
221 | random_sample, sample_num, sample_aug)
222 |
223 | def room2blocks_wrapper(data_label_filename, num_point, block_size=1.0, stride=1.0,
224 | random_sample=False, sample_num=None, sample_aug=1):
225 | if data_label_filename[-3:] == 'txt':
226 | data_label = np.loadtxt(data_label_filename)
227 | elif data_label_filename[-3:] == 'npy':
228 | data_label = np.load(data_label_filename)
229 | else:
230 | print('Unknown file type! exiting.')
231 | exit()
232 | return room2blocks_plus(data_label, num_point, block_size, stride,
233 | random_sample, sample_num, sample_aug)
234 |
235 | def room2blocks_plus_normalized(data_label, num_point, block_size, stride,
236 | random_sample, sample_num, sample_aug):
237 | """ room2block, with input filename and RGB preprocessing.
238 | for each block centralize XYZ, add normalized XYZ as 678 channels
239 | """
240 | data = data_label[:,0:6]
241 | data[:,3:6] /= 255.0
242 | label = data_label[:,-1].astype(np.uint8)
243 | max_room_x = max(data[:,0])
244 | max_room_y = max(data[:,1])
245 | max_room_z = max(data[:,2])
246 |
247 | data_batch, label_batch = room2blocks(data, label, num_point, block_size, stride,
248 | random_sample, sample_num, sample_aug)
249 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9))
250 | for b in range(data_batch.shape[0]):
251 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x
252 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y
253 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z
254 | minx = min(data_batch[b, :, 0])
255 | miny = min(data_batch[b, :, 1])
256 | data_batch[b, :, 0] -= (minx+block_size/2)
257 | data_batch[b, :, 1] -= (miny+block_size/2)
258 | new_data_batch[:, :, 0:6] = data_batch
259 | return new_data_batch, label_batch
260 |
261 |
262 | def room2blocks_wrapper_normalized(data_label_filename, num_point, block_size=1.0, stride=1.0,
263 | random_sample=False, sample_num=None, sample_aug=1):
264 | if data_label_filename[-3:] == 'txt':
265 | data_label = np.loadtxt(data_label_filename)
266 | elif data_label_filename[-3:] == 'npy':
267 | data_label = np.load(data_label_filename)
268 | else:
269 | print('Unknown file type! exiting.')
270 | exit()
271 | return room2blocks_plus_normalized(data_label, num_point, block_size, stride,
272 | random_sample, sample_num, sample_aug)
273 |
274 | def room2samples(data, label, sample_num_point):
275 | """ Prepare whole room samples.
276 |
277 | Args:
278 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1]
279 | assumes the data is shifted (min point is origin) and
280 | aligned (aligned with XYZ axis)
281 | label: N size uint8 numpy array from 0-12
282 | sample_num_point: int, how many points to sample in each sample
283 | Returns:
284 | sample_datas: K x sample_num_point x 9
285 | numpy array of XYZRGBX'Y'Z', RGB is in [0,1]
286 | sample_labels: K x sample_num_point x 1 np array of uint8 labels
287 | """
288 | N = data.shape[0]
289 | order = np.arange(N)
290 | np.random.shuffle(order)
291 | data = data[order, :]
292 | label = label[order]
293 |
294 | batch_num = int(np.ceil(N / float(sample_num_point)))
295 | sample_datas = np.zeros((batch_num, sample_num_point, 6))
296 | sample_labels = np.zeros((batch_num, sample_num_point, 1))
297 |
298 | for i in range(batch_num):
299 | beg_idx = i*sample_num_point
300 | end_idx = min((i+1)*sample_num_point, N)
301 | num = end_idx - beg_idx
302 | sample_datas[i,0:num,:] = data[beg_idx:end_idx, :]
303 | sample_labels[i,0:num,0] = label[beg_idx:end_idx]
304 | if num < sample_num_point:
305 | makeup_indices = np.random.choice(N, sample_num_point - num)
306 | sample_datas[i,num:,:] = data[makeup_indices, :]
307 | sample_labels[i,num:,0] = label[makeup_indices]
308 | return sample_datas, sample_labels
309 |
310 | def room2samples_plus_normalized(data_label, num_point):
311 | """ room2sample, with input filename and RGB preprocessing.
312 | for each block centralize XYZ, add normalized XYZ as 678 channels
313 | """
314 | data = data_label[:,0:6]
315 | data[:,3:6] /= 255.0
316 | label = data_label[:,-1].astype(np.uint8)
317 | max_room_x = max(data[:,0])
318 | max_room_y = max(data[:,1])
319 | max_room_z = max(data[:,2])
320 | #print(max_room_x, max_room_y, max_room_z)
321 |
322 | data_batch, label_batch = room2samples(data, label, num_point)
323 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9))
324 | for b in range(data_batch.shape[0]):
325 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x
326 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y
327 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z
328 | #minx = min(data_batch[b, :, 0])
329 | #miny = min(data_batch[b, :, 1])
330 | #data_batch[b, :, 0] -= (minx+block_size/2)
331 | #data_batch[b, :, 1] -= (miny+block_size/2)
332 | new_data_batch[:, :, 0:6] = data_batch
333 | return new_data_batch, label_batch
334 |
335 |
336 | def room2samples_wrapper_normalized(data_label_filename, num_point):
337 | if data_label_filename[-3:] == 'txt':
338 | data_label = np.loadtxt(data_label_filename)
339 | elif data_label_filename[-3:] == 'npy':
340 | data_label = np.load(data_label_filename)
341 | else:
342 | print('Unknown file type! exiting.')
343 | exit()
344 | return room2samples_plus_normalized(data_label, num_point)
345 |
346 |
347 | # -----------------------------------------------------------------------------
348 | # EXTRACT INSTANCE BBOX FROM ORIGINAL DATA (for detection evaluation)
349 | # -----------------------------------------------------------------------------
350 |
351 | def collect_bounding_box(anno_path, out_filename):
352 | """ Compute bounding boxes from each instance in original dataset files on
353 | one room. **We assume the bbox is aligned with XYZ coordinate.**
354 |
355 | Args:
356 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
357 | out_filename: path to save instance bounding boxes for that room.
358 | each line is x1 y1 z1 x2 y2 z2 label,
359 | where (x1,y1,z1) is the point on the diagonal closer to origin
360 | Returns:
361 | None
362 | Note:
363 | room points are shifted, the most negative point is now at origin.
364 | """
365 | bbox_label_list = []
366 |
367 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
368 | cls = os.path.basename(f).split('_')[0]
369 | if cls not in g_classes: # note: in some room there is 'staris' class..
370 | cls = 'clutter'
371 | points = np.loadtxt(f)
372 | label = g_class2label[cls]
373 | # Compute tightest axis aligned bounding box
374 | xyz_min = np.amin(points[:, 0:3], axis=0)
375 | xyz_max = np.amax(points[:, 0:3], axis=0)
376 | ins_bbox_label = np.expand_dims(
377 | np.concatenate([xyz_min, xyz_max, np.array([label])], 0), 0)
378 | bbox_label_list.append(ins_bbox_label)
379 |
380 | bbox_label = np.concatenate(bbox_label_list, 0)
381 | room_xyz_min = np.amin(bbox_label[:, 0:3], axis=0)
382 | bbox_label[:, 0:3] -= room_xyz_min
383 | bbox_label[:, 3:6] -= room_xyz_min
384 |
385 | fout = open(out_filename, 'w')
386 | for i in range(bbox_label.shape[0]):
387 | fout.write('%f %f %f %f %f %f %d\n' % \
388 | (bbox_label[i,0], bbox_label[i,1], bbox_label[i,2],
389 | bbox_label[i,3], bbox_label[i,4], bbox_label[i,5],
390 | bbox_label[i,6]))
391 | fout.close()
392 |
393 | def bbox_label_to_obj(input_filename, out_filename_prefix, easy_view=False):
394 | """ Visualization of bounding boxes.
395 |
396 | Args:
397 | input_filename: each line is x1 y1 z1 x2 y2 z2 label
398 | out_filename_prefix: OBJ filename prefix,
399 | visualize object by g_label2color
400 | easy_view: if True, only visualize furniture and floor
401 | Returns:
402 | output a list of OBJ file and MTL files with the same prefix
403 | """
404 | bbox_label = np.loadtxt(input_filename)
405 | bbox = bbox_label[:, 0:6]
406 | label = bbox_label[:, -1].astype(int)
407 | v_cnt = 0 # count vertex
408 | ins_cnt = 0 # count instance
409 | for i in range(bbox.shape[0]):
410 | if easy_view and (label[i] not in g_easy_view_labels):
411 | continue
412 | obj_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.obj'
413 | mtl_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.mtl'
414 | fout_obj = open(obj_filename, 'w')
415 | fout_mtl = open(mtl_filename, 'w')
416 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename)))
417 |
418 | length = bbox[i, 3:6] - bbox[i, 0:3]
419 | a = length[0]
420 | b = length[1]
421 | c = length[2]
422 | x = bbox[i, 0]
423 | y = bbox[i, 1]
424 | z = bbox[i, 2]
425 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0
426 |
427 | material = 'material%d' % (ins_cnt)
428 | fout_obj.write('usemtl %s\n' % (material))
429 | fout_obj.write('v %f %f %f\n' % (x,y,z+c))
430 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c))
431 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c))
432 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c))
433 | fout_obj.write('v %f %f %f\n' % (x,y,z))
434 | fout_obj.write('v %f %f %f\n' % (x,y+b,z))
435 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z))
436 | fout_obj.write('v %f %f %f\n' % (x+a,y,z))
437 | fout_obj.write('g default\n')
438 | v_cnt = 0 # for individual box
439 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt))
440 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt))
441 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt))
442 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt))
443 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt))
444 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt))
445 | fout_obj.write('\n')
446 |
447 | fout_mtl.write('newmtl %s\n' % (material))
448 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2]))
449 | fout_mtl.write('\n')
450 | fout_obj.close()
451 | fout_mtl.close()
452 |
453 | v_cnt += 8
454 | ins_cnt += 1
455 |
456 | def bbox_label_to_obj_room(input_filename, out_filename_prefix, easy_view=False, permute=None, center=False, exclude_table=False):
457 | """ Visualization of bounding boxes.
458 |
459 | Args:
460 | input_filename: each line is x1 y1 z1 x2 y2 z2 label
461 | out_filename_prefix: OBJ filename prefix,
462 | visualize object by g_label2color
463 | easy_view: if True, only visualize furniture and floor
464 | permute: if not None, permute XYZ for rendering, e.g. [0 2 1]
465 | center: if True, move obj to have zero origin
466 | Returns:
467 | output a list of OBJ file and MTL files with the same prefix
468 | """
469 | bbox_label = np.loadtxt(input_filename)
470 | bbox = bbox_label[:, 0:6]
471 | if permute is not None:
472 | assert(len(permute)==3)
473 | permute = np.array(permute)
474 | bbox[:,0:3] = bbox[:,permute]
475 | bbox[:,3:6] = bbox[:,permute+3]
476 | if center:
477 | xyz_max = np.amax(bbox[:,3:6], 0)
478 | bbox[:,0:3] -= (xyz_max/2.0)
479 | bbox[:,3:6] -= (xyz_max/2.0)
480 | bbox /= np.max(xyz_max/2.0)
481 | label = bbox_label[:, -1].astype(int)
482 | obj_filename = out_filename_prefix+'.obj'
483 | mtl_filename = out_filename_prefix+'.mtl'
484 |
485 | fout_obj = open(obj_filename, 'w')
486 | fout_mtl = open(mtl_filename, 'w')
487 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename)))
488 | v_cnt = 0 # count vertex
489 | ins_cnt = 0 # count instance
490 | for i in range(bbox.shape[0]):
491 | if easy_view and (label[i] not in g_easy_view_labels):
492 | continue
493 | if exclude_table and label[i] == g_classes.index('table'):
494 | continue
495 |
496 | length = bbox[i, 3:6] - bbox[i, 0:3]
497 | a = length[0]
498 | b = length[1]
499 | c = length[2]
500 | x = bbox[i, 0]
501 | y = bbox[i, 1]
502 | z = bbox[i, 2]
503 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0
504 |
505 | material = 'material%d' % (ins_cnt)
506 | fout_obj.write('usemtl %s\n' % (material))
507 | fout_obj.write('v %f %f %f\n' % (x,y,z+c))
508 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c))
509 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c))
510 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c))
511 | fout_obj.write('v %f %f %f\n' % (x,y,z))
512 | fout_obj.write('v %f %f %f\n' % (x,y+b,z))
513 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z))
514 | fout_obj.write('v %f %f %f\n' % (x+a,y,z))
515 | fout_obj.write('g default\n')
516 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt))
517 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt))
518 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt))
519 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt))
520 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt))
521 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt))
522 | fout_obj.write('\n')
523 |
524 | fout_mtl.write('newmtl %s\n' % (material))
525 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2]))
526 | fout_mtl.write('\n')
527 |
528 | v_cnt += 8
529 | ins_cnt += 1
530 |
531 | fout_obj.close()
532 | fout_mtl.close()
533 |
534 |
535 | def collect_point_bounding_box(anno_path, out_filename, file_format):
536 | """ Compute bounding boxes from each instance in original dataset files on
537 | one room. **We assume the bbox is aligned with XYZ coordinate.**
538 | Save both the point XYZRGB and the bounding box for the point's
539 | parent element.
540 |
541 | Args:
542 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
543 | out_filename: path to save instance bounding boxes for each point,
544 | plus the point's XYZRGBL
545 | each line is XYZRGBL offsetX offsetY offsetZ a b c,
546 | where cx = X+offsetX, cy=X+offsetY, cz=Z+offsetZ
547 | where (cx,cy,cz) is center of the box, a,b,c are distances from center
548 | to the surfaces of the box, i.e. x1 = cx-a, x2 = cx+a, y1=cy-b etc.
549 | file_format: output file format, txt or numpy
550 | Returns:
551 | None
552 |
553 | Note:
554 | room points are shifted, the most negative point is now at origin.
555 | """
556 | point_bbox_list = []
557 |
558 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
559 | cls = os.path.basename(f).split('_')[0]
560 | if cls not in g_classes: # note: in some room there is 'staris' class..
561 | cls = 'clutter'
562 | points = np.loadtxt(f) # Nx6
563 | label = g_class2label[cls] # N,
564 | # Compute tightest axis aligned bounding box
565 | xyz_min = np.amin(points[:, 0:3], axis=0) # 3,
566 | xyz_max = np.amax(points[:, 0:3], axis=0) # 3,
567 | xyz_center = (xyz_min + xyz_max) / 2
568 | dimension = (xyz_max - xyz_min) / 2
569 |
570 | xyz_offsets = xyz_center - points[:,0:3] # Nx3
571 | dimensions = np.ones((points.shape[0],3)) * dimension # Nx3
572 | labels = np.ones((points.shape[0],1)) * label # N
573 | point_bbox_list.append(np.concatenate([points, labels,
574 | xyz_offsets, dimensions], 1)) # Nx13
575 |
576 | point_bbox = np.concatenate(point_bbox_list, 0) # KxNx13
577 | room_xyz_min = np.amin(point_bbox[:, 0:3], axis=0)
578 | point_bbox[:, 0:3] -= room_xyz_min
579 |
580 | if file_format == 'txt':
581 | fout = open(out_filename, 'w')
582 | for i in range(point_bbox.shape[0]):
583 | fout.write('%f %f %f %d %d %d %d %f %f %f %f %f %f\n' % \
584 | (point_bbox[i,0], point_bbox[i,1], point_bbox[i,2],
585 | point_bbox[i,3], point_bbox[i,4], point_bbox[i,5],
586 | point_bbox[i,6],
587 | point_bbox[i,7], point_bbox[i,8], point_bbox[i,9],
588 | point_bbox[i,10], point_bbox[i,11], point_bbox[i,12]))
589 |
590 | fout.close()
591 | elif file_format == 'numpy':
592 | np.save(out_filename, point_bbox)
593 | else:
594 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \
595 | (file_format))
596 | exit()
597 |
598 |
599 |
--------------------------------------------------------------------------------
/data_utils/meta/anno_paths.txt:
--------------------------------------------------------------------------------
1 | Area_1/conferenceRoom_1/Annotations
2 | Area_1/conferenceRoom_2/Annotations
3 | Area_1/copyRoom_1/Annotations
4 | Area_1/hallway_1/Annotations
5 | Area_1/hallway_2/Annotations
6 | Area_1/hallway_3/Annotations
7 | Area_1/hallway_4/Annotations
8 | Area_1/hallway_5/Annotations
9 | Area_1/hallway_6/Annotations
10 | Area_1/hallway_7/Annotations
11 | Area_1/hallway_8/Annotations
12 | Area_1/office_10/Annotations
13 | Area_1/office_11/Annotations
14 | Area_1/office_12/Annotations
15 | Area_1/office_13/Annotations
16 | Area_1/office_14/Annotations
17 | Area_1/office_15/Annotations
18 | Area_1/office_16/Annotations
19 | Area_1/office_17/Annotations
20 | Area_1/office_18/Annotations
21 | Area_1/office_19/Annotations
22 | Area_1/office_1/Annotations
23 | Area_1/office_20/Annotations
24 | Area_1/office_21/Annotations
25 | Area_1/office_22/Annotations
26 | Area_1/office_23/Annotations
27 | Area_1/office_24/Annotations
28 | Area_1/office_25/Annotations
29 | Area_1/office_26/Annotations
30 | Area_1/office_27/Annotations
31 | Area_1/office_28/Annotations
32 | Area_1/office_29/Annotations
33 | Area_1/office_2/Annotations
34 | Area_1/office_30/Annotations
35 | Area_1/office_31/Annotations
36 | Area_1/office_3/Annotations
37 | Area_1/office_4/Annotations
38 | Area_1/office_5/Annotations
39 | Area_1/office_6/Annotations
40 | Area_1/office_7/Annotations
41 | Area_1/office_8/Annotations
42 | Area_1/office_9/Annotations
43 | Area_1/pantry_1/Annotations
44 | Area_1/WC_1/Annotations
45 | Area_2/auditorium_1/Annotations
46 | Area_2/auditorium_2/Annotations
47 | Area_2/conferenceRoom_1/Annotations
48 | Area_2/hallway_10/Annotations
49 | Area_2/hallway_11/Annotations
50 | Area_2/hallway_12/Annotations
51 | Area_2/hallway_1/Annotations
52 | Area_2/hallway_2/Annotations
53 | Area_2/hallway_3/Annotations
54 | Area_2/hallway_4/Annotations
55 | Area_2/hallway_5/Annotations
56 | Area_2/hallway_6/Annotations
57 | Area_2/hallway_7/Annotations
58 | Area_2/hallway_8/Annotations
59 | Area_2/hallway_9/Annotations
60 | Area_2/office_10/Annotations
61 | Area_2/office_11/Annotations
62 | Area_2/office_12/Annotations
63 | Area_2/office_13/Annotations
64 | Area_2/office_14/Annotations
65 | Area_2/office_1/Annotations
66 | Area_2/office_2/Annotations
67 | Area_2/office_3/Annotations
68 | Area_2/office_4/Annotations
69 | Area_2/office_5/Annotations
70 | Area_2/office_6/Annotations
71 | Area_2/office_7/Annotations
72 | Area_2/office_8/Annotations
73 | Area_2/office_9/Annotations
74 | Area_2/storage_1/Annotations
75 | Area_2/storage_2/Annotations
76 | Area_2/storage_3/Annotations
77 | Area_2/storage_4/Annotations
78 | Area_2/storage_5/Annotations
79 | Area_2/storage_6/Annotations
80 | Area_2/storage_7/Annotations
81 | Area_2/storage_8/Annotations
82 | Area_2/storage_9/Annotations
83 | Area_2/WC_1/Annotations
84 | Area_2/WC_2/Annotations
85 | Area_3/conferenceRoom_1/Annotations
86 | Area_3/hallway_1/Annotations
87 | Area_3/hallway_2/Annotations
88 | Area_3/hallway_3/Annotations
89 | Area_3/hallway_4/Annotations
90 | Area_3/hallway_5/Annotations
91 | Area_3/hallway_6/Annotations
92 | Area_3/lounge_1/Annotations
93 | Area_3/lounge_2/Annotations
94 | Area_3/office_10/Annotations
95 | Area_3/office_1/Annotations
96 | Area_3/office_2/Annotations
97 | Area_3/office_3/Annotations
98 | Area_3/office_4/Annotations
99 | Area_3/office_5/Annotations
100 | Area_3/office_6/Annotations
101 | Area_3/office_7/Annotations
102 | Area_3/office_8/Annotations
103 | Area_3/office_9/Annotations
104 | Area_3/storage_1/Annotations
105 | Area_3/storage_2/Annotations
106 | Area_3/WC_1/Annotations
107 | Area_3/WC_2/Annotations
108 | Area_4/conferenceRoom_1/Annotations
109 | Area_4/conferenceRoom_2/Annotations
110 | Area_4/conferenceRoom_3/Annotations
111 | Area_4/hallway_10/Annotations
112 | Area_4/hallway_11/Annotations
113 | Area_4/hallway_12/Annotations
114 | Area_4/hallway_13/Annotations
115 | Area_4/hallway_14/Annotations
116 | Area_4/hallway_1/Annotations
117 | Area_4/hallway_2/Annotations
118 | Area_4/hallway_3/Annotations
119 | Area_4/hallway_4/Annotations
120 | Area_4/hallway_5/Annotations
121 | Area_4/hallway_6/Annotations
122 | Area_4/hallway_7/Annotations
123 | Area_4/hallway_8/Annotations
124 | Area_4/hallway_9/Annotations
125 | Area_4/lobby_1/Annotations
126 | Area_4/lobby_2/Annotations
127 | Area_4/office_10/Annotations
128 | Area_4/office_11/Annotations
129 | Area_4/office_12/Annotations
130 | Area_4/office_13/Annotations
131 | Area_4/office_14/Annotations
132 | Area_4/office_15/Annotations
133 | Area_4/office_16/Annotations
134 | Area_4/office_17/Annotations
135 | Area_4/office_18/Annotations
136 | Area_4/office_19/Annotations
137 | Area_4/office_1/Annotations
138 | Area_4/office_20/Annotations
139 | Area_4/office_21/Annotations
140 | Area_4/office_22/Annotations
141 | Area_4/office_2/Annotations
142 | Area_4/office_3/Annotations
143 | Area_4/office_4/Annotations
144 | Area_4/office_5/Annotations
145 | Area_4/office_6/Annotations
146 | Area_4/office_7/Annotations
147 | Area_4/office_8/Annotations
148 | Area_4/office_9/Annotations
149 | Area_4/storage_1/Annotations
150 | Area_4/storage_2/Annotations
151 | Area_4/storage_3/Annotations
152 | Area_4/storage_4/Annotations
153 | Area_4/WC_1/Annotations
154 | Area_4/WC_2/Annotations
155 | Area_4/WC_3/Annotations
156 | Area_4/WC_4/Annotations
157 | Area_5/conferenceRoom_1/Annotations
158 | Area_5/conferenceRoom_2/Annotations
159 | Area_5/conferenceRoom_3/Annotations
160 | Area_5/hallway_10/Annotations
161 | Area_5/hallway_11/Annotations
162 | Area_5/hallway_12/Annotations
163 | Area_5/hallway_13/Annotations
164 | Area_5/hallway_14/Annotations
165 | Area_5/hallway_15/Annotations
166 | Area_5/hallway_1/Annotations
167 | Area_5/hallway_2/Annotations
168 | Area_5/hallway_3/Annotations
169 | Area_5/hallway_4/Annotations
170 | Area_5/hallway_5/Annotations
171 | Area_5/hallway_6/Annotations
172 | Area_5/hallway_7/Annotations
173 | Area_5/hallway_8/Annotations
174 | Area_5/hallway_9/Annotations
175 | Area_5/lobby_1/Annotations
176 | Area_5/office_10/Annotations
177 | Area_5/office_11/Annotations
178 | Area_5/office_12/Annotations
179 | Area_5/office_13/Annotations
180 | Area_5/office_14/Annotations
181 | Area_5/office_15/Annotations
182 | Area_5/office_16/Annotations
183 | Area_5/office_17/Annotations
184 | Area_5/office_18/Annotations
185 | Area_5/office_19/Annotations
186 | Area_5/office_1/Annotations
187 | Area_5/office_20/Annotations
188 | Area_5/office_21/Annotations
189 | Area_5/office_22/Annotations
190 | Area_5/office_23/Annotations
191 | Area_5/office_24/Annotations
192 | Area_5/office_25/Annotations
193 | Area_5/office_26/Annotations
194 | Area_5/office_27/Annotations
195 | Area_5/office_28/Annotations
196 | Area_5/office_29/Annotations
197 | Area_5/office_2/Annotations
198 | Area_5/office_30/Annotations
199 | Area_5/office_31/Annotations
200 | Area_5/office_32/Annotations
201 | Area_5/office_33/Annotations
202 | Area_5/office_34/Annotations
203 | Area_5/office_35/Annotations
204 | Area_5/office_36/Annotations
205 | Area_5/office_37/Annotations
206 | Area_5/office_38/Annotations
207 | Area_5/office_39/Annotations
208 | Area_5/office_3/Annotations
209 | Area_5/office_40/Annotations
210 | Area_5/office_41/Annotations
211 | Area_5/office_42/Annotations
212 | Area_5/office_4/Annotations
213 | Area_5/office_5/Annotations
214 | Area_5/office_6/Annotations
215 | Area_5/office_7/Annotations
216 | Area_5/office_8/Annotations
217 | Area_5/office_9/Annotations
218 | Area_5/pantry_1/Annotations
219 | Area_5/storage_1/Annotations
220 | Area_5/storage_2/Annotations
221 | Area_5/storage_3/Annotations
222 | Area_5/storage_4/Annotations
223 | Area_5/WC_1/Annotations
224 | Area_5/WC_2/Annotations
225 | Area_6/conferenceRoom_1/Annotations
226 | Area_6/copyRoom_1/Annotations
227 | Area_6/hallway_1/Annotations
228 | Area_6/hallway_2/Annotations
229 | Area_6/hallway_3/Annotations
230 | Area_6/hallway_4/Annotations
231 | Area_6/hallway_5/Annotations
232 | Area_6/hallway_6/Annotations
233 | Area_6/lounge_1/Annotations
234 | Area_6/office_10/Annotations
235 | Area_6/office_11/Annotations
236 | Area_6/office_12/Annotations
237 | Area_6/office_13/Annotations
238 | Area_6/office_14/Annotations
239 | Area_6/office_15/Annotations
240 | Area_6/office_16/Annotations
241 | Area_6/office_17/Annotations
242 | Area_6/office_18/Annotations
243 | Area_6/office_19/Annotations
244 | Area_6/office_1/Annotations
245 | Area_6/office_20/Annotations
246 | Area_6/office_21/Annotations
247 | Area_6/office_22/Annotations
248 | Area_6/office_23/Annotations
249 | Area_6/office_24/Annotations
250 | Area_6/office_25/Annotations
251 | Area_6/office_26/Annotations
252 | Area_6/office_27/Annotations
253 | Area_6/office_28/Annotations
254 | Area_6/office_29/Annotations
255 | Area_6/office_2/Annotations
256 | Area_6/office_30/Annotations
257 | Area_6/office_31/Annotations
258 | Area_6/office_32/Annotations
259 | Area_6/office_33/Annotations
260 | Area_6/office_34/Annotations
261 | Area_6/office_35/Annotations
262 | Area_6/office_36/Annotations
263 | Area_6/office_37/Annotations
264 | Area_6/office_3/Annotations
265 | Area_6/office_4/Annotations
266 | Area_6/office_5/Annotations
267 | Area_6/office_6/Annotations
268 | Area_6/office_7/Annotations
269 | Area_6/office_8/Annotations
270 | Area_6/office_9/Annotations
271 | Area_6/openspace_1/Annotations
272 | Area_6/pantry_1/Annotations
273 |
--------------------------------------------------------------------------------
/data_utils/meta/class_names.txt:
--------------------------------------------------------------------------------
1 | ceiling
2 | floor
3 | wall
4 | beam
5 | column
6 | window
7 | door
8 | table
9 | chair
10 | sofa
11 | bookcase
12 | board
13 | clutter
14 |
--------------------------------------------------------------------------------
/images/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/images/example.jpg
--------------------------------------------------------------------------------
/images/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/images/model.jpg
--------------------------------------------------------------------------------
/models/SCNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
4 | from z_order import *
5 |
6 |
7 | def get_relation_zorder_sample(input_u, input_v, random_sample=True, sample_size=64):
8 | batchsize, in_uchannels, length = input_u.shape
9 | _, in_vchannels, _ = input_v.shape
10 | device = input_u.device
11 | if not random_sample:
12 | sample_size = length
13 |
14 | input_u = input_u.permute(0, 2, 1)
15 | input_v = input_v.permute(0, 2, 1)
16 | ides = z_order_point_sample(input_u[:, :, :3], sample_size)
17 |
18 | batch_indices = torch.arange(batchsize, dtype=torch.long).to(device).view(batchsize, 1).repeat(1, sample_size)
19 |
20 | temp_relationu = input_u[batch_indices, ides, :].permute(0, 2, 1)
21 | temp_relationv = input_v[batch_indices, ides, :].permute(0, 2, 1)
22 | input_u = input_u.permute(0, 2, 1)
23 | input_v = input_v.permute(0, 2, 1)
24 | relation_u = torch.cat([input_u.view(batchsize, -1, length, 1).repeat(1, 1, 1, sample_size),
25 | temp_relationu.view(batchsize, -1, 1, sample_size).repeat(1, 1, length, 1)], dim=1)
26 | relation_v = torch.cat([input_v.view(batchsize, -1, length, 1).repeat(1, 1, 1, sample_size),
27 | temp_relationv.view(batchsize, -1, 1, sample_size).repeat(1, 1, length, 1)], dim=1)
28 |
29 | return relation_u, relation_v, temp_relationu, temp_relationv
30 |
31 |
32 |
33 | class PSCNChannelAttention(nn.Module):
34 | def __init__(self, channel_in):
35 | super(PSCNChannelAttention, self).__init__()
36 | self.avg_pool = nn.AdaptiveAvgPool1d(1) # 全局自适应池化
37 | self.max_pool = nn.AdaptiveMaxPool1d(1)
38 |
39 | self.fc1 = nn.Sequential(
40 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1),
41 | nn.ReLU(inplace=True),
42 | nn.Conv1d(channel_in // 2, channel_in, bias=False, kernel_size=1),
43 | )
44 | self.fc2 = nn.Sequential(
45 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1),
46 | nn.ReLU(inplace=True),
47 | nn.Conv1d(channel_in // 2, channel_in, bias=False, kernel_size=1),
48 |
49 | )
50 | self.sigmoid = nn.Sigmoid()
51 |
52 | def forward(self, x):
53 | batch_size, channel_num, _ = x.size()
54 |
55 | avg_out = self.avg_pool(x).view(batch_size, channel_num, 1) # squeeze操作
56 | max_out = self.max_pool(x).view(batch_size, channel_num, 1)
57 |
58 | avg_y = self.fc1(avg_out).view(batch_size, channel_num, 1) # FC获取通道注意力权重,是具有全局信息的
59 | max_y = self.fc2(max_out).view(batch_size, channel_num, 1)
60 | out = self.sigmoid(avg_y + max_y)
61 | return out # 注意力作用每一个通道上
62 |
63 | # return x + x * out.expand_as(x) # 注意力作用每一个通道上
64 |
65 |
66 | class PSCNSpatialAttention(nn.Module):
67 | def __init__(self, channel_in, channel_out):
68 | super(PSCNSpatialAttention, self).__init__()
69 | self.avg_pool = nn.AdaptiveAvgPool1d(1) # 全局自适应池化
70 | self.max_pool = nn.AdaptiveMaxPool1d(1)
71 |
72 | self.fc1 = nn.Sequential(
73 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1),
74 | nn.ReLU(inplace=True),
75 | nn.Conv1d(channel_in // 2, channel_out, bias=False, kernel_size=1),
76 | )
77 | self.fc2 = nn.Sequential(
78 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1),
79 | nn.ReLU(inplace=True),
80 | nn.Conv1d(channel_in // 2, channel_out, bias=False, kernel_size=1),
81 |
82 | )
83 | self.sigmoid = nn.Sigmoid()
84 |
85 | def forward(self, x):
86 | batch_size, channel_num, point_num = x.size()
87 |
88 | avg_out = self.avg_pool(x).view(batch_size, channel_num, 1) # squeeze操作
89 | max_out = self.max_pool(x).view(batch_size, channel_num, 1)
90 |
91 | avg_y = self.fc1(avg_out).view(batch_size, 1, point_num) # FC获取通道注意力权重,是具有全局信息的
92 | max_y = self.fc2(max_out).view(batch_size, 1, point_num)
93 | out = self.sigmoid(avg_y + max_y)
94 | return out # 注意力作用每一个空间上
95 | # return x + x * out.expand_as(x) # 注意力作用每一个空间上
96 | # return x + x * out.expand_as(x), out # 注意力作用每一个空间上
97 |
98 |
99 | class PointSCN(nn.Module):
100 | def __init__(self, in_uchannels, in_vchannels, random_sample=True, sample_size=64):
101 | super(PointSCN, self).__init__()
102 |
103 | self.random_sample = random_sample
104 | self.sample_size = sample_size
105 |
106 | self.conv_gu = nn.Conv2d(2 * in_uchannels, 2 * in_uchannels, 1)
107 | self.bn1 = nn.BatchNorm2d(2 * in_uchannels)
108 |
109 | self.conv_gv = nn.Conv2d(2 * in_vchannels, 2 * in_vchannels, 1)
110 | self.bn2 = nn.BatchNorm2d(2 * in_vchannels)
111 |
112 | self.conv_uv = nn.Conv2d(2 * in_uchannels + 2 * in_vchannels, in_vchannels, 1)
113 | self.bn3 = nn.BatchNorm2d(in_vchannels)
114 |
115 | self.conv_f = nn.Conv1d(in_vchannels, in_vchannels, 1)
116 | self.bn4 = nn.BatchNorm1d(in_vchannels)
117 |
118 | def forward(self, input_u, input_v):
119 | """
120 | Input:
121 | input_u: input points position data, [B, C, N]
122 | input_v: input points data, [B, D, N]
123 | Return:
124 | new_xyz: sampled points position data, [B, C, S]
125 | new_points_concat: sample points feature data, [B, D', S]
126 | """
127 |
128 | relation_u, relation_v, _, _ = get_relation_zorder_sample(input_u, input_v, random_sample=self.random_sample,
129 | sample_size=self.sample_size)
130 |
131 | relation_uv = torch.cat(
132 | [F.relu(self.bn1(self.conv_gu(relation_u))), F.relu(self.bn2(self.conv_gv(relation_v)))], dim=1)
133 |
134 | relation_uv = F.relu(self.bn3(self.conv_uv(relation_uv)))
135 |
136 | relation_uv = torch.max(relation_uv, 3)[0]
137 |
138 | relation_uv = F.relu(self.bn4(self.conv_f(relation_uv)))
139 | relation_uv = torch.cat([input_v + relation_uv, input_u], dim=1)
140 |
141 | return relation_uv
142 |
143 |
144 | class get_model(nn.Module):
145 | def __init__(self, num_class, normal_channel=True):
146 | super(get_model, self).__init__()
147 | in_channel = 3 if normal_channel else 0
148 | self.normal_channel = normal_channel
149 | self.sa1 = PointNetSetAbstractionMsg(256, [0.1, 0.4], [16, 128], in_channel,
150 | [[32, 32, 64], [64, 96, 128]])
151 |
152 | self.PSCN1 = PointSCN(3, 64 + 128, random_sample=True)
153 | self.attention1 = PSCNChannelAttention(128 + 64 + 3)
154 | self.attention3 = PSCNSpatialAttention(128 + 64 + 3, 256)
155 |
156 | self.sa2 = PointNetSetAbstraction(None, None, None, 128 + 64 + 3 + 3, [256, 512, 1024], True) # 含池化得到
157 |
158 | self.fc1 = nn.Linear(1024, 512)
159 | self.bn1 = nn.BatchNorm1d(512)
160 | self.drop1 = nn.Dropout(0.4)
161 | self.fc2 = nn.Linear(512, 256)
162 | self.bn2 = nn.BatchNorm1d(256)
163 | self.drop2 = nn.Dropout(0.5)
164 | self.fc3 = nn.Linear(256, num_class)
165 |
166 | def forward(self, xyz):
167 | B, _, _ = xyz.shape
168 | if self.normal_channel:
169 | norm = xyz[:, 3:, :]
170 | xyz = xyz[:, :3, :]
171 | else:
172 | norm = None
173 | l1_xyz, l1_points = self.sa1(xyz, norm)
174 | l1_points = self.PSCN1(l1_xyz, l1_points)
175 | l1_points_att1 = self.attention1(l1_points)
176 | l1_points_att3 = self.attention3(l1_points)
177 | l1_points = l1_points * l1_points_att1.expand_as(l1_points) * l1_points_att3.expand_as(l1_points)
178 | l3_xyz, l3_points = self.sa2(l1_xyz, l1_points)
179 | x = l3_points.view(B, 1024)
180 | x = self.drop1(F.relu(self.bn1(self.fc1(x))))
181 | x = self.drop2(F.relu(self.bn2(self.fc2(x))))
182 | x = self.fc3(x)
183 | x = F.log_softmax(x, -1)
184 | return x, l3_points
185 |
186 |
187 | class get_loss(nn.Module):
188 | def __init__(self):
189 | super(get_loss, self).__init__()
190 |
191 | def forward(self, pred, target, trans_feat):
192 | total_loss = F.nll_loss(pred, target)
193 |
194 | return total_loss
195 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 |
8 | def timeit(tag, t):
9 | print("{}: {}s".format(tag, time() - t))
10 | return time()
11 |
12 |
13 | def pc_normalize(pc):
14 | l = pc.shape[0]
15 | centroid = np.mean(pc, axis=0)
16 | pc = pc - centroid
17 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
18 | pc = pc / m
19 | return pc
20 |
21 |
22 | def square_distance(src, dst):
23 | """
24 | Calculate Euclid distance between each two points.
25 |
26 | src^T * dst = xn * xm + yn * ym + zn * zm;
27 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
28 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
31 |
32 | Input:
33 | src: source points, [B, N, C]
34 | dst: target points, [B, M, C]
35 | Output:
36 | dist: per-point square distance, [B, N, M]
37 | """
38 | B, N, _ = src.shape
39 | _, M, _ = dst.shape
40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
41 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
43 | return dist
44 |
45 |
46 | def index_points(points, idx):
47 | """
48 |
49 | Input:
50 | points: input points data, [B, N, C]
51 | idx: sample index data, [B, S]
52 | Return:
53 | new_points:, indexed points data, [B, S, C]
54 | """
55 | device = points.device
56 | B = points.shape[0]
57 | view_shape = list(idx.shape)
58 | view_shape[1:] = [1] * (len(view_shape) - 1)
59 | repeat_shape = list(idx.shape)
60 | repeat_shape[0] = 1
61 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
62 | new_points = points[batch_indices, idx, :]
63 | return new_points
64 |
65 |
66 | def farthest_point_sample(xyz, npoint):
67 | """
68 | Input:
69 | xyz: pointcloud data, [B, N, 3]
70 | npoint: number of samples
71 | Return:
72 | centroids: sampled pointcloud index, [B, npoint]
73 | """
74 | device = xyz.device
75 | B, N, C = xyz.shape
76 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
77 | distance = torch.ones(B, N).to(device) * 1e10
78 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
79 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
80 | for i in range(npoint):
81 | centroids[:, i] = farthest
82 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
83 | dist = torch.sum((xyz - centroid) ** 2, -1)
84 | mask = dist < distance
85 | distance[mask] = dist[mask]
86 | farthest = torch.max(distance, -1)[1]
87 | return centroids
88 |
89 |
90 | def query_ball_point(radius, nsample, xyz, new_xyz):
91 | """
92 | Input:
93 | radius: local region radius
94 | nsample: max sample number in local region
95 | xyz: all points, [B, N, 3]
96 | new_xyz: query points, [B, S, 3]
97 | Return:
98 | group_idx: grouped points index, [B, S, nsample]
99 | """
100 | device = xyz.device
101 | B, N, C = xyz.shape
102 | _, S, _ = new_xyz.shape
103 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
104 | sqrdists = square_distance(new_xyz, xyz)
105 | group_idx[sqrdists > radius ** 2] = N
106 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
107 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
108 | mask = group_idx == N
109 | group_idx[mask] = group_first[mask]
110 | return group_idx
111 |
112 |
113 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
114 | """
115 | Input:
116 | npoint:
117 | radius:
118 | nsample:
119 | xyz: input points position data, [B, N, 3]
120 | points: input points data, [B, N, D]
121 | Return:
122 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
123 | new_points: sampled points data, [B, npoint, nsample, 3+D]
124 | """
125 | B, N, C = xyz.shape
126 | S = npoint
127 | # fps_idx = z_order.z_order_point_sample(xyz, npoint) # [B, npoint, C]
128 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
129 | new_xyz = index_points(xyz, fps_idx)
130 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
131 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
132 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
133 |
134 | if points is not None:
135 | grouped_points = index_points(points, idx)
136 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
137 | else:
138 | new_points = grouped_xyz_norm
139 | if returnfps:
140 | return new_xyz, new_points, grouped_xyz, fps_idx
141 | else:
142 | return new_xyz, new_points
143 |
144 |
145 | def sample_and_group_all(xyz, points):
146 | """
147 | Input:
148 | xyz: input points position data, [B, N, 3]
149 | points: input points data, [B, N, D]
150 | Return:
151 | new_xyz: sampled points position data, [B, 1, 3]
152 | new_points: sampled points data, [B, 1, N, 3+D]
153 | """
154 | device = xyz.device
155 | B, N, C = xyz.shape
156 | new_xyz = torch.zeros(B, 1, C).to(device)
157 | grouped_xyz = xyz.view(B, 1, N, C)
158 | if points is not None:
159 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
160 | else:
161 | new_points = grouped_xyz
162 | return new_xyz, new_points
163 |
164 |
165 | class PointNetSetAbstraction(nn.Module):
166 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
167 | super(PointNetSetAbstraction, self).__init__()
168 | self.npoint = npoint
169 | self.radius = radius
170 | self.nsample = nsample
171 | self.mlp_convs = nn.ModuleList()
172 | self.mlp_bns = nn.ModuleList()
173 | last_channel = in_channel
174 | for out_channel in mlp:
175 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
176 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
177 | last_channel = out_channel
178 | self.group_all = group_all
179 |
180 | def forward(self, xyz, points):
181 | """
182 | Input:
183 | xyz: input points position data, [B, C, N]
184 | points: input points data, [B, D, N]
185 | Return:
186 | new_xyz: sampled points position data, [B, C, S]
187 | new_points_concat: sample points feature data, [B, D', S]
188 | """
189 | xyz = xyz.permute(0, 2, 1)
190 | if points is not None:
191 | points = points.permute(0, 2, 1)
192 |
193 | if self.group_all:
194 | new_xyz, new_points = sample_and_group_all(xyz, points)
195 | else:
196 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
197 | # new_xyz: sampled points position data, [B, npoint, C]
198 | # new_points: sampled points data, [B, npoint, nsample, C+D]
199 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
200 | for i, conv in enumerate(self.mlp_convs):
201 | bn = self.mlp_bns[i]
202 | new_points = F.relu(bn(conv(new_points)))
203 |
204 | new_points = torch.max(new_points, 2)[0]
205 | new_xyz = new_xyz.permute(0, 2, 1)
206 | return new_xyz, new_points
207 |
208 | class PointNetSetAbstractionMsg(nn.Module):
209 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
210 | super(PointNetSetAbstractionMsg, self).__init__()
211 | self.npoint = npoint
212 | self.radius_list = radius_list
213 | self.nsample_list = nsample_list
214 | self.conv_blocks = nn.ModuleList()
215 | self.bn_blocks = nn.ModuleList()
216 | for i in range(len(mlp_list)):
217 | convs = nn.ModuleList()
218 | bns = nn.ModuleList()
219 | last_channel = in_channel + 3
220 | for out_channel in mlp_list[i]:
221 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
222 | bns.append(nn.BatchNorm2d(out_channel))
223 | last_channel = out_channel
224 | self.conv_blocks.append(convs)
225 | self.bn_blocks.append(bns)
226 |
227 | def forward(self, xyz, points):
228 | """
229 | Input:
230 | xyz: input points position data, [B, C, N]
231 | points: input points data, [B, D, N]
232 | Return:
233 | new_xyz: sampled points position data, [B, C, S]
234 | new_points_concat: sample points feature data, [B, D', S]
235 | """
236 | xyz = xyz.permute(0, 2, 1)
237 | if points is not None:
238 | points = points.permute(0, 2, 1)
239 |
240 | B, N, C = xyz.shape
241 | S = self.npoint
242 |
243 | # new_xyz = index_points(xyz, z_order.z_order_point_sample(xyz, S))
244 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
245 | new_points_list = []
246 | for i, radius in enumerate(self.radius_list):
247 | K = self.nsample_list[i]
248 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
249 | grouped_xyz = index_points(xyz, group_idx)
250 | grouped_xyz -= new_xyz.view(B, S, 1, C)
251 | if points is not None:
252 | grouped_points = index_points(points, group_idx)
253 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
254 | else:
255 | grouped_points = grouped_xyz
256 |
257 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
258 | for j in range(len(self.conv_blocks[i])):
259 | conv = self.conv_blocks[i][j]
260 | bn = self.bn_blocks[i][j]
261 | grouped_points = F.relu(bn(conv(grouped_points)))
262 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
263 | new_points_list.append(new_points)
264 |
265 | new_xyz = new_xyz.permute(0, 2, 1)
266 | new_points_concat = torch.cat(new_points_list, dim=1)
267 | return new_xyz, new_points_concat
268 |
269 | class PointNetFeaturePropagation(nn.Module):
270 | def __init__(self, in_channel, mlp):
271 | super(PointNetFeaturePropagation, self).__init__()
272 | self.mlp_convs = nn.ModuleList()
273 | self.mlp_bns = nn.ModuleList()
274 | last_channel = in_channel
275 | for out_channel in mlp:
276 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
277 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
278 | last_channel = out_channel
279 |
280 | def forward(self, xyz1, xyz2, points1, points2):
281 | """
282 | Input:
283 | xyz1: input points position data, [B, C, N]
284 | xyz2: sampled input points position data, [B, C, S]
285 | points1: input points data, [B, D, N]
286 | points2: sampled input points data, [B, D, S]
287 | Return:
288 | new_points: upsampled points data, [B, D', N]
289 | """
290 | xyz1 = xyz1.permute(0, 2, 1)
291 | xyz2 = xyz2.permute(0, 2, 1)
292 |
293 | points2 = points2.permute(0, 2, 1)
294 | B, N, C = xyz1.shape
295 | _, S, _ = xyz2.shape
296 |
297 | if S == 1:
298 | interpolated_points = points2.repeat(1, N, 1)
299 | else:
300 | dists = square_distance(xyz1, xyz2)
301 | dists, idx = dists.sort(dim=-1)
302 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
303 |
304 | dist_recip = 1.0 / (dists + 1e-8)
305 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
306 | weight = dist_recip / norm
307 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
308 |
309 | if points1 is not None:
310 | points1 = points1.permute(0, 2, 1)
311 | new_points = torch.cat([points1, interpolated_points], dim=-1)
312 | else:
313 | new_points = interpolated_points
314 |
315 | new_points = new_points.permute(0, 2, 1)
316 | for i, conv in enumerate(self.mlp_convs):
317 | bn = self.mlp_bns[i]
318 | new_points = F.relu(bn(conv(new_points)))
319 | return new_points
320 |
--------------------------------------------------------------------------------
/models/z_order.py:
--------------------------------------------------------------------------------
1 | # import numpy
2 | import numpy as np
3 | import torch
4 |
5 | from matplotlib import pyplot as plt
6 | import time
7 |
8 |
9 | def round_to_int_32(data):
10 | """
11 | Takes a Numpy array of float values between
12 | -1 and 1, and rounds them to significant
13 | 32-bit integer values, to be used in the
14 | morton code computation
15 |
16 | :param data: multidimensional numpy array
17 | :return: same as data but in 32-bit int format
18 | """
19 | # first we rescale points to 0-512
20 | data = 256 * (data + 1)
21 | # now convert to int
22 | data = np.round(2 ** 21 - data).astype(dtype=np.int32)
23 |
24 | return data
25 |
26 |
27 | def split_by_3(x):
28 | """
29 | Method to separate bits of a 32-bit integer
30 | by 3 positions apart, using the magic bits
31 | https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/
32 |
33 | :param x: 32-bit integer
34 | :return: x with bits separated
35 | """
36 | # we only look at 21 bits, since we want to generate
37 | # a 64-bit code eventually (3 x 21 bits = 63 bits, which
38 | # is the maximum we can fit in a 64-bit code)
39 | x &= 0x1fffff # only take first 21 bits
40 | # shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111
41 | x = (x | (x << 32)) & 0x1f00000000ffff
42 | # shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111
43 | x = (x | (x << 16)) & 0x1f0000ff0000ff
44 | # shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000
45 | x = (x | (x << 8)) & 0x100f00f00f00f00f
46 | # shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000
47 | x = (x | (x << 4)) & 0x10c30c30c30c30c3
48 | # shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001
49 | x = (x | (x << 2)) & 0x1249249249249249
50 |
51 | return x
52 |
53 |
54 | def get_z_order(x, y, z):
55 | """
56 | Given 3 arrays of corresponding x, y, z
57 | coordinates, compute the morton (or z) code for
58 | each point and return an index array
59 | We compute the Morton order as follows:
60 | 1- Split all coordinates by 3 (add 2 zeros between bits)
61 | 2- Shift bits left by 1 for y and 2 for z
62 | 3- Interleave x, shifted y, and shifted z
63 | The mordon order is the final interleaved bit sequence
64 |
65 | :param x: x coordinates
66 | :param y: y coordinates
67 | :param z: z coordinates
68 | :return: index array with morton code
69 | """
70 | res = 0
71 | res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2
72 |
73 | return res
74 |
75 |
76 | def get_z_values(data):
77 | """
78 | Computes the z values for a point array
79 | :param data: Nx3 array of x, y, and z location
80 |
81 | :return: Nx1 array of z values
82 | """
83 | data = data.cpu().detach().numpy()
84 | points_round = round_to_int_32(data) # convert to int
85 | z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2])
86 |
87 | return z
88 |
89 |
90 | def pointnet_index_points(points, idx):
91 | """
92 |
93 | Input:
94 | points: input points data, [B, N, C]
95 | idx: sample index data, [B, S]
96 | Return:
97 | new_points:, indexed points data, [B, S, C]
98 | """
99 | device = points.device
100 | B = points.shape[0]
101 | view_shape = list(idx.shape)
102 | view_shape[1:] = [1] * (len(view_shape) - 1)
103 | # print(view_shape)
104 | repeat_shape = list(idx.shape)
105 | repeat_shape[0] = 1
106 | # print(repeat_shape)
107 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
108 | # print(batch_indices)
109 | # print(batch_indices.shape, idx.shape)
110 | new_points = points[batch_indices, idx, :]
111 | # print(new_points)
112 | return new_points
113 |
114 |
115 | def farthest_point_sample(xyz, npoint):
116 | """
117 | Input:
118 | xyz: pointcloud data, [B, N, 3]
119 | npoint: number of samples
120 | Return:
121 | centroids: sampled pointcloud index, [B, npoint]
122 | """
123 | device = xyz.device
124 | B, N, C = xyz.shape
125 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
126 | distance = torch.ones(B, N).to(device) * 1e10
127 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
128 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
129 | for i in range(npoint):
130 | centroids[:, i] = farthest
131 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
132 | dist = torch.sum((xyz - centroid) ** 2, -1)
133 | mask = dist < distance
134 | distance[mask] = dist[mask]
135 | farthest = torch.max(distance, -1)[1]
136 | return centroids
137 |
138 |
139 | def z_order_point_sample(xyz, npoints):
140 | """
141 | Input:
142 | xyz: pointcloud data, [B, N, 3]
143 | npoint: number of samples
144 | Return:
145 | centroids: sampled pointcloud index, [B, npoint]
146 | """
147 | device = xyz.device
148 | B, N, C = xyz.shape
149 |
150 | if npoints >= N:
151 | return torch.linspace(0, N, steps=N, dtype=int).view(1, N).repeat(B, 1)
152 |
153 | centroids = torch.zeros(B, npoints, dtype=int).to(device)
154 | for batch_idx in range(B):
155 | z = get_z_values(xyz[batch_idx])
156 | z = np.argsort(z)
157 | centroids[batch_idx, :] = torch.from_numpy(
158 | z[torch.linspace(0, N - 1, steps=npoints, dtype=int)].reshape(1, npoints))
159 |
160 | return centroids
161 |
--------------------------------------------------------------------------------
/paint.py:
--------------------------------------------------------------------------------
1 | import re
2 | from matplotlib import *
3 | import matplotlib.pyplot as plt
4 | import os
5 | import numpy as np
6 | import pandas as pd
7 | from scipy import interpolate
8 | from scipy.interpolate import make_interp_spline
9 |
10 | pathlist = [
11 | r'C:\Users\LENOVO\Desktop\Pointnet\Pointnet_Pointnet2_pytorch-master\log\classification\test10\logs\pointnet_srn_cls_msg.txt']
12 |
13 |
14 | def read_log(log_path_list, logs, train_acc=True, test_acc=True, loss=True):
15 |
16 | fig = plt.figure(num=1, figsize=(8, 4))
17 | ax1 = fig.add_subplot(121)
18 | ax2 = fig.add_subplot(122)
19 |
20 | # 绘画处理
21 | ax1.set_xlabel('epoch')
22 | ax1.set_title('accuracy')
23 | ax1.axis([0, 200, 0.9, 0.95])
24 | ax2.set_xlabel('epoch')
25 | ax2.set_title('total_loss')
26 | ax2.axis([0, 100, 0, 300])
27 | for idx, path in enumerate(log_path_list):
28 | log_dir = path
29 | print(path)
30 | with open(log_dir, "r", encoding="utf-8") as f:
31 | content = f.read()
32 |
33 | epoch = re.findall(r'Epoch \d[0-9]*', content, re.M)
34 | Train_Instance_Accuracy = re.findall(r'Train Instance Accuracy: .*', content, re.M)
35 | if loss:
36 | Total_loss = re.findall(r'Total loss: .*', content, re.M)
37 | Test_Instance_Accuracy = re.findall(r'Test Instance Accuracy: .*,', content, re.M)
38 |
39 | data_length = min([len(epoch), len(Train_Instance_Accuracy), len(Test_Instance_Accuracy)])
40 | for i in range(data_length):
41 | epoch[i] = int(epoch[i].strip('Epoch '))
42 | Train_Instance_Accuracy[i] = float(Train_Instance_Accuracy[i].strip('Train Instance Accuracy: '))
43 | if loss:
44 | Total_loss[i] = float(Total_loss[i].strip('Total loss: '))
45 | Test_Instance_Accuracy[i] = float(Test_Instance_Accuracy[i].strip('Test Instance Accuracy: ,'))
46 | plt.rcParams["font.family"] = "SimHei"
47 | if loss:
48 | total_loss_x = epoch[:data_length]
49 | total_loss_y = Total_loss[:data_length]
50 |
51 | test_accuracy_x = epoch[:data_length]
52 | test_accuracy_y = Test_Instance_Accuracy[:data_length]
53 |
54 | train_accuracy_x = epoch[:data_length]
55 | train_accuracy_y = Train_Instance_Accuracy[:data_length]
56 | if test_acc:
57 | # test_accuracy_x = np.array(test_accuracy_x)
58 | # test_accuracy_y = np.array(test_accuracy_x)
59 | # x_new = np.linspace(test_accuracy_x.min(), test_accuracy_x.max(),
60 | # 1000) # 1000 represents number of points to make between T.min and T.max
61 | # y_smooth = make_interp_spline(test_accuracy_x, test_accuracy_y)(x_new)
62 | ax1.plot(test_accuracy_x, test_accuracy_y, label='test_accuracy_{}'.format(logs[idx]))
63 | # ax1.plot(x_new, y_smooth, label='test_accuracy_{}'.format(logs[idx]))
64 |
65 | ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True)
66 | if train_acc:
67 | ax1.plot(train_accuracy_x, train_accuracy_y, label='train_accuracy_{}'.format(logs[idx]))
68 | ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True)
69 | if loss:
70 | ax2.plot(total_loss_x, total_loss_y, label='loss_{}'.format(logs[idx]))
71 | ax2.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True)
72 | plt.show()
73 |
74 |
75 | def read_all_paths(log_dir):
76 | class_dirs = os.listdir(log_dir)
77 | paths = [os.path.join(log_dir, i) + r'\logs' for i in class_dirs]
78 | all_paths = [os.path.join(i, os.listdir(i)[0]) for i in paths]
79 | return all_paths
80 |
81 |
82 | def read_log_path(log_name, classfication=True):
83 | log_dir = []
84 | if classfication:
85 | for i in range(len(log_name)):
86 | log_dir.append('log/classification/' + log_name[i] + "/logs/" + \
87 | os.listdir('log/classification/' + log_name[i] + r'\logs')[0])
88 | return log_dir
89 |
90 |
91 | if __name__ == "__main__":
92 | logs = ['Ours', 'Pointnet++']
93 | lists = read_log_path(logs)
94 | read_log(lists, logs, train_acc=False, loss=True)
95 |
96 | # 提取数据
97 | # log_dir = r'C:\Users\LENOVO\Desktop\Pointnet\Pointnet_Pointnet2_pytorch-master\log\classification\test10\logs\pointnet_srn_cls_msg.txt'
98 | #
99 | # with open(log_dir, "r", encoding="utf-8") as f:
100 | # content = f.read()
101 | # epoch = re.findall(r'Epoch \d[0-9]*', content, re.M)
102 | # Train_Instance_Accuracy = re.findall(r'Train Instance Accuracy: .*', content, re.M)
103 | # Total_loss = re.findall(r'Total loss: .*', content, re.M)
104 | # Test_Instance_Accuracy = re.findall(r'Test Instance Accuracy: .*,', content, re.M)
105 | # for i in range(len(epoch)):
106 | # epoch[i] = int(epoch[i].strip('Epoch '))
107 | # Train_Instance_Accuracy[i] = float(Train_Instance_Accuracy[i].strip('Train Instance Accuracy: '))
108 | # Total_loss[i] = float(Total_loss[i].strip('Total loss: '))
109 | # Test_Instance_Accuracy[i] = float(Test_Instance_Accuracy[i].strip('Test Instance Accuracy: ,'))
110 | #
111 | # plt.rcParams["font.family"] = "SimHei"
112 | # total_loss_x = epoch
113 | # total_loss_y = Total_loss
114 | # test_accuracy_x = epoch
115 | # test_accuracy_y = Test_Instance_Accuracy
116 | # train_accuracy_x = epoch
117 | # train_accuracy_y = Train_Instance_Accuracy
118 | #
119 | # fig = plt.figure(num=1, figsize=(4, 4))
120 | # ax1 = fig.add_subplot(121)
121 | # ax2 = fig.add_subplot(122)
122 | # # 绘画处理
123 | # ax1.set_xlabel('epoch')
124 | # ax1.set_title('accuracy')
125 | # ax2.set_xlabel('epoch')
126 | # ax2.set_title('total_loss')
127 | #
128 | # ax1.plot(test_accuracy_x, test_accuracy_y, "r-", label='test_accuracy')
129 | #
130 | # ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True)
131 | # ax1.plot(train_accuracy_x, train_accuracy_y, "b-", label='train_accuracy')
132 | # ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True)
133 | #
134 | # ax2.plot(total_loss_x, total_loss_y, "c-")
135 | # plt.show()
136 |
--------------------------------------------------------------------------------
/provider.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def normalize_data(batch_data):
4 | """ Normalize the batch data, use coordinates of the block centered at origin,
5 | Input:
6 | BxNxC array
7 | Output:
8 | BxNxC array
9 | """
10 | B, N, C = batch_data.shape
11 | normal_data = np.zeros((B, N, C))
12 | for b in range(B):
13 | pc = batch_data[b]
14 | centroid = np.mean(pc, axis=0)
15 | pc = pc - centroid
16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
17 | pc = pc / m
18 | normal_data[b] = pc
19 | return normal_data
20 |
21 |
22 | def shuffle_data(data, labels):
23 | """ Shuffle data and labels.
24 | Input:
25 | data: B,N,... numpy array
26 | label: B,... numpy array
27 | Return:
28 | shuffled data, label and shuffle indices
29 | """
30 | idx = np.arange(len(labels))
31 | np.random.shuffle(idx)
32 | return data[idx, ...], labels[idx], idx
33 |
34 | def shuffle_points(batch_data):
35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior.
36 | Use the same shuffling idx for the entire batch.
37 | Input:
38 | BxNxC array
39 | Output:
40 | BxNxC array
41 | """
42 | idx = np.arange(batch_data.shape[1])
43 | np.random.shuffle(idx)
44 | return batch_data[:,idx,:]
45 |
46 | def rotate_point_cloud(batch_data):
47 | """ Randomly rotate the point clouds to augument the dataset
48 | rotation is per shape based along up direction
49 | Input:
50 | BxNx3 array, original batch of point clouds
51 | Return:
52 | BxNx3 array, rotated batch of point clouds
53 | """
54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
55 | for k in range(batch_data.shape[0]):
56 | rotation_angle = np.random.uniform() * 2 * np.pi
57 | cosval = np.cos(rotation_angle)
58 | sinval = np.sin(rotation_angle)
59 | rotation_matrix = np.array([[cosval, 0, sinval],
60 | [0, 1, 0],
61 | [-sinval, 0, cosval]])
62 | shape_pc = batch_data[k, ...]
63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
64 | return rotated_data
65 |
66 | def rotate_point_cloud_z(batch_data):
67 | """ Randomly rotate the point clouds to augument the dataset
68 | rotation is per shape based along up direction
69 | Input:
70 | BxNx3 array, original batch of point clouds
71 | Return:
72 | BxNx3 array, rotated batch of point clouds
73 | """
74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
75 | for k in range(batch_data.shape[0]):
76 | rotation_angle = np.random.uniform() * 2 * np.pi
77 | cosval = np.cos(rotation_angle)
78 | sinval = np.sin(rotation_angle)
79 | rotation_matrix = np.array([[cosval, sinval, 0],
80 | [-sinval, cosval, 0],
81 | [0, 0, 1]])
82 | shape_pc = batch_data[k, ...]
83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
84 | return rotated_data
85 |
86 | def rotate_point_cloud_with_normal(batch_xyz_normal):
87 | ''' Randomly rotate XYZ, normal point cloud.
88 | Input:
89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
90 | Output:
91 | B,N,6, rotated XYZ, normal point cloud
92 | '''
93 | for k in range(batch_xyz_normal.shape[0]):
94 | rotation_angle = np.random.uniform() * 2 * np.pi
95 | cosval = np.cos(rotation_angle)
96 | sinval = np.sin(rotation_angle)
97 | rotation_matrix = np.array([[cosval, 0, sinval],
98 | [0, 1, 0],
99 | [-sinval, 0, cosval]])
100 | shape_pc = batch_xyz_normal[k,:,0:3]
101 | shape_normal = batch_xyz_normal[k,:,3:6]
102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
104 | return batch_xyz_normal
105 |
106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
107 | """ Randomly perturb the point clouds by small rotations
108 | Input:
109 | BxNx6 array, original batch of point clouds and point normals
110 | Return:
111 | BxNx3 array, rotated batch of point clouds
112 | """
113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
114 | for k in range(batch_data.shape[0]):
115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
116 | Rx = np.array([[1,0,0],
117 | [0,np.cos(angles[0]),-np.sin(angles[0])],
118 | [0,np.sin(angles[0]),np.cos(angles[0])]])
119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
120 | [0,1,0],
121 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
123 | [np.sin(angles[2]),np.cos(angles[2]),0],
124 | [0,0,1]])
125 | R = np.dot(Rz, np.dot(Ry,Rx))
126 | shape_pc = batch_data[k,:,0:3]
127 | shape_normal = batch_data[k,:,3:6]
128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
130 | return rotated_data
131 |
132 |
133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle):
134 | """ Rotate the point cloud along up direction with certain angle.
135 | Input:
136 | BxNx3 array, original batch of point clouds
137 | Return:
138 | BxNx3 array, rotated batch of point clouds
139 | """
140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
141 | for k in range(batch_data.shape[0]):
142 | #rotation_angle = np.random.uniform() * 2 * np.pi
143 | cosval = np.cos(rotation_angle)
144 | sinval = np.sin(rotation_angle)
145 | rotation_matrix = np.array([[cosval, 0, sinval],
146 | [0, 1, 0],
147 | [-sinval, 0, cosval]])
148 | shape_pc = batch_data[k,:,0:3]
149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
150 | return rotated_data
151 |
152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
153 | """ Rotate the point cloud along up direction with certain angle.
154 | Input:
155 | BxNx6 array, original batch of point clouds with normal
156 | scalar, angle of rotation
157 | Return:
158 | BxNx6 array, rotated batch of point clouds iwth normal
159 | """
160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
161 | for k in range(batch_data.shape[0]):
162 | #rotation_angle = np.random.uniform() * 2 * np.pi
163 | cosval = np.cos(rotation_angle)
164 | sinval = np.sin(rotation_angle)
165 | rotation_matrix = np.array([[cosval, 0, sinval],
166 | [0, 1, 0],
167 | [-sinval, 0, cosval]])
168 | shape_pc = batch_data[k,:,0:3]
169 | shape_normal = batch_data[k,:,3:6]
170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
172 | return rotated_data
173 |
174 |
175 |
176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
177 | """ Randomly perturb the point clouds by small rotations
178 | Input:
179 | BxNx3 array, original batch of point clouds
180 | Return:
181 | BxNx3 array, rotated batch of point clouds
182 | """
183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
184 | for k in range(batch_data.shape[0]):
185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
186 | Rx = np.array([[1,0,0],
187 | [0,np.cos(angles[0]),-np.sin(angles[0])],
188 | [0,np.sin(angles[0]),np.cos(angles[0])]])
189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
190 | [0,1,0],
191 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
193 | [np.sin(angles[2]),np.cos(angles[2]),0],
194 | [0,0,1]])
195 | R = np.dot(Rz, np.dot(Ry,Rx))
196 | shape_pc = batch_data[k, ...]
197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
198 | return rotated_data
199 |
200 |
201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
202 | """ Randomly jitter points. jittering is per point.
203 | Input:
204 | BxNx3 array, original batch of point clouds
205 | Return:
206 | BxNx3 array, jittered batch of point clouds
207 | """
208 | B, N, C = batch_data.shape
209 | assert(clip > 0)
210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
211 | jittered_data += batch_data
212 | return jittered_data
213 |
214 | def shift_point_cloud(batch_data, shift_range=0.1):
215 | """ Randomly shift point cloud. Shift is per point cloud.
216 | Input:
217 | BxNx3 array, original batch of point clouds
218 | Return:
219 | BxNx3 array, shifted batch of point clouds
220 | """
221 | B, N, C = batch_data.shape
222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3))
223 | for batch_index in range(B):
224 | batch_data[batch_index,:,:] += shifts[batch_index,:]
225 | return batch_data
226 |
227 |
228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
229 | """ Randomly scale the point cloud. Scale is per point cloud.
230 | Input:
231 | BxNx3 array, original batch of point clouds
232 | Return:
233 | BxNx3 array, scaled batch of point clouds
234 | """
235 | B, N, C = batch_data.shape
236 | scales = np.random.uniform(scale_low, scale_high, B)
237 | for batch_index in range(B):
238 | batch_data[batch_index,:,:] *= scales[batch_index]
239 | return batch_data
240 |
241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
242 | ''' batch_pc: BxNx3 '''
243 | for b in range(batch_pc.shape[0]):
244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
246 | if len(drop_idx)>0:
247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
248 | return batch_pc
249 |
250 |
251 |
252 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
6 | import argparse
7 | import numpy as np
8 | import os
9 | import torch
10 | import logging
11 | from tqdm import tqdm
12 | import sys
13 | import importlib
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | ROOT_DIR = BASE_DIR
17 |
18 |
19 | def parse_args():
20 | '''PARAMETERS'''
21 | parser = argparse.ArgumentParser('Testing')
22 | parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
23 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
24 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
25 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
26 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
27 | parser.add_argument('--log_dir', type=str, required=True, help='Experiment root')
28 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
29 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
30 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
31 | return parser.parse_args()
32 |
33 |
34 | def test(model, loader, num_class=40, vote_num=1):
35 | mean_correct = []
36 | classifier = model.eval()
37 | class_acc = np.zeros((num_class, 3))
38 |
39 | for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
40 | if not args.use_cpu:
41 | points, target = points.cuda(), target.cuda()
42 |
43 | # input()
44 | points = points.transpose(2, 1)
45 | vote_pool = torch.zeros(target.size()[0], num_class).cuda()
46 |
47 | for _ in range(vote_num):
48 | pred, _, = classifier(points)
49 | vote_pool += pred
50 | pred = vote_pool / vote_num
51 |
52 | pred_choice = pred.data.max(1)[1]
53 | # print(pred_choice)
54 | # print("pred_choice.shape", pred_choice.shape)
55 | for cat in np.unique(target.cpu()):
56 | # print('\n',cat.shape)
57 | # print(pred_choice[target == cat])
58 | classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
59 | # print("classacc.shape", classacc.shape)
60 | class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
61 | class_acc[cat, 1] += 1
62 |
63 | correct = pred_choice.eq(target.long().data).cpu().sum()
64 | mean_correct.append(correct.item() / float(points.size()[0]))
65 |
66 | class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
67 |
68 | class_mean_acc = np.mean(class_acc[:, 2])
69 | instance_acc = np.mean(mean_correct)
70 | return instance_acc, class_mean_acc, class_acc
71 |
72 |
73 | def main(args):
74 | def log_string(str):
75 | logger.info(str)
76 | print(str)
77 |
78 | '''HYPER PARAMETER'''
79 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
80 |
81 | '''CREATE DIR'''
82 | experiment_dir = 'log/classification/' + args.log_dir
83 | sys.path.append(experiment_dir)
84 | '''LOG'''
85 | args = parse_args()
86 | logger = logging.getLogger("Model")
87 | logger.setLevel(logging.INFO)
88 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
89 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
90 | file_handler.setLevel(logging.INFO)
91 | file_handler.setFormatter(formatter)
92 | logger.addHandler(file_handler)
93 | log_string('PARAMETER ...')
94 | log_string(args)
95 |
96 | '''DATA LOADING'''
97 | log_string('Load dataset ...')
98 | data_path = 'data/modelnet40_normal_resampled/'
99 |
100 | test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False)
101 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
102 | num_workers=10)
103 |
104 | '''MODEL LOADING'''
105 | num_class = args.num_category
106 | model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
107 | log_string(model_name)
108 | model = importlib.import_module(model_name)
109 |
110 | classifier = model.get_model(num_class, normal_channel=args.use_normals)
111 |
112 | # for name, parameters in classifier.named_parameters():
113 | # print(name, ':', parameters)
114 | if not args.use_cpu:
115 | classifier = classifier.cuda()
116 |
117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
118 | classifier.load_state_dict(checkpoint['model_state_dict'])
119 |
120 | if num_class == 10:
121 | catfile = os.path.join('./data/modelnet40_normal_resampled', 'modelnet10_shape_names.txt')
122 | else:
123 | catfile = os.path.join('./data/modelnet40_normal_resampled', 'modelnet40_shape_names.txt')
124 |
125 | cats = [line.rstrip() for line in open(catfile)]
126 | cls_to_tag = dict(zip(range(len(cats)), cats))
127 |
128 | with torch.no_grad():
129 | instance_acc, class_mean_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes,
130 | num_class=num_class)
131 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_mean_acc))
132 |
133 | for i in range(num_class):
134 | log_string('Class %s Accuracy: %f' % (cls_to_tag[i], class_acc[i, 2]))
135 |
136 |
137 | if __name__ == '__main__':
138 | args = parse_args()
139 | main(args)
140 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 |
6 | import os
7 | import sys
8 | import torch
9 | import numpy as np
10 |
11 | import datetime
12 | import logging
13 | import provider
14 | import importlib
15 | import shutil
16 | import argparse
17 |
18 | from pathlib import Path
19 | from tqdm import tqdm
20 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
21 | from torch.utils.tensorboard import SummaryWriter
22 | # from torchstat import stat
23 | # from thop import profile
24 | # from thop import clever_format
25 |
26 | # default `log_dir` is "runs" - we'll be more specific here
27 |
28 |
29 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30 | ROOT_DIR = BASE_DIR
31 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
32 |
33 |
34 | def parse_args():
35 | '''PARAMETERS'''
36 | parser = argparse.ArgumentParser('training')
37 | parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
38 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
39 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
40 | parser.add_argument('--model', default='SCNet', help='model name [default: SCNet]')
41 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
42 | parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
43 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
44 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
45 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
46 | parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
47 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
48 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
49 | parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
50 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
51 | return parser.parse_args()
52 |
53 |
54 | def inplace_relu(m):
55 | classname = m.__class__.__name__
56 | if classname.find('ReLU') != -1:
57 | m.inplace = True
58 |
59 |
60 | def test(model, loader, num_class=40):
61 | mean_correct = []
62 | class_acc = np.zeros((num_class, 3))
63 | classifier = model.eval()
64 |
65 | for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
66 |
67 | if not args.use_cpu:
68 | points, target = points.cuda(), target.cuda()
69 |
70 | points = points.transpose(2, 1)
71 | pred, _ = classifier(points)
72 | pred_choice = pred.data.max(1)[1]
73 |
74 | for cat in np.unique(target.cpu()):
75 | classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
76 | class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
77 | class_acc[cat, 1] += 1
78 |
79 | correct = pred_choice.eq(target.long().data).cpu().sum()
80 | mean_correct.append(correct.item() / float(points.size()[0]))
81 |
82 | class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
83 | class_acc = np.mean(class_acc[:, 2])
84 | instance_acc = np.mean(mean_correct)
85 |
86 | return instance_acc, class_acc
87 |
88 |
89 | def main(args):
90 | def log_string(str):
91 | logger.info(str)
92 | print(str)
93 |
94 | '''HYPER PARAMETER'''
95 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
96 |
97 | '''CREATE DIR'''
98 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
99 | exp_dir = Path('./log/')
100 | exp_dir.mkdir(exist_ok=True)
101 | exp_dir = exp_dir.joinpath('classification')
102 | exp_dir.mkdir(exist_ok=True)
103 | if args.log_dir is None:
104 | exp_dir = exp_dir.joinpath(timestr)
105 | else:
106 | exp_dir = exp_dir.joinpath(args.log_dir)
107 | exp_dir.mkdir(exist_ok=True)
108 | checkpoints_dir = exp_dir.joinpath('checkpoints/')
109 | checkpoints_dir.mkdir(exist_ok=True)
110 | log_dir = exp_dir.joinpath('logs/')
111 | log_dir.mkdir(exist_ok=True)
112 |
113 | writer_dir = exp_dir.joinpath('SummaryWriter/')
114 | log_dir.mkdir(exist_ok=True)
115 | writer = SummaryWriter(writer_dir)
116 |
117 | '''LOG'''
118 | args = parse_args()
119 | logger = logging.getLogger("Model")
120 | logger.setLevel(logging.INFO)
121 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
122 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
123 | file_handler.setLevel(logging.INFO)
124 | file_handler.setFormatter(formatter)
125 | logger.addHandler(file_handler)
126 | log_string('PARAMETER ...')
127 | log_string(args)
128 |
129 | '''DATA LOADING'''
130 | log_string('Load dataset ...')
131 | data_path = 'data/modelnet40_normal_resampled/'
132 |
133 | train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
134 | test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
135 | trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
136 | num_workers=10, drop_last=True)
137 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
138 | num_workers=10)
139 |
140 | '''MODEL LOADING'''
141 | num_class = args.num_category
142 | model = importlib.import_module(args.model)
143 | shutil.copy('./models/%s.py' % args.model, str(exp_dir))
144 | shutil.copy('models/utils.py', str(exp_dir))
145 | shutil.copy('./train.py', str(exp_dir))
146 |
147 | shutil.copy('models/SCNet.py', str(exp_dir))
148 | shutil.copy('models/z_order.py', str(exp_dir))
149 |
150 | classifier = model.get_model(num_class, normal_channel=args.use_normals)
151 | criterion = model.get_loss()
152 | classifier.apply(inplace_relu)
153 |
154 | if not args.use_cpu:
155 | classifier = classifier.cuda()
156 | criterion = criterion.cuda()
157 |
158 | try:
159 | checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
160 | start_epoch = checkpoint['epoch']
161 | classifier.load_state_dict(checkpoint['model_state_dict'])
162 | log_string('Use pretrain model')
163 | except:
164 | log_string('No existing model, starting training from scratch...')
165 | start_epoch = 0
166 |
167 | if args.optimizer == 'Adam':
168 | optimizer = torch.optim.Adam(
169 | classifier.parameters(),
170 | lr=args.learning_rate,
171 | betas=(0.9, 0.999),
172 | eps=1e-08,
173 | weight_decay=args.decay_rate
174 | )
175 | else:
176 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
177 |
178 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
179 | global_epoch = 0
180 | global_step = 0
181 | best_instance_acc = 0.0
182 | best_class_acc = 0.0
183 |
184 | '''TRANING'''
185 | logger.info('Start training...')
186 | for epoch in range(start_epoch, args.epoch):
187 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
188 | mean_correct = []
189 | classifier = classifier.train()
190 | total_loss = 0
191 | scheduler.step()
192 | for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader),
193 | smoothing=0.9):
194 | optimizer.zero_grad()
195 |
196 | points = points.data.numpy()
197 | points = provider.random_point_dropout(points)
198 | points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
199 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
200 | points = torch.Tensor(points)
201 | points = points.transpose(2, 1)
202 |
203 | if not args.use_cpu:
204 | points, target = points.cuda(), target.cuda()
205 | # flops, params = profile(classifier, (points,))
206 | # flops, params = clever_format([flops, params], "%.3f")
207 | # print(flops, params)
208 | pred, trans_feat = classifier(points)
209 | loss = criterion(pred, target.long(), trans_feat)
210 | total_loss += loss.item()
211 |
212 | pred_choice = pred.data.max(1)[1]
213 |
214 | correct = pred_choice.eq(target.long().data).cpu().sum()
215 | mean_correct.append(correct.item() / float(points.size()[0]))
216 | loss.backward()
217 | optimizer.step()
218 | global_step += 1
219 |
220 | train_instance_acc = np.mean(mean_correct)
221 | log_string('Train Instance Accuracy: %f' % train_instance_acc)
222 | log_string('Total loss: %f' % total_loss)
223 | writer.add_scalar('Accuracy/Train Instance Accuracy', train_instance_acc, epoch + 1)
224 | writer.add_scalar('Total loss', total_loss, epoch + 1)
225 | with torch.no_grad():
226 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
227 |
228 | if (instance_acc >= best_instance_acc):
229 | best_instance_acc = instance_acc
230 | best_epoch = epoch + 1
231 |
232 | if (class_acc >= best_class_acc):
233 | best_class_acc = class_acc
234 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
235 | writer.add_scalar('Accuracy/Test Instance Accuracy', instance_acc, epoch + 1)
236 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
237 |
238 | if (instance_acc >= best_instance_acc):
239 | logger.info('Save model...')
240 | savepath = str(checkpoints_dir) + '/best_model.pth'
241 | log_string('Saving at %s' % savepath)
242 | state = {
243 | 'epoch': best_epoch,
244 | 'instance_acc': instance_acc,
245 | 'class_acc': class_acc,
246 | 'model_state_dict': classifier.state_dict(),
247 | 'optimizer_state_dict': optimizer.state_dict(),
248 | }
249 | torch.save(state, savepath)
250 | global_epoch += 1
251 |
252 | logger.info('End of training...')
253 |
254 |
255 | if __name__ == '__main__':
256 | args = parse_args()
257 | main(args)
258 |
--------------------------------------------------------------------------------