├── README.md ├── checkpoint ├── best_model.pth └── eval.txt ├── data_utils ├── ModelNetDataLoader.py ├── S3DISDataLoader.py ├── ShapeNetDataLoader.py ├── __pycache__ │ ├── ModelNetDataLoader.cpython-37.pyc │ └── ShapeNetDataLoader.cpython-37.pyc ├── collect_indoor3d_data.py ├── indoor3d_util.py └── meta │ ├── anno_paths.txt │ └── class_names.txt ├── images ├── example.jpg └── model.jpg ├── models ├── SCNet.py ├── utils.py └── z_order.py ├── paint.py ├── provider.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # PointSCNet: Point Cloud Structure and Correlation Learning based on Space Filling Curve guided Sampling 2 | 3 | ## Description 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pointscnet-point-cloud-structure-and/3d-point-cloud-classification-on-modelnet40)](https://paperswithcode.com/sota/3d-point-cloud-classification-on-modelnet40?p=pointscnet-point-cloud-structure-and) 5 | 6 | This repository contains the code for our paper: [PointSCNet: Point Cloud Structure and Correlation Learning based on Space Filling Curve guided Sampling](https://doi.org/10.3390/sym14010008) 7 | 8 |
9 |

10 |
11 | 12 | 13 | ## Environment setup 14 | 15 | Current Code is tested on ubuntu18.04 with cuda11, python3.6.9, torch 1.10.0 and torchvision 0.11.3. 16 | We use a [pytorch version of pointnet++](https://github.com/yanx27/Pointnet_Pointnet2_pytorch) in our pipeline. 17 | 18 | 19 | ## Classification (ModelNet10/40) 20 | ### Data Preparation 21 | Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`. 22 | 23 | ### Data Preparation 24 | 25 | 26 | ### Run 27 | 28 | ``` 29 | python train.py --model SCNet --log_dir SCNet_log --use_normals --process_data 30 | ``` 31 | 32 | * --model: model name 33 | * --log_dir: path to log dir 34 | * --use_normals: use normals 35 | * --process_data: save data offline 36 | 37 | ## Test 38 | 39 | ``` 40 | python test.py --log_dir SCNet_log --use_normals 41 | ``` 42 | 43 | 44 | ## Performance 45 | | Model | Accuracy | 46 | |--|--| 47 | | PointNet (Official) | 89.2| 48 | | PointNet2 (Official) | 91.9 | 49 | | PointSCNet | **93.7**| 50 | 51 | 52 | ## Citation 53 | Please cite our paper if you find it useful in your research: 54 | 55 | ``` 56 | @article{chen2022pointscnet, 57 | title={PointSCNet: Point Cloud Structure and Correlation Learning Based on Space-Filling Curve-Guided Sampling}, 58 | author={Chen, Xingye and Wu, Yiqi and Xu, Wenjie and Li, Jin and Dong, Huaiyi and Chen, Yilin}, 59 | journal={Symmetry}, 60 | volume={14}, 61 | number={1}, 62 | pages={8}, 63 | year={2022}, 64 | publisher={Multidisciplinary Digital Publishing Institute} 65 | } 66 | ``` 67 | 68 | ## Contact 69 | If you have any questions, please contact cxy@cug.edu.cn 70 | -------------------------------------------------------------------------------- /checkpoint/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/checkpoint/best_model.pth -------------------------------------------------------------------------------- /checkpoint/eval.txt: -------------------------------------------------------------------------------- 1 | 2021-11-30 18:50:58,111 - Model - INFO - PARAMETER ... 2 | 2021-11-30 18:50:58,111 - Model - INFO - Namespace(batch_size=24, gpu='0', log_dir='att_dot_test', num_category=40, num_point=1024, num_votes=2, use_cpu=False, use_normals=True, use_uniform_sample=False) 3 | 2021-11-30 18:50:58,111 - Model - INFO - Load dataset ... 4 | 2021-11-30 18:50:58,117 - Model - INFO - PointSCNet 5 | 2021-11-30 18:51:38,716 - Model - INFO - Test Instance Accuracy: 0.937217, Class Accuracy: 0.914043 6 | 2021-11-30 18:51:38,716 - Model - INFO - Class airplane Accuracy: 1.000000 7 | 2021-11-30 18:51:38,716 - Model - INFO - Class bathtub Accuracy: 0.986111 8 | 2021-11-30 18:51:38,716 - Model - INFO - Class bed Accuracy: 1.000000 9 | 2021-11-30 18:51:38,716 - Model - INFO - Class bench Accuracy: 0.845238 10 | 2021-11-30 18:51:38,716 - Model - INFO - Class bookshelf Accuracy: 0.988889 11 | 2021-11-30 18:51:38,716 - Model - INFO - Class bottle Accuracy: 0.960714 12 | 2021-11-30 18:51:38,716 - Model - INFO - Class bowl Accuracy: 1.000000 13 | 2021-11-30 18:51:38,716 - Model - INFO - Class car Accuracy: 1.000000 14 | 2021-11-30 18:51:38,716 - Model - INFO - Class chair Accuracy: 0.991667 15 | 2021-11-30 18:51:38,716 - Model - INFO - Class cone Accuracy: 1.000000 16 | 2021-11-30 18:51:38,717 - Model - INFO - Class cup Accuracy: 0.750000 17 | 2021-11-30 18:51:38,717 - Model - INFO - Class curtain Accuracy: 0.880952 18 | 2021-11-30 18:51:38,717 - Model - INFO - Class desk Accuracy: 0.903472 19 | 2021-11-30 18:51:38,717 - Model - INFO - Class door Accuracy: 0.968750 20 | 2021-11-30 18:51:38,717 - Model - INFO - Class dresser Accuracy: 0.883333 21 | 2021-11-30 18:51:38,717 - Model - INFO - Class flower_pot Accuracy: 0.111111 22 | 2021-11-30 18:51:38,717 - Model - INFO - Class glass_box Accuracy: 0.949242 23 | 2021-11-30 18:51:38,717 - Model - INFO - Class guitar Accuracy: 1.000000 24 | 2021-11-30 18:51:38,717 - Model - INFO - Class keyboard Accuracy: 1.000000 25 | 2021-11-30 18:51:38,717 - Model - INFO - Class lamp Accuracy: 1.000000 26 | 2021-11-30 18:51:38,717 - Model - INFO - Class laptop Accuracy: 1.000000 27 | 2021-11-30 18:51:38,717 - Model - INFO - Class mantel Accuracy: 0.986111 28 | 2021-11-30 18:51:38,717 - Model - INFO - Class monitor Accuracy: 0.991667 29 | 2021-11-30 18:51:38,717 - Model - INFO - Class night_stand Accuracy: 0.798611 30 | 2021-11-30 18:51:38,717 - Model - INFO - Class person Accuracy: 1.000000 31 | 2021-11-30 18:51:38,717 - Model - INFO - Class piano Accuracy: 0.940000 32 | 2021-11-30 18:51:38,717 - Model - INFO - Class plant Accuracy: 0.833333 33 | 2021-11-30 18:51:38,717 - Model - INFO - Class radio Accuracy: 0.750000 34 | 2021-11-30 18:51:38,717 - Model - INFO - Class range_hood Accuracy: 0.925000 35 | 2021-11-30 18:51:38,717 - Model - INFO - Class sink Accuracy: 0.900000 36 | 2021-11-30 18:51:38,717 - Model - INFO - Class sofa Accuracy: 1.000000 37 | 2021-11-30 18:51:38,717 - Model - INFO - Class stairs Accuracy: 1.000000 38 | 2021-11-30 18:51:38,717 - Model - INFO - Class stool Accuracy: 0.968750 39 | 2021-11-30 18:51:38,717 - Model - INFO - Class table Accuracy: 0.880000 40 | 2021-11-30 18:51:38,717 - Model - INFO - Class tent Accuracy: 0.968750 41 | 2021-11-30 18:51:38,717 - Model - INFO - Class toilet Accuracy: 1.000000 42 | 2021-11-30 18:51:38,717 - Model - INFO - Class tv_stand Accuracy: 0.900000 43 | 2021-11-30 18:51:38,717 - Model - INFO - Class vase Accuracy: 0.850000 44 | 2021-11-30 18:51:38,717 - Model - INFO - Class wardrobe Accuracy: 0.700000 45 | 2021-11-30 18:51:38,717 - Model - INFO - Class xbox Accuracy: 0.950000 46 | -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Xu Yan 3 | @file: ModelNet.py 4 | @time: 2021/3/19 15:51 5 | ''' 6 | import os 7 | import numpy as np 8 | import warnings 9 | import pickle 10 | 11 | from tqdm import tqdm 12 | from torch.utils.data import Dataset 13 | 14 | warnings.filterwarnings('ignore') 15 | 16 | 17 | def pc_normalize(pc): 18 | centroid = np.mean(pc, axis=0) 19 | pc = pc - centroid 20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 21 | pc = pc / m 22 | return pc 23 | 24 | 25 | def farthest_point_sample(point, npoint): 26 | """ 27 | Input: 28 | xyz: pointcloud data, [N, D] 29 | npoint: number of samples 30 | Return: 31 | centroids: sampled pointcloud index, [npoint, D] 32 | """ 33 | N, D = point.shape 34 | xyz = point[:,:3] 35 | centroids = np.zeros((npoint,)) 36 | distance = np.ones((N,)) * 1e10 37 | farthest = np.random.randint(0, N) 38 | for i in range(npoint): 39 | centroids[i] = farthest 40 | centroid = xyz[farthest, :] 41 | dist = np.sum((xyz - centroid) ** 2, -1) 42 | mask = dist < distance 43 | distance[mask] = dist[mask] 44 | farthest = np.argmax(distance, -1) 45 | point = point[centroids.astype(np.int32)] 46 | return point 47 | 48 | 49 | class ModelNetDataLoader(Dataset): 50 | def __init__(self, root, args, split='train', process_data=False): 51 | self.root = root 52 | self.npoints = args.num_point 53 | self.process_data = process_data 54 | self.uniform = args.use_uniform_sample 55 | self.use_normals = args.use_normals 56 | self.num_category = args.num_category 57 | 58 | if self.num_category == 10: 59 | self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') 60 | else: 61 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 62 | 63 | self.cat = [line.rstrip() for line in open(self.catfile)] 64 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 65 | 66 | shape_ids = {} 67 | if self.num_category == 10: 68 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))] 69 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))] 70 | else: 71 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 72 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 73 | 74 | assert (split == 'train' or split == 'test') 75 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 76 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 77 | in range(len(shape_ids[split]))] 78 | print('The size of %s data is %d' % (split, len(self.datapath))) 79 | 80 | if self.uniform: 81 | self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints)) 82 | else: 83 | self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints)) 84 | 85 | if self.process_data: 86 | if not os.path.exists(self.save_path): 87 | print('Processing data %s (only running in the first time)...' % self.save_path) 88 | self.list_of_points = [None] * len(self.datapath) 89 | self.list_of_labels = [None] * len(self.datapath) 90 | 91 | for index in tqdm(range(len(self.datapath)), total=len(self.datapath)): 92 | fn = self.datapath[index] 93 | cls = self.classes[self.datapath[index][0]] 94 | cls = np.array([cls]).astype(np.int32) 95 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 96 | 97 | if self.uniform: 98 | point_set = farthest_point_sample(point_set, self.npoints) 99 | else: 100 | point_set = point_set[0:self.npoints, :] 101 | 102 | self.list_of_points[index] = point_set 103 | self.list_of_labels[index] = cls 104 | 105 | with open(self.save_path, 'wb') as f: 106 | pickle.dump([self.list_of_points, self.list_of_labels], f) 107 | else: 108 | print('Load processed data from %s...' % self.save_path) 109 | with open(self.save_path, 'rb') as f: 110 | self.list_of_points, self.list_of_labels = pickle.load(f) 111 | 112 | def __len__(self): 113 | return len(self.datapath) 114 | 115 | def _get_item(self, index): 116 | if self.process_data: 117 | point_set, label = self.list_of_points[index], self.list_of_labels[index] 118 | else: 119 | fn = self.datapath[index] 120 | cls = self.classes[self.datapath[index][0]] 121 | label = np.array([cls]).astype(np.int32) 122 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 123 | 124 | if self.uniform: 125 | point_set = farthest_point_sample(point_set, self.npoints) 126 | else: 127 | point_set = point_set[0:self.npoints, :] 128 | 129 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 130 | if not self.use_normals: 131 | point_set = point_set[:, 0:3] 132 | 133 | return point_set, label[0] 134 | 135 | def __getitem__(self, index): 136 | return self._get_item(index) 137 | 138 | 139 | if __name__ == '__main__': 140 | import torch 141 | 142 | 143 | data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train') 144 | 145 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True,num_workers=10) 146 | for point, label in DataLoader: 147 | print(point.shape) 148 | print(label.shape) 149 | -------------------------------------------------------------------------------- /data_utils/S3DISDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from tqdm import tqdm 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class S3DISDataset(Dataset): 9 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None): 10 | super().__init__() 11 | self.num_point = num_point 12 | self.block_size = block_size 13 | self.transform = transform 14 | rooms = sorted(os.listdir(data_root)) 15 | rooms = [room for room in rooms if 'Area_' in room] 16 | if split == 'train': 17 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room] 18 | else: 19 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room] 20 | 21 | self.room_points, self.room_labels = [], [] 22 | self.room_coord_min, self.room_coord_max = [], [] 23 | num_point_all = [] 24 | labelweights = np.zeros(13) 25 | 26 | for room_name in tqdm(rooms_split, total=len(rooms_split)): 27 | room_path = os.path.join(data_root, room_name) 28 | room_data = np.load(room_path) # xyzrgbl, N*7 29 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N 30 | tmp, _ = np.histogram(labels, range(14)) 31 | labelweights += tmp 32 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 33 | self.room_points.append(points), self.room_labels.append(labels) 34 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 35 | num_point_all.append(labels.size) 36 | labelweights = labelweights.astype(np.float32) 37 | labelweights = labelweights / np.sum(labelweights) 38 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 39 | print(self.labelweights) 40 | sample_prob = num_point_all / np.sum(num_point_all) 41 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point) 42 | room_idxs = [] 43 | for index in range(len(rooms_split)): 44 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter))) 45 | self.room_idxs = np.array(room_idxs) 46 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split)) 47 | 48 | def __getitem__(self, idx): 49 | room_idx = self.room_idxs[idx] 50 | points = self.room_points[room_idx] # N * 6 51 | labels = self.room_labels[room_idx] # N 52 | N_points = points.shape[0] 53 | 54 | while (True): 55 | center = points[np.random.choice(N_points)][:3] 56 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0] 57 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0] 58 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0] 59 | if point_idxs.size > 1024: 60 | break 61 | 62 | if point_idxs.size >= self.num_point: 63 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False) 64 | else: 65 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True) 66 | 67 | # normalize 68 | selected_points = points[selected_point_idxs, :] # num_point * 6 69 | current_points = np.zeros((self.num_point, 9)) # num_point * 9 70 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0] 71 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1] 72 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2] 73 | selected_points[:, 0] = selected_points[:, 0] - center[0] 74 | selected_points[:, 1] = selected_points[:, 1] - center[1] 75 | selected_points[:, 3:6] /= 255.0 76 | current_points[:, 0:6] = selected_points 77 | current_labels = labels[selected_point_idxs] 78 | if self.transform is not None: 79 | current_points, current_labels = self.transform(current_points, current_labels) 80 | return current_points, current_labels 81 | 82 | def __len__(self): 83 | return len(self.room_idxs) 84 | 85 | class ScannetDatasetWholeScene(): 86 | # prepare to give prediction on each points 87 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001): 88 | self.block_points = block_points 89 | self.block_size = block_size 90 | self.padding = padding 91 | self.root = root 92 | self.split = split 93 | self.stride = stride 94 | self.scene_points_num = [] 95 | assert split in ['train', 'test'] 96 | if self.split == 'train': 97 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1] 98 | else: 99 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1] 100 | self.scene_points_list = [] 101 | self.semantic_labels_list = [] 102 | self.room_coord_min, self.room_coord_max = [], [] 103 | for file in self.file_list: 104 | data = np.load(root + file) 105 | points = data[:, :3] 106 | self.scene_points_list.append(data[:, :6]) 107 | self.semantic_labels_list.append(data[:, 6]) 108 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 109 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 110 | assert len(self.scene_points_list) == len(self.semantic_labels_list) 111 | 112 | labelweights = np.zeros(13) 113 | for seg in self.semantic_labels_list: 114 | tmp, _ = np.histogram(seg, range(14)) 115 | self.scene_points_num.append(seg.shape[0]) 116 | labelweights += tmp 117 | labelweights = labelweights.astype(np.float32) 118 | labelweights = labelweights / np.sum(labelweights) 119 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 120 | 121 | def __getitem__(self, index): 122 | point_set_ini = self.scene_points_list[index] 123 | points = point_set_ini[:,:6] 124 | labels = self.semantic_labels_list[index] 125 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 126 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1) 127 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1) 128 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([]) 129 | for index_y in range(0, grid_y): 130 | for index_x in range(0, grid_x): 131 | s_x = coord_min[0] + index_x * self.stride 132 | e_x = min(s_x + self.block_size, coord_max[0]) 133 | s_x = e_x - self.block_size 134 | s_y = coord_min[1] + index_y * self.stride 135 | e_y = min(s_y + self.block_size, coord_max[1]) 136 | s_y = e_y - self.block_size 137 | point_idxs = np.where( 138 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & ( 139 | points[:, 1] <= e_y + self.padding))[0] 140 | if point_idxs.size == 0: 141 | continue 142 | num_batch = int(np.ceil(point_idxs.size / self.block_points)) 143 | point_size = int(num_batch * self.block_points) 144 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True 145 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace) 146 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat)) 147 | np.random.shuffle(point_idxs) 148 | data_batch = points[point_idxs, :] 149 | normlized_xyz = np.zeros((point_size, 3)) 150 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0] 151 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1] 152 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2] 153 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0) 154 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0) 155 | data_batch[:, 3:6] /= 255.0 156 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1) 157 | label_batch = labels[point_idxs].astype(int) 158 | batch_weight = self.labelweights[label_batch] 159 | 160 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch 161 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch 162 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight 163 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs 164 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1])) 165 | label_room = label_room.reshape((-1, self.block_points)) 166 | sample_weight = sample_weight.reshape((-1, self.block_points)) 167 | index_room = index_room.reshape((-1, self.block_points)) 168 | return data_room, label_room, sample_weight, index_room 169 | 170 | def __len__(self): 171 | return len(self.scene_points_list) 172 | 173 | if __name__ == '__main__': 174 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/' 175 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01 176 | 177 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None) 178 | print('point data size:', point_data.__len__()) 179 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape) 180 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape) 181 | import torch, time, random 182 | manual_seed = 123 183 | random.seed(manual_seed) 184 | np.random.seed(manual_seed) 185 | torch.manual_seed(manual_seed) 186 | torch.cuda.manual_seed_all(manual_seed) 187 | def worker_init_fn(worker_id): 188 | random.seed(manual_seed + worker_id) 189 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn) 190 | for idx in range(4): 191 | end = time.time() 192 | for i, (input, target) in enumerate(train_loader): 193 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end)) 194 | end = time.time() -------------------------------------------------------------------------------- /data_utils/ShapeNetDataLoader.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import json 4 | import warnings 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | warnings.filterwarnings('ignore') 8 | 9 | def pc_normalize(pc): 10 | centroid = np.mean(pc, axis=0) 11 | pc = pc - centroid 12 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 13 | pc = pc / m 14 | return pc 15 | 16 | class PartNormalDataset(Dataset): 17 | def __init__(self,root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False): 18 | self.npoints = npoints 19 | self.root = root 20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 21 | self.cat = {} 22 | self.normal_channel = normal_channel 23 | 24 | 25 | with open(self.catfile, 'r') as f: 26 | for line in f: 27 | ls = line.strip().split() 28 | self.cat[ls[0]] = ls[1] 29 | self.cat = {k: v for k, v in self.cat.items()} 30 | self.classes_original = dict(zip(self.cat, range(len(self.cat)))) 31 | 32 | if not class_choice is None: 33 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 34 | # print(self.cat) 35 | 36 | self.meta = {} 37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 38 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 39 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 40 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 41 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 42 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 43 | for item in self.cat: 44 | # print('category', item) 45 | self.meta[item] = [] 46 | dir_point = os.path.join(self.root, self.cat[item]) 47 | fns = sorted(os.listdir(dir_point)) 48 | # print(fns[0][0:-4]) 49 | if split == 'trainval': 50 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 51 | elif split == 'train': 52 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 53 | elif split == 'val': 54 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 55 | elif split == 'test': 56 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 57 | else: 58 | print('Unknown split: %s. Exiting..' % (split)) 59 | exit(-1) 60 | 61 | # print(os.path.basename(fns)) 62 | for fn in fns: 63 | token = (os.path.splitext(os.path.basename(fn))[0]) 64 | self.meta[item].append(os.path.join(dir_point, token + '.txt')) 65 | 66 | self.datapath = [] 67 | for item in self.cat: 68 | for fn in self.meta[item]: 69 | self.datapath.append((item, fn)) 70 | 71 | self.classes = {} 72 | for i in self.cat.keys(): 73 | self.classes[i] = self.classes_original[i] 74 | 75 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels 76 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 77 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 78 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 79 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 80 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 81 | 82 | # for cat in sorted(self.seg_classes.keys()): 83 | # print(cat, self.seg_classes[cat]) 84 | 85 | self.cache = {} # from index to (point_set, cls, seg) tuple 86 | self.cache_size = 20000 87 | 88 | 89 | def __getitem__(self, index): 90 | if index in self.cache: 91 | point_set, cls, seg = self.cache[index] 92 | else: 93 | fn = self.datapath[index] 94 | cat = self.datapath[index][0] 95 | cls = self.classes[cat] 96 | cls = np.array([cls]).astype(np.int32) 97 | data = np.loadtxt(fn[1]).astype(np.float32) 98 | if not self.normal_channel: 99 | point_set = data[:, 0:3] 100 | else: 101 | point_set = data[:, 0:6] 102 | seg = data[:, -1].astype(np.int32) 103 | if len(self.cache) < self.cache_size: 104 | self.cache[index] = (point_set, cls, seg) 105 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 106 | 107 | choice = np.random.choice(len(seg), self.npoints, replace=True) 108 | # resample 109 | point_set = point_set[choice, :] 110 | seg = seg[choice] 111 | 112 | return point_set, cls, seg 113 | 114 | def __len__(self): 115 | return len(self.datapath) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /data_utils/__pycache__/ModelNetDataLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/data_utils/__pycache__/ModelNetDataLoader.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/__pycache__/ShapeNetDataLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/data_utils/__pycache__/ShapeNetDataLoader.cpython-37.pyc -------------------------------------------------------------------------------- /data_utils/collect_indoor3d_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from indoor3d_util import DATA_PATH, collect_point_label 4 | 5 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | ROOT_DIR = os.path.dirname(BASE_DIR) 7 | sys.path.append(BASE_DIR) 8 | 9 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))] 10 | anno_paths = [os.path.join(DATA_PATH, p) for p in anno_paths] 11 | 12 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d') 13 | if not os.path.exists(output_folder): 14 | os.mkdir(output_folder) 15 | 16 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually. 17 | for anno_path in anno_paths: 18 | print(anno_path) 19 | try: 20 | elements = anno_path.split('/') 21 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' # Area_1_hallway_1.npy 22 | collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy') 23 | except: 24 | print(anno_path, 'ERROR!!') 25 | -------------------------------------------------------------------------------- /data_utils/indoor3d_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import sys 5 | 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | ROOT_DIR = os.path.dirname(BASE_DIR) 8 | sys.path.append(BASE_DIR) 9 | 10 | DATA_PATH = os.path.join(ROOT_DIR, 'data','s3dis', 'Stanford3dDataset_v1.2_Aligned_Version') 11 | g_classes = [x.rstrip() for x in open(os.path.join(BASE_DIR, 'meta/class_names.txt'))] 12 | g_class2label = {cls: i for i,cls in enumerate(g_classes)} 13 | g_class2color = {'ceiling': [0,255,0], 14 | 'floor': [0,0,255], 15 | 'wall': [0,255,255], 16 | 'beam': [255,255,0], 17 | 'column': [255,0,255], 18 | 'window': [100,100,255], 19 | 'door': [200,200,100], 20 | 'table': [170,120,200], 21 | 'chair': [255,0,0], 22 | 'sofa': [200,100,100], 23 | 'bookcase': [10,200,100], 24 | 'board': [200,200,200], 25 | 'clutter': [50,50,50]} 26 | g_easy_view_labels = [7,8,9,10,11,1] 27 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes} 28 | 29 | 30 | # ----------------------------------------------------------------------------- 31 | # CONVERT ORIGINAL DATA TO OUR DATA_LABEL FILES 32 | # ----------------------------------------------------------------------------- 33 | 34 | def collect_point_label(anno_path, out_filename, file_format='txt'): 35 | """ Convert original dataset files to data_label file (each line is XYZRGBL). 36 | We aggregated all the points from each instance in the room. 37 | 38 | Args: 39 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 40 | out_filename: path to save collected points and labels (each line is XYZRGBL) 41 | file_format: txt or numpy, determines what file format to save. 42 | Returns: 43 | None 44 | Note: 45 | the points are shifted before save, the most negative point is now at origin. 46 | """ 47 | points_list = [] 48 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 49 | cls = os.path.basename(f).split('_')[0] 50 | print(f) 51 | if cls not in g_classes: # note: in some room there is 'staris' class.. 52 | cls = 'clutter' 53 | 54 | points = np.loadtxt(f) 55 | labels = np.ones((points.shape[0],1)) * g_class2label[cls] 56 | points_list.append(np.concatenate([points, labels], 1)) # Nx7 57 | 58 | data_label = np.concatenate(points_list, 0) 59 | xyz_min = np.amin(data_label, axis=0)[0:3] 60 | data_label[:, 0:3] -= xyz_min 61 | 62 | if file_format=='txt': 63 | fout = open(out_filename, 'w') 64 | for i in range(data_label.shape[0]): 65 | fout.write('%f %f %f %d %d %d %d\n' % \ 66 | (data_label[i,0], data_label[i,1], data_label[i,2], 67 | data_label[i,3], data_label[i,4], data_label[i,5], 68 | data_label[i,6])) 69 | fout.close() 70 | elif file_format=='numpy': 71 | np.save(out_filename, data_label) 72 | else: 73 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 74 | (file_format)) 75 | exit() 76 | 77 | def data_to_obj(data,name='example.obj',no_wall=True): 78 | fout = open(name, 'w') 79 | label = data[:, -1].astype(int) 80 | for i in range(data.shape[0]): 81 | if no_wall and ((label[i] == 2) or (label[i]==0)): 82 | continue 83 | fout.write('v %f %f %f %d %d %d\n' % \ 84 | (data[i, 0], data[i, 1], data[i, 2], data[i, 3], data[i, 4], data[i, 5])) 85 | fout.close() 86 | 87 | def point_label_to_obj(input_filename, out_filename, label_color=True, easy_view=False, no_wall=False): 88 | """ For visualization of a room from data_label file, 89 | input_filename: each line is X Y Z R G B L 90 | out_filename: OBJ filename, 91 | visualize input file by coloring point with label color 92 | easy_view: only visualize furnitures and floor 93 | """ 94 | data_label = np.loadtxt(input_filename) 95 | data = data_label[:, 0:6] 96 | label = data_label[:, -1].astype(int) 97 | fout = open(out_filename, 'w') 98 | for i in range(data.shape[0]): 99 | color = g_label2color[label[i]] 100 | if easy_view and (label[i] not in g_easy_view_labels): 101 | continue 102 | if no_wall and ((label[i] == 2) or (label[i]==0)): 103 | continue 104 | if label_color: 105 | fout.write('v %f %f %f %d %d %d\n' % \ 106 | (data[i,0], data[i,1], data[i,2], color[0], color[1], color[2])) 107 | else: 108 | fout.write('v %f %f %f %d %d %d\n' % \ 109 | (data[i,0], data[i,1], data[i,2], data[i,3], data[i,4], data[i,5])) 110 | fout.close() 111 | 112 | 113 | 114 | # ----------------------------------------------------------------------------- 115 | # PREPARE BLOCK DATA FOR DEEPNETS TRAINING/TESTING 116 | # ----------------------------------------------------------------------------- 117 | 118 | def sample_data(data, num_sample): 119 | """ data is in N x ... 120 | we want to keep num_samplexC of them. 121 | if N > num_sample, we will randomly keep num_sample of them. 122 | if N < num_sample, we will randomly duplicate samples. 123 | """ 124 | N = data.shape[0] 125 | if (N == num_sample): 126 | return data, range(N) 127 | elif (N > num_sample): 128 | sample = np.random.choice(N, num_sample) 129 | return data[sample, ...], sample 130 | else: 131 | sample = np.random.choice(N, num_sample-N) 132 | dup_data = data[sample, ...] 133 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample) 134 | 135 | def sample_data_label(data, label, num_sample): 136 | new_data, sample_indices = sample_data(data, num_sample) 137 | new_label = label[sample_indices] 138 | return new_data, new_label 139 | 140 | def room2blocks(data, label, num_point, block_size=1.0, stride=1.0, 141 | random_sample=False, sample_num=None, sample_aug=1): 142 | """ Prepare block training data. 143 | Args: 144 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 145 | assumes the data is shifted (min point is origin) and aligned 146 | (aligned with XYZ axis) 147 | label: N size uint8 numpy array from 0-12 148 | num_point: int, how many points to sample in each block 149 | block_size: float, physical size of the block in meters 150 | stride: float, stride for block sweeping 151 | random_sample: bool, if True, we will randomly sample blocks in the room 152 | sample_num: int, if random sample, how many blocks to sample 153 | [default: room area] 154 | sample_aug: if random sample, how much aug 155 | Returns: 156 | block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1] 157 | block_labels: K x num_point x 1 np array of uint8 labels 158 | 159 | TODO: for this version, blocking is in fixed, non-overlapping pattern. 160 | """ 161 | assert(stride<=block_size) 162 | 163 | limit = np.amax(data, 0)[0:3] 164 | 165 | # Get the corner location for our sampling blocks 166 | xbeg_list = [] 167 | ybeg_list = [] 168 | if not random_sample: 169 | num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1 170 | num_block_y = int(np.ceil(collect_point_label(limit[1] - block_size) / stride)) + 1 171 | for i in range(num_block_x): 172 | for j in range(num_block_y): 173 | xbeg_list.append(i*stride) 174 | ybeg_list.append(j*stride) 175 | else: 176 | num_block_x = int(np.ceil(limit[0] / block_size)) 177 | num_block_y = int(np.ceil(limit[1] / block_size)) 178 | if sample_num is None: 179 | sample_num = num_block_x * num_block_y * sample_aug 180 | for _ in range(sample_num): 181 | xbeg = np.random.uniform(-block_size, limit[0]) 182 | ybeg = np.random.uniform(-block_size, limit[1]) 183 | xbeg_list.append(xbeg) 184 | ybeg_list.append(ybeg) 185 | 186 | # Collect blocks 187 | block_data_list = [] 188 | block_label_list = [] 189 | idx = 0 190 | for idx in range(len(xbeg_list)): 191 | xbeg = xbeg_list[idx] 192 | ybeg = ybeg_list[idx] 193 | xcond = (data[:,0]<=xbeg+block_size) & (data[:,0]>=xbeg) 194 | ycond = (data[:,1]<=ybeg+block_size) & (data[:,1]>=ybeg) 195 | cond = xcond & ycond 196 | if np.sum(cond) < 100: # discard block if there are less than 100 pts. 197 | continue 198 | 199 | block_data = data[cond, :] 200 | block_label = label[cond] 201 | 202 | # randomly subsample data 203 | block_data_sampled, block_label_sampled = \ 204 | sample_data_label(block_data, block_label, num_point) 205 | block_data_list.append(np.expand_dims(block_data_sampled, 0)) 206 | block_label_list.append(np.expand_dims(block_label_sampled, 0)) 207 | 208 | return np.concatenate(block_data_list, 0), \ 209 | np.concatenate(block_label_list, 0) 210 | 211 | 212 | def room2blocks_plus(data_label, num_point, block_size, stride, 213 | random_sample, sample_num, sample_aug): 214 | """ room2block with input filename and RGB preprocessing. 215 | """ 216 | data = data_label[:,0:6] 217 | data[:,3:6] /= 255.0 218 | label = data_label[:,-1].astype(np.uint8) 219 | 220 | return room2blocks(data, label, num_point, block_size, stride, 221 | random_sample, sample_num, sample_aug) 222 | 223 | def room2blocks_wrapper(data_label_filename, num_point, block_size=1.0, stride=1.0, 224 | random_sample=False, sample_num=None, sample_aug=1): 225 | if data_label_filename[-3:] == 'txt': 226 | data_label = np.loadtxt(data_label_filename) 227 | elif data_label_filename[-3:] == 'npy': 228 | data_label = np.load(data_label_filename) 229 | else: 230 | print('Unknown file type! exiting.') 231 | exit() 232 | return room2blocks_plus(data_label, num_point, block_size, stride, 233 | random_sample, sample_num, sample_aug) 234 | 235 | def room2blocks_plus_normalized(data_label, num_point, block_size, stride, 236 | random_sample, sample_num, sample_aug): 237 | """ room2block, with input filename and RGB preprocessing. 238 | for each block centralize XYZ, add normalized XYZ as 678 channels 239 | """ 240 | data = data_label[:,0:6] 241 | data[:,3:6] /= 255.0 242 | label = data_label[:,-1].astype(np.uint8) 243 | max_room_x = max(data[:,0]) 244 | max_room_y = max(data[:,1]) 245 | max_room_z = max(data[:,2]) 246 | 247 | data_batch, label_batch = room2blocks(data, label, num_point, block_size, stride, 248 | random_sample, sample_num, sample_aug) 249 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 250 | for b in range(data_batch.shape[0]): 251 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 252 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 253 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 254 | minx = min(data_batch[b, :, 0]) 255 | miny = min(data_batch[b, :, 1]) 256 | data_batch[b, :, 0] -= (minx+block_size/2) 257 | data_batch[b, :, 1] -= (miny+block_size/2) 258 | new_data_batch[:, :, 0:6] = data_batch 259 | return new_data_batch, label_batch 260 | 261 | 262 | def room2blocks_wrapper_normalized(data_label_filename, num_point, block_size=1.0, stride=1.0, 263 | random_sample=False, sample_num=None, sample_aug=1): 264 | if data_label_filename[-3:] == 'txt': 265 | data_label = np.loadtxt(data_label_filename) 266 | elif data_label_filename[-3:] == 'npy': 267 | data_label = np.load(data_label_filename) 268 | else: 269 | print('Unknown file type! exiting.') 270 | exit() 271 | return room2blocks_plus_normalized(data_label, num_point, block_size, stride, 272 | random_sample, sample_num, sample_aug) 273 | 274 | def room2samples(data, label, sample_num_point): 275 | """ Prepare whole room samples. 276 | 277 | Args: 278 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 279 | assumes the data is shifted (min point is origin) and 280 | aligned (aligned with XYZ axis) 281 | label: N size uint8 numpy array from 0-12 282 | sample_num_point: int, how many points to sample in each sample 283 | Returns: 284 | sample_datas: K x sample_num_point x 9 285 | numpy array of XYZRGBX'Y'Z', RGB is in [0,1] 286 | sample_labels: K x sample_num_point x 1 np array of uint8 labels 287 | """ 288 | N = data.shape[0] 289 | order = np.arange(N) 290 | np.random.shuffle(order) 291 | data = data[order, :] 292 | label = label[order] 293 | 294 | batch_num = int(np.ceil(N / float(sample_num_point))) 295 | sample_datas = np.zeros((batch_num, sample_num_point, 6)) 296 | sample_labels = np.zeros((batch_num, sample_num_point, 1)) 297 | 298 | for i in range(batch_num): 299 | beg_idx = i*sample_num_point 300 | end_idx = min((i+1)*sample_num_point, N) 301 | num = end_idx - beg_idx 302 | sample_datas[i,0:num,:] = data[beg_idx:end_idx, :] 303 | sample_labels[i,0:num,0] = label[beg_idx:end_idx] 304 | if num < sample_num_point: 305 | makeup_indices = np.random.choice(N, sample_num_point - num) 306 | sample_datas[i,num:,:] = data[makeup_indices, :] 307 | sample_labels[i,num:,0] = label[makeup_indices] 308 | return sample_datas, sample_labels 309 | 310 | def room2samples_plus_normalized(data_label, num_point): 311 | """ room2sample, with input filename and RGB preprocessing. 312 | for each block centralize XYZ, add normalized XYZ as 678 channels 313 | """ 314 | data = data_label[:,0:6] 315 | data[:,3:6] /= 255.0 316 | label = data_label[:,-1].astype(np.uint8) 317 | max_room_x = max(data[:,0]) 318 | max_room_y = max(data[:,1]) 319 | max_room_z = max(data[:,2]) 320 | #print(max_room_x, max_room_y, max_room_z) 321 | 322 | data_batch, label_batch = room2samples(data, label, num_point) 323 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 324 | for b in range(data_batch.shape[0]): 325 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 326 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 327 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 328 | #minx = min(data_batch[b, :, 0]) 329 | #miny = min(data_batch[b, :, 1]) 330 | #data_batch[b, :, 0] -= (minx+block_size/2) 331 | #data_batch[b, :, 1] -= (miny+block_size/2) 332 | new_data_batch[:, :, 0:6] = data_batch 333 | return new_data_batch, label_batch 334 | 335 | 336 | def room2samples_wrapper_normalized(data_label_filename, num_point): 337 | if data_label_filename[-3:] == 'txt': 338 | data_label = np.loadtxt(data_label_filename) 339 | elif data_label_filename[-3:] == 'npy': 340 | data_label = np.load(data_label_filename) 341 | else: 342 | print('Unknown file type! exiting.') 343 | exit() 344 | return room2samples_plus_normalized(data_label, num_point) 345 | 346 | 347 | # ----------------------------------------------------------------------------- 348 | # EXTRACT INSTANCE BBOX FROM ORIGINAL DATA (for detection evaluation) 349 | # ----------------------------------------------------------------------------- 350 | 351 | def collect_bounding_box(anno_path, out_filename): 352 | """ Compute bounding boxes from each instance in original dataset files on 353 | one room. **We assume the bbox is aligned with XYZ coordinate.** 354 | 355 | Args: 356 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 357 | out_filename: path to save instance bounding boxes for that room. 358 | each line is x1 y1 z1 x2 y2 z2 label, 359 | where (x1,y1,z1) is the point on the diagonal closer to origin 360 | Returns: 361 | None 362 | Note: 363 | room points are shifted, the most negative point is now at origin. 364 | """ 365 | bbox_label_list = [] 366 | 367 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 368 | cls = os.path.basename(f).split('_')[0] 369 | if cls not in g_classes: # note: in some room there is 'staris' class.. 370 | cls = 'clutter' 371 | points = np.loadtxt(f) 372 | label = g_class2label[cls] 373 | # Compute tightest axis aligned bounding box 374 | xyz_min = np.amin(points[:, 0:3], axis=0) 375 | xyz_max = np.amax(points[:, 0:3], axis=0) 376 | ins_bbox_label = np.expand_dims( 377 | np.concatenate([xyz_min, xyz_max, np.array([label])], 0), 0) 378 | bbox_label_list.append(ins_bbox_label) 379 | 380 | bbox_label = np.concatenate(bbox_label_list, 0) 381 | room_xyz_min = np.amin(bbox_label[:, 0:3], axis=0) 382 | bbox_label[:, 0:3] -= room_xyz_min 383 | bbox_label[:, 3:6] -= room_xyz_min 384 | 385 | fout = open(out_filename, 'w') 386 | for i in range(bbox_label.shape[0]): 387 | fout.write('%f %f %f %f %f %f %d\n' % \ 388 | (bbox_label[i,0], bbox_label[i,1], bbox_label[i,2], 389 | bbox_label[i,3], bbox_label[i,4], bbox_label[i,5], 390 | bbox_label[i,6])) 391 | fout.close() 392 | 393 | def bbox_label_to_obj(input_filename, out_filename_prefix, easy_view=False): 394 | """ Visualization of bounding boxes. 395 | 396 | Args: 397 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 398 | out_filename_prefix: OBJ filename prefix, 399 | visualize object by g_label2color 400 | easy_view: if True, only visualize furniture and floor 401 | Returns: 402 | output a list of OBJ file and MTL files with the same prefix 403 | """ 404 | bbox_label = np.loadtxt(input_filename) 405 | bbox = bbox_label[:, 0:6] 406 | label = bbox_label[:, -1].astype(int) 407 | v_cnt = 0 # count vertex 408 | ins_cnt = 0 # count instance 409 | for i in range(bbox.shape[0]): 410 | if easy_view and (label[i] not in g_easy_view_labels): 411 | continue 412 | obj_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.obj' 413 | mtl_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.mtl' 414 | fout_obj = open(obj_filename, 'w') 415 | fout_mtl = open(mtl_filename, 'w') 416 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 417 | 418 | length = bbox[i, 3:6] - bbox[i, 0:3] 419 | a = length[0] 420 | b = length[1] 421 | c = length[2] 422 | x = bbox[i, 0] 423 | y = bbox[i, 1] 424 | z = bbox[i, 2] 425 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 426 | 427 | material = 'material%d' % (ins_cnt) 428 | fout_obj.write('usemtl %s\n' % (material)) 429 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 430 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 431 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 432 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 433 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 434 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 435 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 436 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 437 | fout_obj.write('g default\n') 438 | v_cnt = 0 # for individual box 439 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 440 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 441 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 442 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 443 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 444 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 445 | fout_obj.write('\n') 446 | 447 | fout_mtl.write('newmtl %s\n' % (material)) 448 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 449 | fout_mtl.write('\n') 450 | fout_obj.close() 451 | fout_mtl.close() 452 | 453 | v_cnt += 8 454 | ins_cnt += 1 455 | 456 | def bbox_label_to_obj_room(input_filename, out_filename_prefix, easy_view=False, permute=None, center=False, exclude_table=False): 457 | """ Visualization of bounding boxes. 458 | 459 | Args: 460 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 461 | out_filename_prefix: OBJ filename prefix, 462 | visualize object by g_label2color 463 | easy_view: if True, only visualize furniture and floor 464 | permute: if not None, permute XYZ for rendering, e.g. [0 2 1] 465 | center: if True, move obj to have zero origin 466 | Returns: 467 | output a list of OBJ file and MTL files with the same prefix 468 | """ 469 | bbox_label = np.loadtxt(input_filename) 470 | bbox = bbox_label[:, 0:6] 471 | if permute is not None: 472 | assert(len(permute)==3) 473 | permute = np.array(permute) 474 | bbox[:,0:3] = bbox[:,permute] 475 | bbox[:,3:6] = bbox[:,permute+3] 476 | if center: 477 | xyz_max = np.amax(bbox[:,3:6], 0) 478 | bbox[:,0:3] -= (xyz_max/2.0) 479 | bbox[:,3:6] -= (xyz_max/2.0) 480 | bbox /= np.max(xyz_max/2.0) 481 | label = bbox_label[:, -1].astype(int) 482 | obj_filename = out_filename_prefix+'.obj' 483 | mtl_filename = out_filename_prefix+'.mtl' 484 | 485 | fout_obj = open(obj_filename, 'w') 486 | fout_mtl = open(mtl_filename, 'w') 487 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 488 | v_cnt = 0 # count vertex 489 | ins_cnt = 0 # count instance 490 | for i in range(bbox.shape[0]): 491 | if easy_view and (label[i] not in g_easy_view_labels): 492 | continue 493 | if exclude_table and label[i] == g_classes.index('table'): 494 | continue 495 | 496 | length = bbox[i, 3:6] - bbox[i, 0:3] 497 | a = length[0] 498 | b = length[1] 499 | c = length[2] 500 | x = bbox[i, 0] 501 | y = bbox[i, 1] 502 | z = bbox[i, 2] 503 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 504 | 505 | material = 'material%d' % (ins_cnt) 506 | fout_obj.write('usemtl %s\n' % (material)) 507 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 508 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 509 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 510 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 511 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 512 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 513 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 514 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 515 | fout_obj.write('g default\n') 516 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 517 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 518 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 519 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 520 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 521 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 522 | fout_obj.write('\n') 523 | 524 | fout_mtl.write('newmtl %s\n' % (material)) 525 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 526 | fout_mtl.write('\n') 527 | 528 | v_cnt += 8 529 | ins_cnt += 1 530 | 531 | fout_obj.close() 532 | fout_mtl.close() 533 | 534 | 535 | def collect_point_bounding_box(anno_path, out_filename, file_format): 536 | """ Compute bounding boxes from each instance in original dataset files on 537 | one room. **We assume the bbox is aligned with XYZ coordinate.** 538 | Save both the point XYZRGB and the bounding box for the point's 539 | parent element. 540 | 541 | Args: 542 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 543 | out_filename: path to save instance bounding boxes for each point, 544 | plus the point's XYZRGBL 545 | each line is XYZRGBL offsetX offsetY offsetZ a b c, 546 | where cx = X+offsetX, cy=X+offsetY, cz=Z+offsetZ 547 | where (cx,cy,cz) is center of the box, a,b,c are distances from center 548 | to the surfaces of the box, i.e. x1 = cx-a, x2 = cx+a, y1=cy-b etc. 549 | file_format: output file format, txt or numpy 550 | Returns: 551 | None 552 | 553 | Note: 554 | room points are shifted, the most negative point is now at origin. 555 | """ 556 | point_bbox_list = [] 557 | 558 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 559 | cls = os.path.basename(f).split('_')[0] 560 | if cls not in g_classes: # note: in some room there is 'staris' class.. 561 | cls = 'clutter' 562 | points = np.loadtxt(f) # Nx6 563 | label = g_class2label[cls] # N, 564 | # Compute tightest axis aligned bounding box 565 | xyz_min = np.amin(points[:, 0:3], axis=0) # 3, 566 | xyz_max = np.amax(points[:, 0:3], axis=0) # 3, 567 | xyz_center = (xyz_min + xyz_max) / 2 568 | dimension = (xyz_max - xyz_min) / 2 569 | 570 | xyz_offsets = xyz_center - points[:,0:3] # Nx3 571 | dimensions = np.ones((points.shape[0],3)) * dimension # Nx3 572 | labels = np.ones((points.shape[0],1)) * label # N 573 | point_bbox_list.append(np.concatenate([points, labels, 574 | xyz_offsets, dimensions], 1)) # Nx13 575 | 576 | point_bbox = np.concatenate(point_bbox_list, 0) # KxNx13 577 | room_xyz_min = np.amin(point_bbox[:, 0:3], axis=0) 578 | point_bbox[:, 0:3] -= room_xyz_min 579 | 580 | if file_format == 'txt': 581 | fout = open(out_filename, 'w') 582 | for i in range(point_bbox.shape[0]): 583 | fout.write('%f %f %f %d %d %d %d %f %f %f %f %f %f\n' % \ 584 | (point_bbox[i,0], point_bbox[i,1], point_bbox[i,2], 585 | point_bbox[i,3], point_bbox[i,4], point_bbox[i,5], 586 | point_bbox[i,6], 587 | point_bbox[i,7], point_bbox[i,8], point_bbox[i,9], 588 | point_bbox[i,10], point_bbox[i,11], point_bbox[i,12])) 589 | 590 | fout.close() 591 | elif file_format == 'numpy': 592 | np.save(out_filename, point_bbox) 593 | else: 594 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 595 | (file_format)) 596 | exit() 597 | 598 | 599 | -------------------------------------------------------------------------------- /data_utils/meta/anno_paths.txt: -------------------------------------------------------------------------------- 1 | Area_1/conferenceRoom_1/Annotations 2 | Area_1/conferenceRoom_2/Annotations 3 | Area_1/copyRoom_1/Annotations 4 | Area_1/hallway_1/Annotations 5 | Area_1/hallway_2/Annotations 6 | Area_1/hallway_3/Annotations 7 | Area_1/hallway_4/Annotations 8 | Area_1/hallway_5/Annotations 9 | Area_1/hallway_6/Annotations 10 | Area_1/hallway_7/Annotations 11 | Area_1/hallway_8/Annotations 12 | Area_1/office_10/Annotations 13 | Area_1/office_11/Annotations 14 | Area_1/office_12/Annotations 15 | Area_1/office_13/Annotations 16 | Area_1/office_14/Annotations 17 | Area_1/office_15/Annotations 18 | Area_1/office_16/Annotations 19 | Area_1/office_17/Annotations 20 | Area_1/office_18/Annotations 21 | Area_1/office_19/Annotations 22 | Area_1/office_1/Annotations 23 | Area_1/office_20/Annotations 24 | Area_1/office_21/Annotations 25 | Area_1/office_22/Annotations 26 | Area_1/office_23/Annotations 27 | Area_1/office_24/Annotations 28 | Area_1/office_25/Annotations 29 | Area_1/office_26/Annotations 30 | Area_1/office_27/Annotations 31 | Area_1/office_28/Annotations 32 | Area_1/office_29/Annotations 33 | Area_1/office_2/Annotations 34 | Area_1/office_30/Annotations 35 | Area_1/office_31/Annotations 36 | Area_1/office_3/Annotations 37 | Area_1/office_4/Annotations 38 | Area_1/office_5/Annotations 39 | Area_1/office_6/Annotations 40 | Area_1/office_7/Annotations 41 | Area_1/office_8/Annotations 42 | Area_1/office_9/Annotations 43 | Area_1/pantry_1/Annotations 44 | Area_1/WC_1/Annotations 45 | Area_2/auditorium_1/Annotations 46 | Area_2/auditorium_2/Annotations 47 | Area_2/conferenceRoom_1/Annotations 48 | Area_2/hallway_10/Annotations 49 | Area_2/hallway_11/Annotations 50 | Area_2/hallway_12/Annotations 51 | Area_2/hallway_1/Annotations 52 | Area_2/hallway_2/Annotations 53 | Area_2/hallway_3/Annotations 54 | Area_2/hallway_4/Annotations 55 | Area_2/hallway_5/Annotations 56 | Area_2/hallway_6/Annotations 57 | Area_2/hallway_7/Annotations 58 | Area_2/hallway_8/Annotations 59 | Area_2/hallway_9/Annotations 60 | Area_2/office_10/Annotations 61 | Area_2/office_11/Annotations 62 | Area_2/office_12/Annotations 63 | Area_2/office_13/Annotations 64 | Area_2/office_14/Annotations 65 | Area_2/office_1/Annotations 66 | Area_2/office_2/Annotations 67 | Area_2/office_3/Annotations 68 | Area_2/office_4/Annotations 69 | Area_2/office_5/Annotations 70 | Area_2/office_6/Annotations 71 | Area_2/office_7/Annotations 72 | Area_2/office_8/Annotations 73 | Area_2/office_9/Annotations 74 | Area_2/storage_1/Annotations 75 | Area_2/storage_2/Annotations 76 | Area_2/storage_3/Annotations 77 | Area_2/storage_4/Annotations 78 | Area_2/storage_5/Annotations 79 | Area_2/storage_6/Annotations 80 | Area_2/storage_7/Annotations 81 | Area_2/storage_8/Annotations 82 | Area_2/storage_9/Annotations 83 | Area_2/WC_1/Annotations 84 | Area_2/WC_2/Annotations 85 | Area_3/conferenceRoom_1/Annotations 86 | Area_3/hallway_1/Annotations 87 | Area_3/hallway_2/Annotations 88 | Area_3/hallway_3/Annotations 89 | Area_3/hallway_4/Annotations 90 | Area_3/hallway_5/Annotations 91 | Area_3/hallway_6/Annotations 92 | Area_3/lounge_1/Annotations 93 | Area_3/lounge_2/Annotations 94 | Area_3/office_10/Annotations 95 | Area_3/office_1/Annotations 96 | Area_3/office_2/Annotations 97 | Area_3/office_3/Annotations 98 | Area_3/office_4/Annotations 99 | Area_3/office_5/Annotations 100 | Area_3/office_6/Annotations 101 | Area_3/office_7/Annotations 102 | Area_3/office_8/Annotations 103 | Area_3/office_9/Annotations 104 | Area_3/storage_1/Annotations 105 | Area_3/storage_2/Annotations 106 | Area_3/WC_1/Annotations 107 | Area_3/WC_2/Annotations 108 | Area_4/conferenceRoom_1/Annotations 109 | Area_4/conferenceRoom_2/Annotations 110 | Area_4/conferenceRoom_3/Annotations 111 | Area_4/hallway_10/Annotations 112 | Area_4/hallway_11/Annotations 113 | Area_4/hallway_12/Annotations 114 | Area_4/hallway_13/Annotations 115 | Area_4/hallway_14/Annotations 116 | Area_4/hallway_1/Annotations 117 | Area_4/hallway_2/Annotations 118 | Area_4/hallway_3/Annotations 119 | Area_4/hallway_4/Annotations 120 | Area_4/hallway_5/Annotations 121 | Area_4/hallway_6/Annotations 122 | Area_4/hallway_7/Annotations 123 | Area_4/hallway_8/Annotations 124 | Area_4/hallway_9/Annotations 125 | Area_4/lobby_1/Annotations 126 | Area_4/lobby_2/Annotations 127 | Area_4/office_10/Annotations 128 | Area_4/office_11/Annotations 129 | Area_4/office_12/Annotations 130 | Area_4/office_13/Annotations 131 | Area_4/office_14/Annotations 132 | Area_4/office_15/Annotations 133 | Area_4/office_16/Annotations 134 | Area_4/office_17/Annotations 135 | Area_4/office_18/Annotations 136 | Area_4/office_19/Annotations 137 | Area_4/office_1/Annotations 138 | Area_4/office_20/Annotations 139 | Area_4/office_21/Annotations 140 | Area_4/office_22/Annotations 141 | Area_4/office_2/Annotations 142 | Area_4/office_3/Annotations 143 | Area_4/office_4/Annotations 144 | Area_4/office_5/Annotations 145 | Area_4/office_6/Annotations 146 | Area_4/office_7/Annotations 147 | Area_4/office_8/Annotations 148 | Area_4/office_9/Annotations 149 | Area_4/storage_1/Annotations 150 | Area_4/storage_2/Annotations 151 | Area_4/storage_3/Annotations 152 | Area_4/storage_4/Annotations 153 | Area_4/WC_1/Annotations 154 | Area_4/WC_2/Annotations 155 | Area_4/WC_3/Annotations 156 | Area_4/WC_4/Annotations 157 | Area_5/conferenceRoom_1/Annotations 158 | Area_5/conferenceRoom_2/Annotations 159 | Area_5/conferenceRoom_3/Annotations 160 | Area_5/hallway_10/Annotations 161 | Area_5/hallway_11/Annotations 162 | Area_5/hallway_12/Annotations 163 | Area_5/hallway_13/Annotations 164 | Area_5/hallway_14/Annotations 165 | Area_5/hallway_15/Annotations 166 | Area_5/hallway_1/Annotations 167 | Area_5/hallway_2/Annotations 168 | Area_5/hallway_3/Annotations 169 | Area_5/hallway_4/Annotations 170 | Area_5/hallway_5/Annotations 171 | Area_5/hallway_6/Annotations 172 | Area_5/hallway_7/Annotations 173 | Area_5/hallway_8/Annotations 174 | Area_5/hallway_9/Annotations 175 | Area_5/lobby_1/Annotations 176 | Area_5/office_10/Annotations 177 | Area_5/office_11/Annotations 178 | Area_5/office_12/Annotations 179 | Area_5/office_13/Annotations 180 | Area_5/office_14/Annotations 181 | Area_5/office_15/Annotations 182 | Area_5/office_16/Annotations 183 | Area_5/office_17/Annotations 184 | Area_5/office_18/Annotations 185 | Area_5/office_19/Annotations 186 | Area_5/office_1/Annotations 187 | Area_5/office_20/Annotations 188 | Area_5/office_21/Annotations 189 | Area_5/office_22/Annotations 190 | Area_5/office_23/Annotations 191 | Area_5/office_24/Annotations 192 | Area_5/office_25/Annotations 193 | Area_5/office_26/Annotations 194 | Area_5/office_27/Annotations 195 | Area_5/office_28/Annotations 196 | Area_5/office_29/Annotations 197 | Area_5/office_2/Annotations 198 | Area_5/office_30/Annotations 199 | Area_5/office_31/Annotations 200 | Area_5/office_32/Annotations 201 | Area_5/office_33/Annotations 202 | Area_5/office_34/Annotations 203 | Area_5/office_35/Annotations 204 | Area_5/office_36/Annotations 205 | Area_5/office_37/Annotations 206 | Area_5/office_38/Annotations 207 | Area_5/office_39/Annotations 208 | Area_5/office_3/Annotations 209 | Area_5/office_40/Annotations 210 | Area_5/office_41/Annotations 211 | Area_5/office_42/Annotations 212 | Area_5/office_4/Annotations 213 | Area_5/office_5/Annotations 214 | Area_5/office_6/Annotations 215 | Area_5/office_7/Annotations 216 | Area_5/office_8/Annotations 217 | Area_5/office_9/Annotations 218 | Area_5/pantry_1/Annotations 219 | Area_5/storage_1/Annotations 220 | Area_5/storage_2/Annotations 221 | Area_5/storage_3/Annotations 222 | Area_5/storage_4/Annotations 223 | Area_5/WC_1/Annotations 224 | Area_5/WC_2/Annotations 225 | Area_6/conferenceRoom_1/Annotations 226 | Area_6/copyRoom_1/Annotations 227 | Area_6/hallway_1/Annotations 228 | Area_6/hallway_2/Annotations 229 | Area_6/hallway_3/Annotations 230 | Area_6/hallway_4/Annotations 231 | Area_6/hallway_5/Annotations 232 | Area_6/hallway_6/Annotations 233 | Area_6/lounge_1/Annotations 234 | Area_6/office_10/Annotations 235 | Area_6/office_11/Annotations 236 | Area_6/office_12/Annotations 237 | Area_6/office_13/Annotations 238 | Area_6/office_14/Annotations 239 | Area_6/office_15/Annotations 240 | Area_6/office_16/Annotations 241 | Area_6/office_17/Annotations 242 | Area_6/office_18/Annotations 243 | Area_6/office_19/Annotations 244 | Area_6/office_1/Annotations 245 | Area_6/office_20/Annotations 246 | Area_6/office_21/Annotations 247 | Area_6/office_22/Annotations 248 | Area_6/office_23/Annotations 249 | Area_6/office_24/Annotations 250 | Area_6/office_25/Annotations 251 | Area_6/office_26/Annotations 252 | Area_6/office_27/Annotations 253 | Area_6/office_28/Annotations 254 | Area_6/office_29/Annotations 255 | Area_6/office_2/Annotations 256 | Area_6/office_30/Annotations 257 | Area_6/office_31/Annotations 258 | Area_6/office_32/Annotations 259 | Area_6/office_33/Annotations 260 | Area_6/office_34/Annotations 261 | Area_6/office_35/Annotations 262 | Area_6/office_36/Annotations 263 | Area_6/office_37/Annotations 264 | Area_6/office_3/Annotations 265 | Area_6/office_4/Annotations 266 | Area_6/office_5/Annotations 267 | Area_6/office_6/Annotations 268 | Area_6/office_7/Annotations 269 | Area_6/office_8/Annotations 270 | Area_6/office_9/Annotations 271 | Area_6/openspace_1/Annotations 272 | Area_6/pantry_1/Annotations 273 | -------------------------------------------------------------------------------- /data_utils/meta/class_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | table 9 | chair 10 | sofa 11 | bookcase 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /images/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/images/example.jpg -------------------------------------------------------------------------------- /images/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/PointSCNet/0f730f22afb00f8e899ce1b0d88f93ad57e938f1/images/model.jpg -------------------------------------------------------------------------------- /models/SCNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from utils import PointNetSetAbstractionMsg, PointNetSetAbstraction 4 | from z_order import * 5 | 6 | 7 | def get_relation_zorder_sample(input_u, input_v, random_sample=True, sample_size=64): 8 | batchsize, in_uchannels, length = input_u.shape 9 | _, in_vchannels, _ = input_v.shape 10 | device = input_u.device 11 | if not random_sample: 12 | sample_size = length 13 | 14 | input_u = input_u.permute(0, 2, 1) 15 | input_v = input_v.permute(0, 2, 1) 16 | ides = z_order_point_sample(input_u[:, :, :3], sample_size) 17 | 18 | batch_indices = torch.arange(batchsize, dtype=torch.long).to(device).view(batchsize, 1).repeat(1, sample_size) 19 | 20 | temp_relationu = input_u[batch_indices, ides, :].permute(0, 2, 1) 21 | temp_relationv = input_v[batch_indices, ides, :].permute(0, 2, 1) 22 | input_u = input_u.permute(0, 2, 1) 23 | input_v = input_v.permute(0, 2, 1) 24 | relation_u = torch.cat([input_u.view(batchsize, -1, length, 1).repeat(1, 1, 1, sample_size), 25 | temp_relationu.view(batchsize, -1, 1, sample_size).repeat(1, 1, length, 1)], dim=1) 26 | relation_v = torch.cat([input_v.view(batchsize, -1, length, 1).repeat(1, 1, 1, sample_size), 27 | temp_relationv.view(batchsize, -1, 1, sample_size).repeat(1, 1, length, 1)], dim=1) 28 | 29 | return relation_u, relation_v, temp_relationu, temp_relationv 30 | 31 | 32 | 33 | class PSCNChannelAttention(nn.Module): 34 | def __init__(self, channel_in): 35 | super(PSCNChannelAttention, self).__init__() 36 | self.avg_pool = nn.AdaptiveAvgPool1d(1) # 全局自适应池化 37 | self.max_pool = nn.AdaptiveMaxPool1d(1) 38 | 39 | self.fc1 = nn.Sequential( 40 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1), 41 | nn.ReLU(inplace=True), 42 | nn.Conv1d(channel_in // 2, channel_in, bias=False, kernel_size=1), 43 | ) 44 | self.fc2 = nn.Sequential( 45 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1), 46 | nn.ReLU(inplace=True), 47 | nn.Conv1d(channel_in // 2, channel_in, bias=False, kernel_size=1), 48 | 49 | ) 50 | self.sigmoid = nn.Sigmoid() 51 | 52 | def forward(self, x): 53 | batch_size, channel_num, _ = x.size() 54 | 55 | avg_out = self.avg_pool(x).view(batch_size, channel_num, 1) # squeeze操作 56 | max_out = self.max_pool(x).view(batch_size, channel_num, 1) 57 | 58 | avg_y = self.fc1(avg_out).view(batch_size, channel_num, 1) # FC获取通道注意力权重,是具有全局信息的 59 | max_y = self.fc2(max_out).view(batch_size, channel_num, 1) 60 | out = self.sigmoid(avg_y + max_y) 61 | return out # 注意力作用每一个通道上 62 | 63 | # return x + x * out.expand_as(x) # 注意力作用每一个通道上 64 | 65 | 66 | class PSCNSpatialAttention(nn.Module): 67 | def __init__(self, channel_in, channel_out): 68 | super(PSCNSpatialAttention, self).__init__() 69 | self.avg_pool = nn.AdaptiveAvgPool1d(1) # 全局自适应池化 70 | self.max_pool = nn.AdaptiveMaxPool1d(1) 71 | 72 | self.fc1 = nn.Sequential( 73 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1), 74 | nn.ReLU(inplace=True), 75 | nn.Conv1d(channel_in // 2, channel_out, bias=False, kernel_size=1), 76 | ) 77 | self.fc2 = nn.Sequential( 78 | nn.Conv1d(channel_in, channel_in // 2, bias=False, kernel_size=1), 79 | nn.ReLU(inplace=True), 80 | nn.Conv1d(channel_in // 2, channel_out, bias=False, kernel_size=1), 81 | 82 | ) 83 | self.sigmoid = nn.Sigmoid() 84 | 85 | def forward(self, x): 86 | batch_size, channel_num, point_num = x.size() 87 | 88 | avg_out = self.avg_pool(x).view(batch_size, channel_num, 1) # squeeze操作 89 | max_out = self.max_pool(x).view(batch_size, channel_num, 1) 90 | 91 | avg_y = self.fc1(avg_out).view(batch_size, 1, point_num) # FC获取通道注意力权重,是具有全局信息的 92 | max_y = self.fc2(max_out).view(batch_size, 1, point_num) 93 | out = self.sigmoid(avg_y + max_y) 94 | return out # 注意力作用每一个空间上 95 | # return x + x * out.expand_as(x) # 注意力作用每一个空间上 96 | # return x + x * out.expand_as(x), out # 注意力作用每一个空间上 97 | 98 | 99 | class PointSCN(nn.Module): 100 | def __init__(self, in_uchannels, in_vchannels, random_sample=True, sample_size=64): 101 | super(PointSCN, self).__init__() 102 | 103 | self.random_sample = random_sample 104 | self.sample_size = sample_size 105 | 106 | self.conv_gu = nn.Conv2d(2 * in_uchannels, 2 * in_uchannels, 1) 107 | self.bn1 = nn.BatchNorm2d(2 * in_uchannels) 108 | 109 | self.conv_gv = nn.Conv2d(2 * in_vchannels, 2 * in_vchannels, 1) 110 | self.bn2 = nn.BatchNorm2d(2 * in_vchannels) 111 | 112 | self.conv_uv = nn.Conv2d(2 * in_uchannels + 2 * in_vchannels, in_vchannels, 1) 113 | self.bn3 = nn.BatchNorm2d(in_vchannels) 114 | 115 | self.conv_f = nn.Conv1d(in_vchannels, in_vchannels, 1) 116 | self.bn4 = nn.BatchNorm1d(in_vchannels) 117 | 118 | def forward(self, input_u, input_v): 119 | """ 120 | Input: 121 | input_u: input points position data, [B, C, N] 122 | input_v: input points data, [B, D, N] 123 | Return: 124 | new_xyz: sampled points position data, [B, C, S] 125 | new_points_concat: sample points feature data, [B, D', S] 126 | """ 127 | 128 | relation_u, relation_v, _, _ = get_relation_zorder_sample(input_u, input_v, random_sample=self.random_sample, 129 | sample_size=self.sample_size) 130 | 131 | relation_uv = torch.cat( 132 | [F.relu(self.bn1(self.conv_gu(relation_u))), F.relu(self.bn2(self.conv_gv(relation_v)))], dim=1) 133 | 134 | relation_uv = F.relu(self.bn3(self.conv_uv(relation_uv))) 135 | 136 | relation_uv = torch.max(relation_uv, 3)[0] 137 | 138 | relation_uv = F.relu(self.bn4(self.conv_f(relation_uv))) 139 | relation_uv = torch.cat([input_v + relation_uv, input_u], dim=1) 140 | 141 | return relation_uv 142 | 143 | 144 | class get_model(nn.Module): 145 | def __init__(self, num_class, normal_channel=True): 146 | super(get_model, self).__init__() 147 | in_channel = 3 if normal_channel else 0 148 | self.normal_channel = normal_channel 149 | self.sa1 = PointNetSetAbstractionMsg(256, [0.1, 0.4], [16, 128], in_channel, 150 | [[32, 32, 64], [64, 96, 128]]) 151 | 152 | self.PSCN1 = PointSCN(3, 64 + 128, random_sample=True) 153 | self.attention1 = PSCNChannelAttention(128 + 64 + 3) 154 | self.attention3 = PSCNSpatialAttention(128 + 64 + 3, 256) 155 | 156 | self.sa2 = PointNetSetAbstraction(None, None, None, 128 + 64 + 3 + 3, [256, 512, 1024], True) # 含池化得到 157 | 158 | self.fc1 = nn.Linear(1024, 512) 159 | self.bn1 = nn.BatchNorm1d(512) 160 | self.drop1 = nn.Dropout(0.4) 161 | self.fc2 = nn.Linear(512, 256) 162 | self.bn2 = nn.BatchNorm1d(256) 163 | self.drop2 = nn.Dropout(0.5) 164 | self.fc3 = nn.Linear(256, num_class) 165 | 166 | def forward(self, xyz): 167 | B, _, _ = xyz.shape 168 | if self.normal_channel: 169 | norm = xyz[:, 3:, :] 170 | xyz = xyz[:, :3, :] 171 | else: 172 | norm = None 173 | l1_xyz, l1_points = self.sa1(xyz, norm) 174 | l1_points = self.PSCN1(l1_xyz, l1_points) 175 | l1_points_att1 = self.attention1(l1_points) 176 | l1_points_att3 = self.attention3(l1_points) 177 | l1_points = l1_points * l1_points_att1.expand_as(l1_points) * l1_points_att3.expand_as(l1_points) 178 | l3_xyz, l3_points = self.sa2(l1_xyz, l1_points) 179 | x = l3_points.view(B, 1024) 180 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 181 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 182 | x = self.fc3(x) 183 | x = F.log_softmax(x, -1) 184 | return x, l3_points 185 | 186 | 187 | class get_loss(nn.Module): 188 | def __init__(self): 189 | super(get_loss, self).__init__() 190 | 191 | def forward(self, pred, target, trans_feat): 192 | total_loss = F.nll_loss(pred, target) 193 | 194 | return total_loss 195 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | def timeit(tag, t): 9 | print("{}: {}s".format(tag, time() - t)) 10 | return time() 11 | 12 | 13 | def pc_normalize(pc): 14 | l = pc.shape[0] 15 | centroid = np.mean(pc, axis=0) 16 | pc = pc - centroid 17 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 18 | pc = pc / m 19 | return pc 20 | 21 | 22 | def square_distance(src, dst): 23 | """ 24 | Calculate Euclid distance between each two points. 25 | 26 | src^T * dst = xn * xm + yn * ym + zn * zm; 27 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 28 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 31 | 32 | Input: 33 | src: source points, [B, N, C] 34 | dst: target points, [B, M, C] 35 | Output: 36 | dist: per-point square distance, [B, N, M] 37 | """ 38 | B, N, _ = src.shape 39 | _, M, _ = dst.shape 40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 41 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 43 | return dist 44 | 45 | 46 | def index_points(points, idx): 47 | """ 48 | 49 | Input: 50 | points: input points data, [B, N, C] 51 | idx: sample index data, [B, S] 52 | Return: 53 | new_points:, indexed points data, [B, S, C] 54 | """ 55 | device = points.device 56 | B = points.shape[0] 57 | view_shape = list(idx.shape) 58 | view_shape[1:] = [1] * (len(view_shape) - 1) 59 | repeat_shape = list(idx.shape) 60 | repeat_shape[0] = 1 61 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 62 | new_points = points[batch_indices, idx, :] 63 | return new_points 64 | 65 | 66 | def farthest_point_sample(xyz, npoint): 67 | """ 68 | Input: 69 | xyz: pointcloud data, [B, N, 3] 70 | npoint: number of samples 71 | Return: 72 | centroids: sampled pointcloud index, [B, npoint] 73 | """ 74 | device = xyz.device 75 | B, N, C = xyz.shape 76 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 77 | distance = torch.ones(B, N).to(device) * 1e10 78 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 79 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 80 | for i in range(npoint): 81 | centroids[:, i] = farthest 82 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 83 | dist = torch.sum((xyz - centroid) ** 2, -1) 84 | mask = dist < distance 85 | distance[mask] = dist[mask] 86 | farthest = torch.max(distance, -1)[1] 87 | return centroids 88 | 89 | 90 | def query_ball_point(radius, nsample, xyz, new_xyz): 91 | """ 92 | Input: 93 | radius: local region radius 94 | nsample: max sample number in local region 95 | xyz: all points, [B, N, 3] 96 | new_xyz: query points, [B, S, 3] 97 | Return: 98 | group_idx: grouped points index, [B, S, nsample] 99 | """ 100 | device = xyz.device 101 | B, N, C = xyz.shape 102 | _, S, _ = new_xyz.shape 103 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 104 | sqrdists = square_distance(new_xyz, xyz) 105 | group_idx[sqrdists > radius ** 2] = N 106 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 107 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 108 | mask = group_idx == N 109 | group_idx[mask] = group_first[mask] 110 | return group_idx 111 | 112 | 113 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 114 | """ 115 | Input: 116 | npoint: 117 | radius: 118 | nsample: 119 | xyz: input points position data, [B, N, 3] 120 | points: input points data, [B, N, D] 121 | Return: 122 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 123 | new_points: sampled points data, [B, npoint, nsample, 3+D] 124 | """ 125 | B, N, C = xyz.shape 126 | S = npoint 127 | # fps_idx = z_order.z_order_point_sample(xyz, npoint) # [B, npoint, C] 128 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 129 | new_xyz = index_points(xyz, fps_idx) 130 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 131 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 132 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 133 | 134 | if points is not None: 135 | grouped_points = index_points(points, idx) 136 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 137 | else: 138 | new_points = grouped_xyz_norm 139 | if returnfps: 140 | return new_xyz, new_points, grouped_xyz, fps_idx 141 | else: 142 | return new_xyz, new_points 143 | 144 | 145 | def sample_and_group_all(xyz, points): 146 | """ 147 | Input: 148 | xyz: input points position data, [B, N, 3] 149 | points: input points data, [B, N, D] 150 | Return: 151 | new_xyz: sampled points position data, [B, 1, 3] 152 | new_points: sampled points data, [B, 1, N, 3+D] 153 | """ 154 | device = xyz.device 155 | B, N, C = xyz.shape 156 | new_xyz = torch.zeros(B, 1, C).to(device) 157 | grouped_xyz = xyz.view(B, 1, N, C) 158 | if points is not None: 159 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 160 | else: 161 | new_points = grouped_xyz 162 | return new_xyz, new_points 163 | 164 | 165 | class PointNetSetAbstraction(nn.Module): 166 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 167 | super(PointNetSetAbstraction, self).__init__() 168 | self.npoint = npoint 169 | self.radius = radius 170 | self.nsample = nsample 171 | self.mlp_convs = nn.ModuleList() 172 | self.mlp_bns = nn.ModuleList() 173 | last_channel = in_channel 174 | for out_channel in mlp: 175 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 176 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 177 | last_channel = out_channel 178 | self.group_all = group_all 179 | 180 | def forward(self, xyz, points): 181 | """ 182 | Input: 183 | xyz: input points position data, [B, C, N] 184 | points: input points data, [B, D, N] 185 | Return: 186 | new_xyz: sampled points position data, [B, C, S] 187 | new_points_concat: sample points feature data, [B, D', S] 188 | """ 189 | xyz = xyz.permute(0, 2, 1) 190 | if points is not None: 191 | points = points.permute(0, 2, 1) 192 | 193 | if self.group_all: 194 | new_xyz, new_points = sample_and_group_all(xyz, points) 195 | else: 196 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 197 | # new_xyz: sampled points position data, [B, npoint, C] 198 | # new_points: sampled points data, [B, npoint, nsample, C+D] 199 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 200 | for i, conv in enumerate(self.mlp_convs): 201 | bn = self.mlp_bns[i] 202 | new_points = F.relu(bn(conv(new_points))) 203 | 204 | new_points = torch.max(new_points, 2)[0] 205 | new_xyz = new_xyz.permute(0, 2, 1) 206 | return new_xyz, new_points 207 | 208 | class PointNetSetAbstractionMsg(nn.Module): 209 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 210 | super(PointNetSetAbstractionMsg, self).__init__() 211 | self.npoint = npoint 212 | self.radius_list = radius_list 213 | self.nsample_list = nsample_list 214 | self.conv_blocks = nn.ModuleList() 215 | self.bn_blocks = nn.ModuleList() 216 | for i in range(len(mlp_list)): 217 | convs = nn.ModuleList() 218 | bns = nn.ModuleList() 219 | last_channel = in_channel + 3 220 | for out_channel in mlp_list[i]: 221 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 222 | bns.append(nn.BatchNorm2d(out_channel)) 223 | last_channel = out_channel 224 | self.conv_blocks.append(convs) 225 | self.bn_blocks.append(bns) 226 | 227 | def forward(self, xyz, points): 228 | """ 229 | Input: 230 | xyz: input points position data, [B, C, N] 231 | points: input points data, [B, D, N] 232 | Return: 233 | new_xyz: sampled points position data, [B, C, S] 234 | new_points_concat: sample points feature data, [B, D', S] 235 | """ 236 | xyz = xyz.permute(0, 2, 1) 237 | if points is not None: 238 | points = points.permute(0, 2, 1) 239 | 240 | B, N, C = xyz.shape 241 | S = self.npoint 242 | 243 | # new_xyz = index_points(xyz, z_order.z_order_point_sample(xyz, S)) 244 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 245 | new_points_list = [] 246 | for i, radius in enumerate(self.radius_list): 247 | K = self.nsample_list[i] 248 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 249 | grouped_xyz = index_points(xyz, group_idx) 250 | grouped_xyz -= new_xyz.view(B, S, 1, C) 251 | if points is not None: 252 | grouped_points = index_points(points, group_idx) 253 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 254 | else: 255 | grouped_points = grouped_xyz 256 | 257 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 258 | for j in range(len(self.conv_blocks[i])): 259 | conv = self.conv_blocks[i][j] 260 | bn = self.bn_blocks[i][j] 261 | grouped_points = F.relu(bn(conv(grouped_points))) 262 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 263 | new_points_list.append(new_points) 264 | 265 | new_xyz = new_xyz.permute(0, 2, 1) 266 | new_points_concat = torch.cat(new_points_list, dim=1) 267 | return new_xyz, new_points_concat 268 | 269 | class PointNetFeaturePropagation(nn.Module): 270 | def __init__(self, in_channel, mlp): 271 | super(PointNetFeaturePropagation, self).__init__() 272 | self.mlp_convs = nn.ModuleList() 273 | self.mlp_bns = nn.ModuleList() 274 | last_channel = in_channel 275 | for out_channel in mlp: 276 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 277 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 278 | last_channel = out_channel 279 | 280 | def forward(self, xyz1, xyz2, points1, points2): 281 | """ 282 | Input: 283 | xyz1: input points position data, [B, C, N] 284 | xyz2: sampled input points position data, [B, C, S] 285 | points1: input points data, [B, D, N] 286 | points2: sampled input points data, [B, D, S] 287 | Return: 288 | new_points: upsampled points data, [B, D', N] 289 | """ 290 | xyz1 = xyz1.permute(0, 2, 1) 291 | xyz2 = xyz2.permute(0, 2, 1) 292 | 293 | points2 = points2.permute(0, 2, 1) 294 | B, N, C = xyz1.shape 295 | _, S, _ = xyz2.shape 296 | 297 | if S == 1: 298 | interpolated_points = points2.repeat(1, N, 1) 299 | else: 300 | dists = square_distance(xyz1, xyz2) 301 | dists, idx = dists.sort(dim=-1) 302 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 303 | 304 | dist_recip = 1.0 / (dists + 1e-8) 305 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 306 | weight = dist_recip / norm 307 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 308 | 309 | if points1 is not None: 310 | points1 = points1.permute(0, 2, 1) 311 | new_points = torch.cat([points1, interpolated_points], dim=-1) 312 | else: 313 | new_points = interpolated_points 314 | 315 | new_points = new_points.permute(0, 2, 1) 316 | for i, conv in enumerate(self.mlp_convs): 317 | bn = self.mlp_bns[i] 318 | new_points = F.relu(bn(conv(new_points))) 319 | return new_points 320 | -------------------------------------------------------------------------------- /models/z_order.py: -------------------------------------------------------------------------------- 1 | # import numpy 2 | import numpy as np 3 | import torch 4 | 5 | from matplotlib import pyplot as plt 6 | import time 7 | 8 | 9 | def round_to_int_32(data): 10 | """ 11 | Takes a Numpy array of float values between 12 | -1 and 1, and rounds them to significant 13 | 32-bit integer values, to be used in the 14 | morton code computation 15 | 16 | :param data: multidimensional numpy array 17 | :return: same as data but in 32-bit int format 18 | """ 19 | # first we rescale points to 0-512 20 | data = 256 * (data + 1) 21 | # now convert to int 22 | data = np.round(2 ** 21 - data).astype(dtype=np.int32) 23 | 24 | return data 25 | 26 | 27 | def split_by_3(x): 28 | """ 29 | Method to separate bits of a 32-bit integer 30 | by 3 positions apart, using the magic bits 31 | https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/ 32 | 33 | :param x: 32-bit integer 34 | :return: x with bits separated 35 | """ 36 | # we only look at 21 bits, since we want to generate 37 | # a 64-bit code eventually (3 x 21 bits = 63 bits, which 38 | # is the maximum we can fit in a 64-bit code) 39 | x &= 0x1fffff # only take first 21 bits 40 | # shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111 41 | x = (x | (x << 32)) & 0x1f00000000ffff 42 | # shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111 43 | x = (x | (x << 16)) & 0x1f0000ff0000ff 44 | # shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000 45 | x = (x | (x << 8)) & 0x100f00f00f00f00f 46 | # shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000 47 | x = (x | (x << 4)) & 0x10c30c30c30c30c3 48 | # shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001 49 | x = (x | (x << 2)) & 0x1249249249249249 50 | 51 | return x 52 | 53 | 54 | def get_z_order(x, y, z): 55 | """ 56 | Given 3 arrays of corresponding x, y, z 57 | coordinates, compute the morton (or z) code for 58 | each point and return an index array 59 | We compute the Morton order as follows: 60 | 1- Split all coordinates by 3 (add 2 zeros between bits) 61 | 2- Shift bits left by 1 for y and 2 for z 62 | 3- Interleave x, shifted y, and shifted z 63 | The mordon order is the final interleaved bit sequence 64 | 65 | :param x: x coordinates 66 | :param y: y coordinates 67 | :param z: z coordinates 68 | :return: index array with morton code 69 | """ 70 | res = 0 71 | res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2 72 | 73 | return res 74 | 75 | 76 | def get_z_values(data): 77 | """ 78 | Computes the z values for a point array 79 | :param data: Nx3 array of x, y, and z location 80 | 81 | :return: Nx1 array of z values 82 | """ 83 | data = data.cpu().detach().numpy() 84 | points_round = round_to_int_32(data) # convert to int 85 | z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2]) 86 | 87 | return z 88 | 89 | 90 | def pointnet_index_points(points, idx): 91 | """ 92 | 93 | Input: 94 | points: input points data, [B, N, C] 95 | idx: sample index data, [B, S] 96 | Return: 97 | new_points:, indexed points data, [B, S, C] 98 | """ 99 | device = points.device 100 | B = points.shape[0] 101 | view_shape = list(idx.shape) 102 | view_shape[1:] = [1] * (len(view_shape) - 1) 103 | # print(view_shape) 104 | repeat_shape = list(idx.shape) 105 | repeat_shape[0] = 1 106 | # print(repeat_shape) 107 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 108 | # print(batch_indices) 109 | # print(batch_indices.shape, idx.shape) 110 | new_points = points[batch_indices, idx, :] 111 | # print(new_points) 112 | return new_points 113 | 114 | 115 | def farthest_point_sample(xyz, npoint): 116 | """ 117 | Input: 118 | xyz: pointcloud data, [B, N, 3] 119 | npoint: number of samples 120 | Return: 121 | centroids: sampled pointcloud index, [B, npoint] 122 | """ 123 | device = xyz.device 124 | B, N, C = xyz.shape 125 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 126 | distance = torch.ones(B, N).to(device) * 1e10 127 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 128 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 129 | for i in range(npoint): 130 | centroids[:, i] = farthest 131 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 132 | dist = torch.sum((xyz - centroid) ** 2, -1) 133 | mask = dist < distance 134 | distance[mask] = dist[mask] 135 | farthest = torch.max(distance, -1)[1] 136 | return centroids 137 | 138 | 139 | def z_order_point_sample(xyz, npoints): 140 | """ 141 | Input: 142 | xyz: pointcloud data, [B, N, 3] 143 | npoint: number of samples 144 | Return: 145 | centroids: sampled pointcloud index, [B, npoint] 146 | """ 147 | device = xyz.device 148 | B, N, C = xyz.shape 149 | 150 | if npoints >= N: 151 | return torch.linspace(0, N, steps=N, dtype=int).view(1, N).repeat(B, 1) 152 | 153 | centroids = torch.zeros(B, npoints, dtype=int).to(device) 154 | for batch_idx in range(B): 155 | z = get_z_values(xyz[batch_idx]) 156 | z = np.argsort(z) 157 | centroids[batch_idx, :] = torch.from_numpy( 158 | z[torch.linspace(0, N - 1, steps=npoints, dtype=int)].reshape(1, npoints)) 159 | 160 | return centroids 161 | -------------------------------------------------------------------------------- /paint.py: -------------------------------------------------------------------------------- 1 | import re 2 | from matplotlib import * 3 | import matplotlib.pyplot as plt 4 | import os 5 | import numpy as np 6 | import pandas as pd 7 | from scipy import interpolate 8 | from scipy.interpolate import make_interp_spline 9 | 10 | pathlist = [ 11 | r'C:\Users\LENOVO\Desktop\Pointnet\Pointnet_Pointnet2_pytorch-master\log\classification\test10\logs\pointnet_srn_cls_msg.txt'] 12 | 13 | 14 | def read_log(log_path_list, logs, train_acc=True, test_acc=True, loss=True): 15 | 16 | fig = plt.figure(num=1, figsize=(8, 4)) 17 | ax1 = fig.add_subplot(121) 18 | ax2 = fig.add_subplot(122) 19 | 20 | # 绘画处理 21 | ax1.set_xlabel('epoch') 22 | ax1.set_title('accuracy') 23 | ax1.axis([0, 200, 0.9, 0.95]) 24 | ax2.set_xlabel('epoch') 25 | ax2.set_title('total_loss') 26 | ax2.axis([0, 100, 0, 300]) 27 | for idx, path in enumerate(log_path_list): 28 | log_dir = path 29 | print(path) 30 | with open(log_dir, "r", encoding="utf-8") as f: 31 | content = f.read() 32 | 33 | epoch = re.findall(r'Epoch \d[0-9]*', content, re.M) 34 | Train_Instance_Accuracy = re.findall(r'Train Instance Accuracy: .*', content, re.M) 35 | if loss: 36 | Total_loss = re.findall(r'Total loss: .*', content, re.M) 37 | Test_Instance_Accuracy = re.findall(r'Test Instance Accuracy: .*,', content, re.M) 38 | 39 | data_length = min([len(epoch), len(Train_Instance_Accuracy), len(Test_Instance_Accuracy)]) 40 | for i in range(data_length): 41 | epoch[i] = int(epoch[i].strip('Epoch ')) 42 | Train_Instance_Accuracy[i] = float(Train_Instance_Accuracy[i].strip('Train Instance Accuracy: ')) 43 | if loss: 44 | Total_loss[i] = float(Total_loss[i].strip('Total loss: ')) 45 | Test_Instance_Accuracy[i] = float(Test_Instance_Accuracy[i].strip('Test Instance Accuracy: ,')) 46 | plt.rcParams["font.family"] = "SimHei" 47 | if loss: 48 | total_loss_x = epoch[:data_length] 49 | total_loss_y = Total_loss[:data_length] 50 | 51 | test_accuracy_x = epoch[:data_length] 52 | test_accuracy_y = Test_Instance_Accuracy[:data_length] 53 | 54 | train_accuracy_x = epoch[:data_length] 55 | train_accuracy_y = Train_Instance_Accuracy[:data_length] 56 | if test_acc: 57 | # test_accuracy_x = np.array(test_accuracy_x) 58 | # test_accuracy_y = np.array(test_accuracy_x) 59 | # x_new = np.linspace(test_accuracy_x.min(), test_accuracy_x.max(), 60 | # 1000) # 1000 represents number of points to make between T.min and T.max 61 | # y_smooth = make_interp_spline(test_accuracy_x, test_accuracy_y)(x_new) 62 | ax1.plot(test_accuracy_x, test_accuracy_y, label='test_accuracy_{}'.format(logs[idx])) 63 | # ax1.plot(x_new, y_smooth, label='test_accuracy_{}'.format(logs[idx])) 64 | 65 | ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True) 66 | if train_acc: 67 | ax1.plot(train_accuracy_x, train_accuracy_y, label='train_accuracy_{}'.format(logs[idx])) 68 | ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True) 69 | if loss: 70 | ax2.plot(total_loss_x, total_loss_y, label='loss_{}'.format(logs[idx])) 71 | ax2.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True) 72 | plt.show() 73 | 74 | 75 | def read_all_paths(log_dir): 76 | class_dirs = os.listdir(log_dir) 77 | paths = [os.path.join(log_dir, i) + r'\logs' for i in class_dirs] 78 | all_paths = [os.path.join(i, os.listdir(i)[0]) for i in paths] 79 | return all_paths 80 | 81 | 82 | def read_log_path(log_name, classfication=True): 83 | log_dir = [] 84 | if classfication: 85 | for i in range(len(log_name)): 86 | log_dir.append('log/classification/' + log_name[i] + "/logs/" + \ 87 | os.listdir('log/classification/' + log_name[i] + r'\logs')[0]) 88 | return log_dir 89 | 90 | 91 | if __name__ == "__main__": 92 | logs = ['Ours', 'Pointnet++'] 93 | lists = read_log_path(logs) 94 | read_log(lists, logs, train_acc=False, loss=True) 95 | 96 | # 提取数据 97 | # log_dir = r'C:\Users\LENOVO\Desktop\Pointnet\Pointnet_Pointnet2_pytorch-master\log\classification\test10\logs\pointnet_srn_cls_msg.txt' 98 | # 99 | # with open(log_dir, "r", encoding="utf-8") as f: 100 | # content = f.read() 101 | # epoch = re.findall(r'Epoch \d[0-9]*', content, re.M) 102 | # Train_Instance_Accuracy = re.findall(r'Train Instance Accuracy: .*', content, re.M) 103 | # Total_loss = re.findall(r'Total loss: .*', content, re.M) 104 | # Test_Instance_Accuracy = re.findall(r'Test Instance Accuracy: .*,', content, re.M) 105 | # for i in range(len(epoch)): 106 | # epoch[i] = int(epoch[i].strip('Epoch ')) 107 | # Train_Instance_Accuracy[i] = float(Train_Instance_Accuracy[i].strip('Train Instance Accuracy: ')) 108 | # Total_loss[i] = float(Total_loss[i].strip('Total loss: ')) 109 | # Test_Instance_Accuracy[i] = float(Test_Instance_Accuracy[i].strip('Test Instance Accuracy: ,')) 110 | # 111 | # plt.rcParams["font.family"] = "SimHei" 112 | # total_loss_x = epoch 113 | # total_loss_y = Total_loss 114 | # test_accuracy_x = epoch 115 | # test_accuracy_y = Test_Instance_Accuracy 116 | # train_accuracy_x = epoch 117 | # train_accuracy_y = Train_Instance_Accuracy 118 | # 119 | # fig = plt.figure(num=1, figsize=(4, 4)) 120 | # ax1 = fig.add_subplot(121) 121 | # ax2 = fig.add_subplot(122) 122 | # # 绘画处理 123 | # ax1.set_xlabel('epoch') 124 | # ax1.set_title('accuracy') 125 | # ax2.set_xlabel('epoch') 126 | # ax2.set_title('total_loss') 127 | # 128 | # ax1.plot(test_accuracy_x, test_accuracy_y, "r-", label='test_accuracy') 129 | # 130 | # ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True) 131 | # ax1.plot(train_accuracy_x, train_accuracy_y, "b-", label='train_accuracy') 132 | # ax1.legend(loc='best', labelspacing=1, handlelength=4, fontsize=14, shadow=True) 133 | # 134 | # ax2.plot(total_loss_x, total_loss_y, "c-") 135 | # plt.show() 136 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 6 | import argparse 7 | import numpy as np 8 | import os 9 | import torch 10 | import logging 11 | from tqdm import tqdm 12 | import sys 13 | import importlib 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | ROOT_DIR = BASE_DIR 17 | 18 | 19 | def parse_args(): 20 | '''PARAMETERS''' 21 | parser = argparse.ArgumentParser('Testing') 22 | parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode') 23 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 24 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') 25 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') 26 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number') 27 | parser.add_argument('--log_dir', type=str, required=True, help='Experiment root') 28 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') 29 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling') 30 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting') 31 | return parser.parse_args() 32 | 33 | 34 | def test(model, loader, num_class=40, vote_num=1): 35 | mean_correct = [] 36 | classifier = model.eval() 37 | class_acc = np.zeros((num_class, 3)) 38 | 39 | for j, (points, target) in tqdm(enumerate(loader), total=len(loader)): 40 | if not args.use_cpu: 41 | points, target = points.cuda(), target.cuda() 42 | 43 | # input() 44 | points = points.transpose(2, 1) 45 | vote_pool = torch.zeros(target.size()[0], num_class).cuda() 46 | 47 | for _ in range(vote_num): 48 | pred, _, = classifier(points) 49 | vote_pool += pred 50 | pred = vote_pool / vote_num 51 | 52 | pred_choice = pred.data.max(1)[1] 53 | # print(pred_choice) 54 | # print("pred_choice.shape", pred_choice.shape) 55 | for cat in np.unique(target.cpu()): 56 | # print('\n',cat.shape) 57 | # print(pred_choice[target == cat]) 58 | classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum() 59 | # print("classacc.shape", classacc.shape) 60 | class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0]) 61 | class_acc[cat, 1] += 1 62 | 63 | correct = pred_choice.eq(target.long().data).cpu().sum() 64 | mean_correct.append(correct.item() / float(points.size()[0])) 65 | 66 | class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1] 67 | 68 | class_mean_acc = np.mean(class_acc[:, 2]) 69 | instance_acc = np.mean(mean_correct) 70 | return instance_acc, class_mean_acc, class_acc 71 | 72 | 73 | def main(args): 74 | def log_string(str): 75 | logger.info(str) 76 | print(str) 77 | 78 | '''HYPER PARAMETER''' 79 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 80 | 81 | '''CREATE DIR''' 82 | experiment_dir = 'log/classification/' + args.log_dir 83 | sys.path.append(experiment_dir) 84 | '''LOG''' 85 | args = parse_args() 86 | logger = logging.getLogger("Model") 87 | logger.setLevel(logging.INFO) 88 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 89 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) 90 | file_handler.setLevel(logging.INFO) 91 | file_handler.setFormatter(formatter) 92 | logger.addHandler(file_handler) 93 | log_string('PARAMETER ...') 94 | log_string(args) 95 | 96 | '''DATA LOADING''' 97 | log_string('Load dataset ...') 98 | data_path = 'data/modelnet40_normal_resampled/' 99 | 100 | test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False) 101 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 102 | num_workers=10) 103 | 104 | '''MODEL LOADING''' 105 | num_class = args.num_category 106 | model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0] 107 | log_string(model_name) 108 | model = importlib.import_module(model_name) 109 | 110 | classifier = model.get_model(num_class, normal_channel=args.use_normals) 111 | 112 | # for name, parameters in classifier.named_parameters(): 113 | # print(name, ':', parameters) 114 | if not args.use_cpu: 115 | classifier = classifier.cuda() 116 | 117 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 118 | classifier.load_state_dict(checkpoint['model_state_dict']) 119 | 120 | if num_class == 10: 121 | catfile = os.path.join('./data/modelnet40_normal_resampled', 'modelnet10_shape_names.txt') 122 | else: 123 | catfile = os.path.join('./data/modelnet40_normal_resampled', 'modelnet40_shape_names.txt') 124 | 125 | cats = [line.rstrip() for line in open(catfile)] 126 | cls_to_tag = dict(zip(range(len(cats)), cats)) 127 | 128 | with torch.no_grad(): 129 | instance_acc, class_mean_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes, 130 | num_class=num_class) 131 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_mean_acc)) 132 | 133 | for i in range(num_class): 134 | log_string('Class %s Accuracy: %f' % (cls_to_tag[i], class_acc[i, 2])) 135 | 136 | 137 | if __name__ == '__main__': 138 | args = parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | 6 | import os 7 | import sys 8 | import torch 9 | import numpy as np 10 | 11 | import datetime 12 | import logging 13 | import provider 14 | import importlib 15 | import shutil 16 | import argparse 17 | 18 | from pathlib import Path 19 | from tqdm import tqdm 20 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 21 | from torch.utils.tensorboard import SummaryWriter 22 | # from torchstat import stat 23 | # from thop import profile 24 | # from thop import clever_format 25 | 26 | # default `log_dir` is "runs" - we'll be more specific here 27 | 28 | 29 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 30 | ROOT_DIR = BASE_DIR 31 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 32 | 33 | 34 | def parse_args(): 35 | '''PARAMETERS''' 36 | parser = argparse.ArgumentParser('training') 37 | parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode') 38 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 39 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') 40 | parser.add_argument('--model', default='SCNet', help='model name [default: SCNet]') 41 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') 42 | parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training') 43 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') 44 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number') 45 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training') 46 | parser.add_argument('--log_dir', type=str, default=None, help='experiment root') 47 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate') 48 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') 49 | parser.add_argument('--process_data', action='store_true', default=False, help='save data offline') 50 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling') 51 | return parser.parse_args() 52 | 53 | 54 | def inplace_relu(m): 55 | classname = m.__class__.__name__ 56 | if classname.find('ReLU') != -1: 57 | m.inplace = True 58 | 59 | 60 | def test(model, loader, num_class=40): 61 | mean_correct = [] 62 | class_acc = np.zeros((num_class, 3)) 63 | classifier = model.eval() 64 | 65 | for j, (points, target) in tqdm(enumerate(loader), total=len(loader)): 66 | 67 | if not args.use_cpu: 68 | points, target = points.cuda(), target.cuda() 69 | 70 | points = points.transpose(2, 1) 71 | pred, _ = classifier(points) 72 | pred_choice = pred.data.max(1)[1] 73 | 74 | for cat in np.unique(target.cpu()): 75 | classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum() 76 | class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0]) 77 | class_acc[cat, 1] += 1 78 | 79 | correct = pred_choice.eq(target.long().data).cpu().sum() 80 | mean_correct.append(correct.item() / float(points.size()[0])) 81 | 82 | class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1] 83 | class_acc = np.mean(class_acc[:, 2]) 84 | instance_acc = np.mean(mean_correct) 85 | 86 | return instance_acc, class_acc 87 | 88 | 89 | def main(args): 90 | def log_string(str): 91 | logger.info(str) 92 | print(str) 93 | 94 | '''HYPER PARAMETER''' 95 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 96 | 97 | '''CREATE DIR''' 98 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 99 | exp_dir = Path('./log/') 100 | exp_dir.mkdir(exist_ok=True) 101 | exp_dir = exp_dir.joinpath('classification') 102 | exp_dir.mkdir(exist_ok=True) 103 | if args.log_dir is None: 104 | exp_dir = exp_dir.joinpath(timestr) 105 | else: 106 | exp_dir = exp_dir.joinpath(args.log_dir) 107 | exp_dir.mkdir(exist_ok=True) 108 | checkpoints_dir = exp_dir.joinpath('checkpoints/') 109 | checkpoints_dir.mkdir(exist_ok=True) 110 | log_dir = exp_dir.joinpath('logs/') 111 | log_dir.mkdir(exist_ok=True) 112 | 113 | writer_dir = exp_dir.joinpath('SummaryWriter/') 114 | log_dir.mkdir(exist_ok=True) 115 | writer = SummaryWriter(writer_dir) 116 | 117 | '''LOG''' 118 | args = parse_args() 119 | logger = logging.getLogger("Model") 120 | logger.setLevel(logging.INFO) 121 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 122 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 123 | file_handler.setLevel(logging.INFO) 124 | file_handler.setFormatter(formatter) 125 | logger.addHandler(file_handler) 126 | log_string('PARAMETER ...') 127 | log_string(args) 128 | 129 | '''DATA LOADING''' 130 | log_string('Load dataset ...') 131 | data_path = 'data/modelnet40_normal_resampled/' 132 | 133 | train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data) 134 | test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data) 135 | trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 136 | num_workers=10, drop_last=True) 137 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 138 | num_workers=10) 139 | 140 | '''MODEL LOADING''' 141 | num_class = args.num_category 142 | model = importlib.import_module(args.model) 143 | shutil.copy('./models/%s.py' % args.model, str(exp_dir)) 144 | shutil.copy('models/utils.py', str(exp_dir)) 145 | shutil.copy('./train.py', str(exp_dir)) 146 | 147 | shutil.copy('models/SCNet.py', str(exp_dir)) 148 | shutil.copy('models/z_order.py', str(exp_dir)) 149 | 150 | classifier = model.get_model(num_class, normal_channel=args.use_normals) 151 | criterion = model.get_loss() 152 | classifier.apply(inplace_relu) 153 | 154 | if not args.use_cpu: 155 | classifier = classifier.cuda() 156 | criterion = criterion.cuda() 157 | 158 | try: 159 | checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth') 160 | start_epoch = checkpoint['epoch'] 161 | classifier.load_state_dict(checkpoint['model_state_dict']) 162 | log_string('Use pretrain model') 163 | except: 164 | log_string('No existing model, starting training from scratch...') 165 | start_epoch = 0 166 | 167 | if args.optimizer == 'Adam': 168 | optimizer = torch.optim.Adam( 169 | classifier.parameters(), 170 | lr=args.learning_rate, 171 | betas=(0.9, 0.999), 172 | eps=1e-08, 173 | weight_decay=args.decay_rate 174 | ) 175 | else: 176 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 177 | 178 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) 179 | global_epoch = 0 180 | global_step = 0 181 | best_instance_acc = 0.0 182 | best_class_acc = 0.0 183 | 184 | '''TRANING''' 185 | logger.info('Start training...') 186 | for epoch in range(start_epoch, args.epoch): 187 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 188 | mean_correct = [] 189 | classifier = classifier.train() 190 | total_loss = 0 191 | scheduler.step() 192 | for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), 193 | smoothing=0.9): 194 | optimizer.zero_grad() 195 | 196 | points = points.data.numpy() 197 | points = provider.random_point_dropout(points) 198 | points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) 199 | points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) 200 | points = torch.Tensor(points) 201 | points = points.transpose(2, 1) 202 | 203 | if not args.use_cpu: 204 | points, target = points.cuda(), target.cuda() 205 | # flops, params = profile(classifier, (points,)) 206 | # flops, params = clever_format([flops, params], "%.3f") 207 | # print(flops, params) 208 | pred, trans_feat = classifier(points) 209 | loss = criterion(pred, target.long(), trans_feat) 210 | total_loss += loss.item() 211 | 212 | pred_choice = pred.data.max(1)[1] 213 | 214 | correct = pred_choice.eq(target.long().data).cpu().sum() 215 | mean_correct.append(correct.item() / float(points.size()[0])) 216 | loss.backward() 217 | optimizer.step() 218 | global_step += 1 219 | 220 | train_instance_acc = np.mean(mean_correct) 221 | log_string('Train Instance Accuracy: %f' % train_instance_acc) 222 | log_string('Total loss: %f' % total_loss) 223 | writer.add_scalar('Accuracy/Train Instance Accuracy', train_instance_acc, epoch + 1) 224 | writer.add_scalar('Total loss', total_loss, epoch + 1) 225 | with torch.no_grad(): 226 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class) 227 | 228 | if (instance_acc >= best_instance_acc): 229 | best_instance_acc = instance_acc 230 | best_epoch = epoch + 1 231 | 232 | if (class_acc >= best_class_acc): 233 | best_class_acc = class_acc 234 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc)) 235 | writer.add_scalar('Accuracy/Test Instance Accuracy', instance_acc, epoch + 1) 236 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc)) 237 | 238 | if (instance_acc >= best_instance_acc): 239 | logger.info('Save model...') 240 | savepath = str(checkpoints_dir) + '/best_model.pth' 241 | log_string('Saving at %s' % savepath) 242 | state = { 243 | 'epoch': best_epoch, 244 | 'instance_acc': instance_acc, 245 | 'class_acc': class_acc, 246 | 'model_state_dict': classifier.state_dict(), 247 | 'optimizer_state_dict': optimizer.state_dict(), 248 | } 249 | torch.save(state, savepath) 250 | global_epoch += 1 251 | 252 | logger.info('End of training...') 253 | 254 | 255 | if __name__ == '__main__': 256 | args = parse_args() 257 | main(args) 258 | --------------------------------------------------------------------------------