├── .gitattributes ├── .gitignore ├── .vscode └── settings.json ├── README.md ├── data_utils ├── ModelNetDataLoader.py ├── S3DISDataLoader.py ├── ShapeNetDataLoader.py ├── __init__.py ├── collect_indoor3d_data.py ├── indoor3d_util.py └── meta │ ├── anno_paths.txt │ └── class_names.txt ├── image ├── plane1.png ├── plane2.png └── psn.png ├── models ├── PointSamplingNet.py ├── __init__.py ├── pointnet2_cls_ssg_psn.py ├── pointnet2_part_seg_ssg_psn.py ├── pointnet2_sem_seg_psn.py └── pointnet_util_psn.py ├── provider.py ├── test_cls.py ├── test_partseg.py ├── test_semseg.py ├── train_cls.py ├── train_partseg.py ├── train_semseg.py └── visualizer ├── build.sh ├── eulerangles.py ├── pc_utils.py ├── plyfile.py ├── render_balls_so.cpp └── show3d_balls.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point Sampling Net: Fast Subsampling and Local Grouping for Deep Learning on Point Cloud 2 | ### (Anonymous Currently) 3 | 4 | This repository is the implementation for our paper :
5 | *Point Sampling Net: Fast Subsampling and Local Grouping for Deep Learning on Point Cloud* 6 | 7 | ## Introduction 8 | ![Architecture of Point Sampling Net](https://github.com/psn-anonymous/PointSamplingNet/blob/master/image/psn.png "Architecture of Point Sampling Net")
9 | **Point Sampling Net** is a differentiable fast grouping and sampling method for deep learning on point cloud, which can be applied to mainstream point cloud deep learning models. Point Sampling Net perform grouping and sampling tasks at the same time. It does not use the relationship between points as a grouping reference, so that the inference speed is independent of the number of points, and friendly to parallel implementation, that reduces the time consumption of sampling and grouping effectively.
10 | Point Sampling Net has been tested on PointNet++ [[1](#Reference)], PointConv [[2](#Reference)], RS-CNN [[3](#Reference)], GAC [[4](#Reference)]. There is not obvious adverse effects on these deep learning models of classification, part segmentation, and scene segmentation tasks and the speed of training and inference has been significantly improved. 11 | 12 | # Usage 13 | The [**CORE FILE**](https://github.com/psn-anonymous/PointSamplingNet/blob/master/models/PointSamplingNet.py) of Point Sampling Net: [models/PointSamplingNet.py](https://github.com/psn-anonymous/PointSamplingNet/blob/master/models/PointSamplingNet.py) 14 | 15 | ## Software Dependencies 16 | Python 3.7 or newer
17 | PyTorch 1.5 or newer
18 | NVIDIA® CUDA® Toolkit 9.2 or newer
19 | NVIDIA® CUDA® Deep Neural Network library (cuDNN) 7.2 or newer
20 |
21 | You can build the software dependencies through **conda** easily 22 | ```shell 23 | conda install pytorch cudatoolkit cudnn -c pytorch 24 | ``` 25 | 26 | ## Import Point Sampling Net PyTorch Module 27 | You may import PSN pytorch module by: 28 | ```python 29 | import PointSmaplingNet as psn 30 | ``` 31 | ## Native PSN 32 | ### Defining 33 | ```python 34 | psn_layer = psn.PointSamplingNet(num_to_sample = 512, max_local_num = 32, mlp = [32, 256]) 35 | ``` 36 | Attribute *mlp* is the middle channels of PSN, because the channel of first layer and last layer must be 3 and sampling number. 37 | ### Forward Propagation 38 | ```python 39 | sampled_points, grouped_points, sampled_feature, grouped_feature = psn_layer(coordinate = {coordinates of point cloud}, feature = {feature of point cloud}) 40 | ``` 41 | *sampled_points* is the sampled points, *grouped_points* is the grouped points.
42 | *sampled_feature* is the sampled feature, *grouped_feature* is the grouped feature.
43 | *{coordinates of point cloud}* is a torch.Tensor object, its shape is [*batch size*, *number of points*, *3*]
44 | *{feature of point cloud}* is a torch.Tensor object, , its shape is [*batch size*, *number of points*, *D*]. 45 | 46 | ## PSN with Multi-Scale Grouping 47 | ### Defining 48 | ```python 49 | psn_msg_layer = psn.PointSamplingNetMSG(num_to_sample = 512, msg_n = [32, 64], mlp = [32, 256]) 50 | ``` 51 | Attribute *msg_n* is the list of multi-scale *n*. 52 | ### Forward Propagation 53 | ```python 54 | sampled_points, grouped_points_msg, sampled_feature, grouped_feature_msg = psn_msg_layer(coordinate = {coordinates of point cloud}, feature = {feature of point cloud}) 55 | ``` 56 | *sampled_points* is the sampled points, *grouped_points_msg* is the list of mutil-scale grouped points.
57 | *sampled_feature* is the sampled feature, *grouped_feature_msg* is the list of mutil-scale the grouped feature. 58 | 59 | 60 | # Visualize Effect 61 | ### Sampling 62 | ![Visualize of Sampling](https://github.com/psn-anonymous/PointSamplingNet/blob/master/image/plane1.png "Visualize of Sampling") 63 | ### Grouping 64 | ![Visualize of Grouping](https://github.com/psn-anonymous/PointSamplingNet/blob/master/image/plane2.png "Visualize of Grouping") 65 | 66 | # The Experiment on Deep Learning Networks 67 | There is an experiment on PointNet++ 68 | ## Environments 69 | This experiment has been tested on follow environments: 70 | ### Software 71 | Canonical Ubuntu 20.04.1 LTS / Microsoft Windows 10 Pro
72 | Python 3.8.5
73 | PyTorch 1.7.0
74 | NVIDIA® CUDA® Toolkit 10.2.89
75 | NVIDIA® CUDA® Deep Neural Network library (cuDNN) 7.6.5
76 | 77 | ### Hardware 78 | Intel® Core™ i9-9900K Processor (16M Cache, up to 5.00 GHz)
79 | 64GB DDR4 RAM
80 | NVIDIA® TITAN RTX™ 81 | 82 | ## Classification 83 | ### Data Preparation 84 | Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`. 85 | 86 | ### Run 87 | ```shell 88 | python train_cls.py --log_dir [your log dir] 89 | ``` 90 | 91 | ## Part Segmentation 92 | ### Data Preparation 93 | Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and save in `data/shapenetcore_partanno_segmentation_benchmark_v0_normal/`. 94 | ### Run 95 | ```shell 96 | python train_partseg.py --normal --log_dir [your log dir] 97 | ``` 98 | 99 | ## Semantic Segmentation 100 | ### Data Preparation 101 | Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/Stanford3dDataset_v1.2_Aligned_Version/`. 102 | ```shell 103 | cd data_utils 104 | python collect_indoor3d_data.py 105 | ``` 106 | Processed data will save in `data/stanford_indoor3d/`. 107 | ### Run 108 | ```shell 109 | python train_semseg.py --log_dir [your log dir] 110 | python test_semseg.py --log_dir [your log dir] --test_area 5 --visual 111 | ``` 112 | 113 | ## Experiment Reference 114 | This implementation of experiment is heavily reference to [yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)
115 | Thanks very much ! 116 | 117 | # Reference 118 | [1] Qi, Charles Ruizhongtai, et al. “[PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-se).” *Advances in Neural Information Processing Systems*, 2017, pp. 5099–5108. [[PDF](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf)]
119 | [2] Wu, Wenxuan, et al. “[PointConv: Deep Convolutional Networks on 3D Point Clouds](http://openaccess.thecvf.com/content_CVPR_2019/html/Wu_PointConv_Deep_Convolutional_Networks_on_3D_Point_Clouds_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 9621–9630. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_PointConv_Deep_Convolutional_Networks_on_3D_Point_Clouds_CVPR_2019_paper.pdf)]
120 | [3] Liu, Yongcheng, et al. “[Relation-Shape Convolutional Neural Network for Point Cloud Analysis](http://openaccess.thecvf.com/content_CVPR_2019/html/Liu_Relation-Shape_Convolutional_Neural_Network_for_Point_Cloud_Analysis_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 8895–8904. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_Relation-Shape_Convolutional_Neural_Network_for_Point_Cloud_Analysis_CVPR_2019_paper.pdf)]
121 | [4] Wang, Lei, et al. “[Graph Attention Convolution for Point Cloud Semantic Segmentation](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Graph_Attention_Convolution_for_Point_Cloud_Semantic_Segmentation_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 10296–10305. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Graph_Attention_Convolution_for_Point_Cloud_Semantic_Segmentation_CVPR_2019_paper.pdf)] -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | import os 4 | from torch.utils.data import Dataset 5 | warnings.filterwarnings('ignore') 6 | 7 | 8 | 9 | def pc_normalize(pc): 10 | centroid = np.mean(pc, axis=0) 11 | pc = pc - centroid 12 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 13 | pc = pc / m 14 | return pc 15 | 16 | def farthest_point_sample(point, npoint): 17 | """ 18 | Input: 19 | xyz: pointcloud data, [N, D] 20 | npoint: number of samples 21 | Return: 22 | centroids: sampled pointcloud index, [npoint, D] 23 | """ 24 | N, D = point.shape 25 | xyz = point[:,:3] 26 | centroids = np.zeros((npoint,)) 27 | distance = np.ones((N,)) * 1e10 28 | farthest = np.random.randint(0, N) 29 | for i in range(npoint): 30 | centroids[i] = farthest 31 | centroid = xyz[farthest, :] 32 | dist = np.sum((xyz - centroid) ** 2, -1) 33 | mask = dist < distance 34 | distance[mask] = dist[mask] 35 | farthest = np.argmax(distance, -1) 36 | point = point[centroids.astype(np.int32)] 37 | return point 38 | 39 | class ModelNetDataLoader(Dataset): 40 | def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): 41 | self.root = root 42 | self.npoints = npoint 43 | self.uniform = uniform 44 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 45 | 46 | self.cat = [line.rstrip() for line in open(self.catfile)] 47 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 48 | self.normal_channel = normal_channel 49 | 50 | shape_ids = {} 51 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 52 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 53 | 54 | assert (split == 'train' or split == 'test') 55 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 56 | # list of (shape_name, shape_txt_file_path) tuple 57 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 58 | in range(len(shape_ids[split]))] 59 | print('The size of %s data is %d'%(split,len(self.datapath))) 60 | 61 | self.cache_size = cache_size # how many data points to cache in memory 62 | self.cache = {} # from index to (point_set, cls) tuple 63 | 64 | def __len__(self): 65 | return len(self.datapath) 66 | 67 | def _get_item(self, index): 68 | if index in self.cache: 69 | point_set, cls = self.cache[index] 70 | else: 71 | fn = self.datapath[index] 72 | cls = self.classes[self.datapath[index][0]] 73 | cls = np.array([cls]).astype(np.int32) 74 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 75 | if self.uniform: 76 | point_set = farthest_point_sample(point_set, self.npoints) 77 | else: 78 | point_set = point_set[0:self.npoints,:] 79 | 80 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 81 | 82 | if not self.normal_channel: 83 | point_set = point_set[:, 0:3] 84 | 85 | if len(self.cache) < self.cache_size: 86 | self.cache[index] = (point_set, cls) 87 | 88 | return point_set, cls 89 | 90 | def __getitem__(self, index): 91 | return self._get_item(index) 92 | 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | import torch 98 | 99 | data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,) 100 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) 101 | for point,label in DataLoader: 102 | print(point.shape) 103 | print(label.shape) -------------------------------------------------------------------------------- /data_utils/S3DISDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class S3DISDataset(Dataset): 7 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None): 8 | super().__init__() 9 | self.num_point = num_point 10 | self.block_size = block_size 11 | self.transform = transform 12 | rooms = sorted(os.listdir(data_root)) 13 | rooms = [room for room in rooms if 'Area_' in room] 14 | if split == 'train': 15 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room] 16 | else: 17 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room] 18 | self.room_points, self.room_labels = [], [] 19 | self.room_coord_min, self.room_coord_max = [], [] 20 | num_point_all = [] 21 | labelweights = np.zeros(13) 22 | for room_name in rooms_split: 23 | room_path = os.path.join(data_root, room_name) 24 | room_data = np.load(room_path) # xyzrgbl, N*7 25 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N 26 | tmp, _ = np.histogram(labels, range(14)) 27 | labelweights += tmp 28 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 29 | self.room_points.append(points), self.room_labels.append(labels) 30 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 31 | num_point_all.append(labels.size) 32 | labelweights = labelweights.astype(np.float32) 33 | labelweights = labelweights / np.sum(labelweights) 34 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 35 | print(self.labelweights) 36 | sample_prob = num_point_all / np.sum(num_point_all) 37 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point) 38 | room_idxs = [] 39 | for index in range(len(rooms_split)): 40 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter))) 41 | self.room_idxs = np.array(room_idxs) 42 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split)) 43 | 44 | def __getitem__(self, idx): 45 | room_idx = self.room_idxs[idx] 46 | points = self.room_points[room_idx] # N * 6 47 | labels = self.room_labels[room_idx] # N 48 | N_points = points.shape[0] 49 | 50 | while (True): 51 | center = points[np.random.choice(N_points)][:3] 52 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0] 53 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0] 54 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0] 55 | if point_idxs.size > 1024: 56 | break 57 | 58 | if point_idxs.size >= self.num_point: 59 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False) 60 | else: 61 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True) 62 | 63 | # normalize 64 | selected_points = points[selected_point_idxs, :] # num_point * 6 65 | current_points = np.zeros((self.num_point, 9)) # num_point * 9 66 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0] 67 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1] 68 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2] 69 | selected_points[:, 0] = selected_points[:, 0] - center[0] 70 | selected_points[:, 1] = selected_points[:, 1] - center[1] 71 | selected_points[:, 3:6] /= 255.0 72 | current_points[:, 0:6] = selected_points 73 | current_labels = labels[selected_point_idxs] 74 | if self.transform is not None: 75 | current_points, current_labels = self.transform(current_points, current_labels) 76 | return current_points, current_labels 77 | 78 | def __len__(self): 79 | return len(self.room_idxs) 80 | 81 | class ScannetDatasetWholeScene(): 82 | # prepare to give prediction on each points 83 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001): 84 | self.block_points = block_points 85 | self.block_size = block_size 86 | self.padding = padding 87 | self.root = root 88 | self.split = split 89 | self.stride = stride 90 | self.scene_points_num = [] 91 | assert split in ['train', 'test'] 92 | if self.split == 'train': 93 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1] 94 | else: 95 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1] 96 | self.scene_points_list = [] 97 | self.semantic_labels_list = [] 98 | self.room_coord_min, self.room_coord_max = [], [] 99 | for file in self.file_list: 100 | data = np.load(root + file) 101 | points = data[:, :3] 102 | self.scene_points_list.append(data[:, :6]) 103 | self.semantic_labels_list.append(data[:, 6]) 104 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 105 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 106 | assert len(self.scene_points_list) == len(self.semantic_labels_list) 107 | 108 | labelweights = np.zeros(13) 109 | for seg in self.semantic_labels_list: 110 | tmp, _ = np.histogram(seg, range(14)) 111 | self.scene_points_num.append(seg.shape[0]) 112 | labelweights += tmp 113 | labelweights = labelweights.astype(np.float32) 114 | labelweights = labelweights / np.sum(labelweights) 115 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 116 | 117 | def __getitem__(self, index): 118 | point_set_ini = self.scene_points_list[index] 119 | points = point_set_ini[:,:6] 120 | labels = self.semantic_labels_list[index] 121 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 122 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1) 123 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1) 124 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([]) 125 | for index_y in range(0, grid_y): 126 | for index_x in range(0, grid_x): 127 | s_x = coord_min[0] + index_x * self.stride 128 | e_x = min(s_x + self.block_size, coord_max[0]) 129 | s_x = e_x - self.block_size 130 | s_y = coord_min[1] + index_y * self.stride 131 | e_y = min(s_y + self.block_size, coord_max[1]) 132 | s_y = e_y - self.block_size 133 | point_idxs = np.where( 134 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & ( 135 | points[:, 1] <= e_y + self.padding))[0] 136 | if point_idxs.size == 0: 137 | continue 138 | num_batch = int(np.ceil(point_idxs.size / self.block_points)) 139 | point_size = int(num_batch * self.block_points) 140 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True 141 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace) 142 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat)) 143 | np.random.shuffle(point_idxs) 144 | data_batch = points[point_idxs, :] 145 | normlized_xyz = np.zeros((point_size, 3)) 146 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0] 147 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1] 148 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2] 149 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0) 150 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0) 151 | data_batch[:, 3:6] /= 255.0 152 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1) 153 | label_batch = labels[point_idxs].astype(int) 154 | batch_weight = self.labelweights[label_batch] 155 | 156 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch 157 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch 158 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight 159 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs 160 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1])) 161 | label_room = label_room.reshape((-1, self.block_points)) 162 | sample_weight = sample_weight.reshape((-1, self.block_points)) 163 | index_room = index_room.reshape((-1, self.block_points)) 164 | return data_room, label_room, sample_weight, index_room 165 | 166 | def __len__(self): 167 | return len(self.scene_points_list) 168 | 169 | if __name__ == '__main__': 170 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/' 171 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01 172 | 173 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None) 174 | print('point data size:', point_data.__len__()) 175 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape) 176 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape) 177 | import torch, time, random 178 | manual_seed = 123 179 | random.seed(manual_seed) 180 | np.random.seed(manual_seed) 181 | torch.manual_seed(manual_seed) 182 | torch.cuda.manual_seed_all(manual_seed) 183 | def worker_init_fn(worker_id): 184 | random.seed(manual_seed + worker_id) 185 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn) 186 | for idx in range(4): 187 | end = time.time() 188 | for i, (input, target) in enumerate(train_loader): 189 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end)) 190 | end = time.time() -------------------------------------------------------------------------------- /data_utils/ShapeNetDataLoader.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import json 4 | import warnings 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | warnings.filterwarnings('ignore') 8 | 9 | def pc_normalize(pc): 10 | centroid = np.mean(pc, axis=0) 11 | pc = pc - centroid 12 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 13 | pc = pc / m 14 | return pc 15 | 16 | class PartNormalDataset(Dataset): 17 | def __init__(self,root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False): 18 | self.npoints = npoints 19 | self.root = root 20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 21 | self.cat = {} 22 | self.normal_channel = normal_channel 23 | 24 | 25 | with open(self.catfile, 'r') as f: 26 | for line in f: 27 | ls = line.strip().split() 28 | self.cat[ls[0]] = ls[1] 29 | self.cat = {k: v for k, v in self.cat.items()} 30 | self.classes_original = dict(zip(self.cat, range(len(self.cat)))) 31 | 32 | if not class_choice is None: 33 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 34 | # print(self.cat) 35 | 36 | self.meta = {} 37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: 38 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 39 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: 40 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 41 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: 42 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 43 | for item in self.cat: 44 | # print('category', item) 45 | self.meta[item] = [] 46 | dir_point = os.path.join(self.root, self.cat[item]) 47 | fns = sorted(os.listdir(dir_point)) 48 | # print(fns[0][0:-4]) 49 | if split == 'trainval': 50 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 51 | elif split == 'train': 52 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 53 | elif split == 'val': 54 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 55 | elif split == 'test': 56 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 57 | else: 58 | print('Unknown split: %s. Exiting..' % (split)) 59 | exit(-1) 60 | 61 | # print(os.path.basename(fns)) 62 | for fn in fns: 63 | token = (os.path.splitext(os.path.basename(fn))[0]) 64 | self.meta[item].append(os.path.join(dir_point, token + '.txt')) 65 | 66 | self.datapath = [] 67 | for item in self.cat: 68 | for fn in self.meta[item]: 69 | self.datapath.append((item, fn)) 70 | 71 | self.classes = {} 72 | for i in self.cat.keys(): 73 | self.classes[i] = self.classes_original[i] 74 | 75 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels 76 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 77 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 78 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 79 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 80 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 81 | 82 | # for cat in sorted(self.seg_classes.keys()): 83 | # print(cat, self.seg_classes[cat]) 84 | 85 | self.cache = {} # from index to (point_set, cls, seg) tuple 86 | self.cache_size = 20000 87 | 88 | 89 | def __getitem__(self, index): 90 | if index in self.cache: 91 | ppoint_set, cls, seg = self.cache[index] 92 | else: 93 | fn = self.datapath[index] 94 | cat = self.datapath[index][0] 95 | cls = self.classes[cat] 96 | cls = np.array([cls]).astype(np.int32) 97 | data = np.loadtxt(fn[1]).astype(np.float32) 98 | if not self.normal_channel: 99 | point_set = data[:, 0:3] 100 | else: 101 | point_set = data[:, 0:6] 102 | seg = data[:, -1].astype(np.int32) 103 | if len(self.cache) < self.cache_size: 104 | self.cache[index] = (point_set, cls, seg) 105 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 106 | 107 | choice = np.random.choice(len(seg), self.npoints, replace=True) 108 | # resample 109 | point_set = point_set[choice, :] 110 | seg = seg[choice] 111 | 112 | return point_set, cls, seg 113 | 114 | def __len__(self): 115 | return len(self.datapath) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/collect_indoor3d_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('./data_utils') 4 | from indoor3d_util import DATA_PATH, collect_point_label 5 | 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | ROOT_DIR = os.path.dirname(BASE_DIR) 8 | sys.path.append(BASE_DIR) 9 | 10 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))] 11 | anno_paths = [os.path.join(DATA_PATH, p) for p in anno_paths] 12 | 13 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d') 14 | if not os.path.exists(output_folder): 15 | os.mkdir(output_folder) 16 | 17 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually. 18 | for anno_path in anno_paths: 19 | print(anno_path) 20 | try: 21 | elements = anno_path.split('/') 22 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' # Area_1_hallway_1.npy 23 | collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy') 24 | except: 25 | print(anno_path, 'ERROR!!') 26 | -------------------------------------------------------------------------------- /data_utils/indoor3d_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import sys 5 | 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | ROOT_DIR = os.path.dirname(BASE_DIR) 8 | sys.path.append(BASE_DIR) 9 | 10 | DATA_PATH = os.path.join(ROOT_DIR, 'data', 'Stanford3dDataset_v1.2_Aligned_Version') 11 | g_classes = [x.rstrip() for x in open(os.path.join(BASE_DIR, 'meta/class_names.txt'))] 12 | g_class2label = {cls: i for i,cls in enumerate(g_classes)} 13 | g_class2color = {'ceiling': [0,255,0], 14 | 'floor': [0,0,255], 15 | 'wall': [0,255,255], 16 | 'beam': [255,255,0], 17 | 'column': [255,0,255], 18 | 'window': [100,100,255], 19 | 'door': [200,200,100], 20 | 'table': [170,120,200], 21 | 'chair': [255,0,0], 22 | 'sofa': [200,100,100], 23 | 'bookcase': [10,200,100], 24 | 'board': [200,200,200], 25 | 'clutter': [50,50,50]} 26 | g_easy_view_labels = [7,8,9,10,11,1] 27 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes} 28 | 29 | 30 | # ----------------------------------------------------------------------------- 31 | # CONVERT ORIGINAL DATA TO OUR DATA_LABEL FILES 32 | # ----------------------------------------------------------------------------- 33 | 34 | def collect_point_label(anno_path, out_filename, file_format='txt'): 35 | """ Convert original dataset files to data_label file (each line is XYZRGBL). 36 | We aggregated all the points from each instance in the room. 37 | 38 | Args: 39 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 40 | out_filename: path to save collected points and labels (each line is XYZRGBL) 41 | file_format: txt or numpy, determines what file format to save. 42 | Returns: 43 | None 44 | Note: 45 | the points are shifted before save, the most negative point is now at origin. 46 | """ 47 | points_list = [] 48 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 49 | cls = os.path.basename(f).split('_')[0] 50 | print(f) 51 | if cls not in g_classes: # note: in some room there is 'staris' class.. 52 | cls = 'clutter' 53 | 54 | points = np.loadtxt(f) 55 | labels = np.ones((points.shape[0],1)) * g_class2label[cls] 56 | points_list.append(np.concatenate([points, labels], 1)) # Nx7 57 | 58 | data_label = np.concatenate(points_list, 0) 59 | xyz_min = np.amin(data_label, axis=0)[0:3] 60 | data_label[:, 0:3] -= xyz_min 61 | 62 | if file_format=='txt': 63 | fout = open(out_filename, 'w') 64 | for i in range(data_label.shape[0]): 65 | fout.write('%f %f %f %d %d %d %d\n' % \ 66 | (data_label[i,0], data_label[i,1], data_label[i,2], 67 | data_label[i,3], data_label[i,4], data_label[i,5], 68 | data_label[i,6])) 69 | fout.close() 70 | elif file_format=='numpy': 71 | np.save(out_filename, data_label) 72 | else: 73 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 74 | (file_format)) 75 | exit() 76 | 77 | def data_to_obj(data,name='example.obj',no_wall=True): 78 | fout = open(name, 'w') 79 | label = data[:, -1].astype(int) 80 | for i in range(data.shape[0]): 81 | if no_wall and ((label[i] == 2) or (label[i]==0)): 82 | continue 83 | fout.write('v %f %f %f %d %d %d\n' % \ 84 | (data[i, 0], data[i, 1], data[i, 2], data[i, 3], data[i, 4], data[i, 5])) 85 | fout.close() 86 | 87 | def point_label_to_obj(input_filename, out_filename, label_color=True, easy_view=False, no_wall=False): 88 | """ For visualization of a room from data_label file, 89 | input_filename: each line is X Y Z R G B L 90 | out_filename: OBJ filename, 91 | visualize input file by coloring point with label color 92 | easy_view: only visualize furnitures and floor 93 | """ 94 | data_label = np.loadtxt(input_filename) 95 | data = data_label[:, 0:6] 96 | label = data_label[:, -1].astype(int) 97 | fout = open(out_filename, 'w') 98 | for i in range(data.shape[0]): 99 | color = g_label2color[label[i]] 100 | if easy_view and (label[i] not in g_easy_view_labels): 101 | continue 102 | if no_wall and ((label[i] == 2) or (label[i]==0)): 103 | continue 104 | if label_color: 105 | fout.write('v %f %f %f %d %d %d\n' % \ 106 | (data[i,0], data[i,1], data[i,2], color[0], color[1], color[2])) 107 | else: 108 | fout.write('v %f %f %f %d %d %d\n' % \ 109 | (data[i,0], data[i,1], data[i,2], data[i,3], data[i,4], data[i,5])) 110 | fout.close() 111 | 112 | 113 | 114 | # ----------------------------------------------------------------------------- 115 | # PREPARE BLOCK DATA FOR DEEPNETS TRAINING/TESTING 116 | # ----------------------------------------------------------------------------- 117 | 118 | def sample_data(data, num_sample): 119 | """ data is in N x ... 120 | we want to keep num_samplexC of them. 121 | if N > num_sample, we will randomly keep num_sample of them. 122 | if N < num_sample, we will randomly duplicate samples. 123 | """ 124 | N = data.shape[0] 125 | if (N == num_sample): 126 | return data, range(N) 127 | elif (N > num_sample): 128 | sample = np.random.choice(N, num_sample) 129 | return data[sample, ...], sample 130 | else: 131 | sample = np.random.choice(N, num_sample-N) 132 | dup_data = data[sample, ...] 133 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample) 134 | 135 | def sample_data_label(data, label, num_sample): 136 | new_data, sample_indices = sample_data(data, num_sample) 137 | new_label = label[sample_indices] 138 | return new_data, new_label 139 | 140 | def room2blocks(data, label, num_point, block_size=1.0, stride=1.0, 141 | random_sample=False, sample_num=None, sample_aug=1): 142 | """ Prepare block training data. 143 | Args: 144 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 145 | assumes the data is shifted (min point is origin) and aligned 146 | (aligned with XYZ axis) 147 | label: N size uint8 numpy array from 0-12 148 | num_point: int, how many points to sample in each block 149 | block_size: float, physical size of the block in meters 150 | stride: float, stride for block sweeping 151 | random_sample: bool, if True, we will randomly sample blocks in the room 152 | sample_num: int, if random sample, how many blocks to sample 153 | [default: room area] 154 | sample_aug: if random sample, how much aug 155 | Returns: 156 | block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1] 157 | block_labels: K x num_point x 1 np array of uint8 labels 158 | 159 | TODO: for this version, blocking is in fixed, non-overlapping pattern. 160 | """ 161 | assert(stride<=block_size) 162 | 163 | limit = np.amax(data, 0)[0:3] 164 | 165 | # Get the corner location for our sampling blocks 166 | xbeg_list = [] 167 | ybeg_list = [] 168 | if not random_sample: 169 | num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1 170 | num_block_y = int(np.ceil(collect_point_label(limit[1] - block_size) / stride)) + 1 171 | for i in range(num_block_x): 172 | for j in range(num_block_y): 173 | xbeg_list.append(i*stride) 174 | ybeg_list.append(j*stride) 175 | else: 176 | num_block_x = int(np.ceil(limit[0] / block_size)) 177 | num_block_y = int(np.ceil(limit[1] / block_size)) 178 | if sample_num is None: 179 | sample_num = num_block_x * num_block_y * sample_aug 180 | for _ in range(sample_num): 181 | xbeg = np.random.uniform(-block_size, limit[0]) 182 | ybeg = np.random.uniform(-block_size, limit[1]) 183 | xbeg_list.append(xbeg) 184 | ybeg_list.append(ybeg) 185 | 186 | # Collect blocks 187 | block_data_list = [] 188 | block_label_list = [] 189 | idx = 0 190 | for idx in range(len(xbeg_list)): 191 | xbeg = xbeg_list[idx] 192 | ybeg = ybeg_list[idx] 193 | xcond = (data[:,0]<=xbeg+block_size) & (data[:,0]>=xbeg) 194 | ycond = (data[:,1]<=ybeg+block_size) & (data[:,1]>=ybeg) 195 | cond = xcond & ycond 196 | if np.sum(cond) < 100: # discard block if there are less than 100 pts. 197 | continue 198 | 199 | block_data = data[cond, :] 200 | block_label = label[cond] 201 | 202 | # randomly subsample data 203 | block_data_sampled, block_label_sampled = \ 204 | sample_data_label(block_data, block_label, num_point) 205 | block_data_list.append(np.expand_dims(block_data_sampled, 0)) 206 | block_label_list.append(np.expand_dims(block_label_sampled, 0)) 207 | 208 | return np.concatenate(block_data_list, 0), \ 209 | np.concatenate(block_label_list, 0) 210 | 211 | 212 | def room2blocks_plus(data_label, num_point, block_size, stride, 213 | random_sample, sample_num, sample_aug): 214 | """ room2block with input filename and RGB preprocessing. 215 | """ 216 | data = data_label[:,0:6] 217 | data[:,3:6] /= 255.0 218 | label = data_label[:,-1].astype(np.uint8) 219 | 220 | return room2blocks(data, label, num_point, block_size, stride, 221 | random_sample, sample_num, sample_aug) 222 | 223 | def room2blocks_wrapper(data_label_filename, num_point, block_size=1.0, stride=1.0, 224 | random_sample=False, sample_num=None, sample_aug=1): 225 | if data_label_filename[-3:] == 'txt': 226 | data_label = np.loadtxt(data_label_filename) 227 | elif data_label_filename[-3:] == 'npy': 228 | data_label = np.load(data_label_filename) 229 | else: 230 | print('Unknown file type! exiting.') 231 | exit() 232 | return room2blocks_plus(data_label, num_point, block_size, stride, 233 | random_sample, sample_num, sample_aug) 234 | 235 | def room2blocks_plus_normalized(data_label, num_point, block_size, stride, 236 | random_sample, sample_num, sample_aug): 237 | """ room2block, with input filename and RGB preprocessing. 238 | for each block centralize XYZ, add normalized XYZ as 678 channels 239 | """ 240 | data = data_label[:,0:6] 241 | data[:,3:6] /= 255.0 242 | label = data_label[:,-1].astype(np.uint8) 243 | max_room_x = max(data[:,0]) 244 | max_room_y = max(data[:,1]) 245 | max_room_z = max(data[:,2]) 246 | 247 | data_batch, label_batch = room2blocks(data, label, num_point, block_size, stride, 248 | random_sample, sample_num, sample_aug) 249 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 250 | for b in range(data_batch.shape[0]): 251 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 252 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 253 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 254 | minx = min(data_batch[b, :, 0]) 255 | miny = min(data_batch[b, :, 1]) 256 | data_batch[b, :, 0] -= (minx+block_size/2) 257 | data_batch[b, :, 1] -= (miny+block_size/2) 258 | new_data_batch[:, :, 0:6] = data_batch 259 | return new_data_batch, label_batch 260 | 261 | 262 | def room2blocks_wrapper_normalized(data_label_filename, num_point, block_size=1.0, stride=1.0, 263 | random_sample=False, sample_num=None, sample_aug=1): 264 | if data_label_filename[-3:] == 'txt': 265 | data_label = np.loadtxt(data_label_filename) 266 | elif data_label_filename[-3:] == 'npy': 267 | data_label = np.load(data_label_filename) 268 | else: 269 | print('Unknown file type! exiting.') 270 | exit() 271 | return room2blocks_plus_normalized(data_label, num_point, block_size, stride, 272 | random_sample, sample_num, sample_aug) 273 | 274 | def room2samples(data, label, sample_num_point): 275 | """ Prepare whole room samples. 276 | 277 | Args: 278 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1] 279 | assumes the data is shifted (min point is origin) and 280 | aligned (aligned with XYZ axis) 281 | label: N size uint8 numpy array from 0-12 282 | sample_num_point: int, how many points to sample in each sample 283 | Returns: 284 | sample_datas: K x sample_num_point x 9 285 | numpy array of XYZRGBX'Y'Z', RGB is in [0,1] 286 | sample_labels: K x sample_num_point x 1 np array of uint8 labels 287 | """ 288 | N = data.shape[0] 289 | order = np.arange(N) 290 | np.random.shuffle(order) 291 | data = data[order, :] 292 | label = label[order] 293 | 294 | batch_num = int(np.ceil(N / float(sample_num_point))) 295 | sample_datas = np.zeros((batch_num, sample_num_point, 6)) 296 | sample_labels = np.zeros((batch_num, sample_num_point, 1)) 297 | 298 | for i in range(batch_num): 299 | beg_idx = i*sample_num_point 300 | end_idx = min((i+1)*sample_num_point, N) 301 | num = end_idx - beg_idx 302 | sample_datas[i,0:num,:] = data[beg_idx:end_idx, :] 303 | sample_labels[i,0:num,0] = label[beg_idx:end_idx] 304 | if num < sample_num_point: 305 | makeup_indices = np.random.choice(N, sample_num_point - num) 306 | sample_datas[i,num:,:] = data[makeup_indices, :] 307 | sample_labels[i,num:,0] = label[makeup_indices] 308 | return sample_datas, sample_labels 309 | 310 | def room2samples_plus_normalized(data_label, num_point): 311 | """ room2sample, with input filename and RGB preprocessing. 312 | for each block centralize XYZ, add normalized XYZ as 678 channels 313 | """ 314 | data = data_label[:,0:6] 315 | data[:,3:6] /= 255.0 316 | label = data_label[:,-1].astype(np.uint8) 317 | max_room_x = max(data[:,0]) 318 | max_room_y = max(data[:,1]) 319 | max_room_z = max(data[:,2]) 320 | #print(max_room_x, max_room_y, max_room_z) 321 | 322 | data_batch, label_batch = room2samples(data, label, num_point) 323 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9)) 324 | for b in range(data_batch.shape[0]): 325 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x 326 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y 327 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z 328 | #minx = min(data_batch[b, :, 0]) 329 | #miny = min(data_batch[b, :, 1]) 330 | #data_batch[b, :, 0] -= (minx+block_size/2) 331 | #data_batch[b, :, 1] -= (miny+block_size/2) 332 | new_data_batch[:, :, 0:6] = data_batch 333 | return new_data_batch, label_batch 334 | 335 | 336 | def room2samples_wrapper_normalized(data_label_filename, num_point): 337 | if data_label_filename[-3:] == 'txt': 338 | data_label = np.loadtxt(data_label_filename) 339 | elif data_label_filename[-3:] == 'npy': 340 | data_label = np.load(data_label_filename) 341 | else: 342 | print('Unknown file type! exiting.') 343 | exit() 344 | return room2samples_plus_normalized(data_label, num_point) 345 | 346 | 347 | # ----------------------------------------------------------------------------- 348 | # EXTRACT INSTANCE BBOX FROM ORIGINAL DATA (for detection evaluation) 349 | # ----------------------------------------------------------------------------- 350 | 351 | def collect_bounding_box(anno_path, out_filename): 352 | """ Compute bounding boxes from each instance in original dataset files on 353 | one room. **We assume the bbox is aligned with XYZ coordinate.** 354 | 355 | Args: 356 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 357 | out_filename: path to save instance bounding boxes for that room. 358 | each line is x1 y1 z1 x2 y2 z2 label, 359 | where (x1,y1,z1) is the point on the diagonal closer to origin 360 | Returns: 361 | None 362 | Note: 363 | room points are shifted, the most negative point is now at origin. 364 | """ 365 | bbox_label_list = [] 366 | 367 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 368 | cls = os.path.basename(f).split('_')[0] 369 | if cls not in g_classes: # note: in some room there is 'staris' class.. 370 | cls = 'clutter' 371 | points = np.loadtxt(f) 372 | label = g_class2label[cls] 373 | # Compute tightest axis aligned bounding box 374 | xyz_min = np.amin(points[:, 0:3], axis=0) 375 | xyz_max = np.amax(points[:, 0:3], axis=0) 376 | ins_bbox_label = np.expand_dims( 377 | np.concatenate([xyz_min, xyz_max, np.array([label])], 0), 0) 378 | bbox_label_list.append(ins_bbox_label) 379 | 380 | bbox_label = np.concatenate(bbox_label_list, 0) 381 | room_xyz_min = np.amin(bbox_label[:, 0:3], axis=0) 382 | bbox_label[:, 0:3] -= room_xyz_min 383 | bbox_label[:, 3:6] -= room_xyz_min 384 | 385 | fout = open(out_filename, 'w') 386 | for i in range(bbox_label.shape[0]): 387 | fout.write('%f %f %f %f %f %f %d\n' % \ 388 | (bbox_label[i,0], bbox_label[i,1], bbox_label[i,2], 389 | bbox_label[i,3], bbox_label[i,4], bbox_label[i,5], 390 | bbox_label[i,6])) 391 | fout.close() 392 | 393 | def bbox_label_to_obj(input_filename, out_filename_prefix, easy_view=False): 394 | """ Visualization of bounding boxes. 395 | 396 | Args: 397 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 398 | out_filename_prefix: OBJ filename prefix, 399 | visualize object by g_label2color 400 | easy_view: if True, only visualize furniture and floor 401 | Returns: 402 | output a list of OBJ file and MTL files with the same prefix 403 | """ 404 | bbox_label = np.loadtxt(input_filename) 405 | bbox = bbox_label[:, 0:6] 406 | label = bbox_label[:, -1].astype(int) 407 | v_cnt = 0 # count vertex 408 | ins_cnt = 0 # count instance 409 | for i in range(bbox.shape[0]): 410 | if easy_view and (label[i] not in g_easy_view_labels): 411 | continue 412 | obj_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.obj' 413 | mtl_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.mtl' 414 | fout_obj = open(obj_filename, 'w') 415 | fout_mtl = open(mtl_filename, 'w') 416 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 417 | 418 | length = bbox[i, 3:6] - bbox[i, 0:3] 419 | a = length[0] 420 | b = length[1] 421 | c = length[2] 422 | x = bbox[i, 0] 423 | y = bbox[i, 1] 424 | z = bbox[i, 2] 425 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 426 | 427 | material = 'material%d' % (ins_cnt) 428 | fout_obj.write('usemtl %s\n' % (material)) 429 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 430 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 431 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 432 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 433 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 434 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 435 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 436 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 437 | fout_obj.write('g default\n') 438 | v_cnt = 0 # for individual box 439 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 440 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 441 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 442 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 443 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 444 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 445 | fout_obj.write('\n') 446 | 447 | fout_mtl.write('newmtl %s\n' % (material)) 448 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 449 | fout_mtl.write('\n') 450 | fout_obj.close() 451 | fout_mtl.close() 452 | 453 | v_cnt += 8 454 | ins_cnt += 1 455 | 456 | def bbox_label_to_obj_room(input_filename, out_filename_prefix, easy_view=False, permute=None, center=False, exclude_table=False): 457 | """ Visualization of bounding boxes. 458 | 459 | Args: 460 | input_filename: each line is x1 y1 z1 x2 y2 z2 label 461 | out_filename_prefix: OBJ filename prefix, 462 | visualize object by g_label2color 463 | easy_view: if True, only visualize furniture and floor 464 | permute: if not None, permute XYZ for rendering, e.g. [0 2 1] 465 | center: if True, move obj to have zero origin 466 | Returns: 467 | output a list of OBJ file and MTL files with the same prefix 468 | """ 469 | bbox_label = np.loadtxt(input_filename) 470 | bbox = bbox_label[:, 0:6] 471 | if permute is not None: 472 | assert(len(permute)==3) 473 | permute = np.array(permute) 474 | bbox[:,0:3] = bbox[:,permute] 475 | bbox[:,3:6] = bbox[:,permute+3] 476 | if center: 477 | xyz_max = np.amax(bbox[:,3:6], 0) 478 | bbox[:,0:3] -= (xyz_max/2.0) 479 | bbox[:,3:6] -= (xyz_max/2.0) 480 | bbox /= np.max(xyz_max/2.0) 481 | label = bbox_label[:, -1].astype(int) 482 | obj_filename = out_filename_prefix+'.obj' 483 | mtl_filename = out_filename_prefix+'.mtl' 484 | 485 | fout_obj = open(obj_filename, 'w') 486 | fout_mtl = open(mtl_filename, 'w') 487 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename))) 488 | v_cnt = 0 # count vertex 489 | ins_cnt = 0 # count instance 490 | for i in range(bbox.shape[0]): 491 | if easy_view and (label[i] not in g_easy_view_labels): 492 | continue 493 | if exclude_table and label[i] == g_classes.index('table'): 494 | continue 495 | 496 | length = bbox[i, 3:6] - bbox[i, 0:3] 497 | a = length[0] 498 | b = length[1] 499 | c = length[2] 500 | x = bbox[i, 0] 501 | y = bbox[i, 1] 502 | z = bbox[i, 2] 503 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0 504 | 505 | material = 'material%d' % (ins_cnt) 506 | fout_obj.write('usemtl %s\n' % (material)) 507 | fout_obj.write('v %f %f %f\n' % (x,y,z+c)) 508 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c)) 509 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c)) 510 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c)) 511 | fout_obj.write('v %f %f %f\n' % (x,y,z)) 512 | fout_obj.write('v %f %f %f\n' % (x,y+b,z)) 513 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z)) 514 | fout_obj.write('v %f %f %f\n' % (x+a,y,z)) 515 | fout_obj.write('g default\n') 516 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt)) 517 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt)) 518 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt)) 519 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt)) 520 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt)) 521 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt)) 522 | fout_obj.write('\n') 523 | 524 | fout_mtl.write('newmtl %s\n' % (material)) 525 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2])) 526 | fout_mtl.write('\n') 527 | 528 | v_cnt += 8 529 | ins_cnt += 1 530 | 531 | fout_obj.close() 532 | fout_mtl.close() 533 | 534 | 535 | def collect_point_bounding_box(anno_path, out_filename, file_format): 536 | """ Compute bounding boxes from each instance in original dataset files on 537 | one room. **We assume the bbox is aligned with XYZ coordinate.** 538 | Save both the point XYZRGB and the bounding box for the point's 539 | parent element. 540 | 541 | Args: 542 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/ 543 | out_filename: path to save instance bounding boxes for each point, 544 | plus the point's XYZRGBL 545 | each line is XYZRGBL offsetX offsetY offsetZ a b c, 546 | where cx = X+offsetX, cy=X+offsetY, cz=Z+offsetZ 547 | where (cx,cy,cz) is center of the box, a,b,c are distances from center 548 | to the surfaces of the box, i.e. x1 = cx-a, x2 = cx+a, y1=cy-b etc. 549 | file_format: output file format, txt or numpy 550 | Returns: 551 | None 552 | 553 | Note: 554 | room points are shifted, the most negative point is now at origin. 555 | """ 556 | point_bbox_list = [] 557 | 558 | for f in glob.glob(os.path.join(anno_path, '*.txt')): 559 | cls = os.path.basename(f).split('_')[0] 560 | if cls not in g_classes: # note: in some room there is 'staris' class.. 561 | cls = 'clutter' 562 | points = np.loadtxt(f) # Nx6 563 | label = g_class2label[cls] # N, 564 | # Compute tightest axis aligned bounding box 565 | xyz_min = np.amin(points[:, 0:3], axis=0) # 3, 566 | xyz_max = np.amax(points[:, 0:3], axis=0) # 3, 567 | xyz_center = (xyz_min + xyz_max) / 2 568 | dimension = (xyz_max - xyz_min) / 2 569 | 570 | xyz_offsets = xyz_center - points[:,0:3] # Nx3 571 | dimensions = np.ones((points.shape[0],3)) * dimension # Nx3 572 | labels = np.ones((points.shape[0],1)) * label # N 573 | point_bbox_list.append(np.concatenate([points, labels, 574 | xyz_offsets, dimensions], 1)) # Nx13 575 | 576 | point_bbox = np.concatenate(point_bbox_list, 0) # KxNx13 577 | room_xyz_min = np.amin(point_bbox[:, 0:3], axis=0) 578 | point_bbox[:, 0:3] -= room_xyz_min 579 | 580 | if file_format == 'txt': 581 | fout = open(out_filename, 'w') 582 | for i in range(point_bbox.shape[0]): 583 | fout.write('%f %f %f %d %d %d %d %f %f %f %f %f %f\n' % \ 584 | (point_bbox[i,0], point_bbox[i,1], point_bbox[i,2], 585 | point_bbox[i,3], point_bbox[i,4], point_bbox[i,5], 586 | point_bbox[i,6], 587 | point_bbox[i,7], point_bbox[i,8], point_bbox[i,9], 588 | point_bbox[i,10], point_bbox[i,11], point_bbox[i,12])) 589 | 590 | fout.close() 591 | elif file_format == 'numpy': 592 | np.save(out_filename, point_bbox) 593 | else: 594 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \ 595 | (file_format)) 596 | exit() 597 | 598 | 599 | -------------------------------------------------------------------------------- /data_utils/meta/anno_paths.txt: -------------------------------------------------------------------------------- 1 | Area_1/conferenceRoom_1/Annotations 2 | Area_1/conferenceRoom_2/Annotations 3 | Area_1/copyRoom_1/Annotations 4 | Area_1/hallway_1/Annotations 5 | Area_1/hallway_2/Annotations 6 | Area_1/hallway_3/Annotations 7 | Area_1/hallway_4/Annotations 8 | Area_1/hallway_5/Annotations 9 | Area_1/hallway_6/Annotations 10 | Area_1/hallway_7/Annotations 11 | Area_1/hallway_8/Annotations 12 | Area_1/office_10/Annotations 13 | Area_1/office_11/Annotations 14 | Area_1/office_12/Annotations 15 | Area_1/office_13/Annotations 16 | Area_1/office_14/Annotations 17 | Area_1/office_15/Annotations 18 | Area_1/office_16/Annotations 19 | Area_1/office_17/Annotations 20 | Area_1/office_18/Annotations 21 | Area_1/office_19/Annotations 22 | Area_1/office_1/Annotations 23 | Area_1/office_20/Annotations 24 | Area_1/office_21/Annotations 25 | Area_1/office_22/Annotations 26 | Area_1/office_23/Annotations 27 | Area_1/office_24/Annotations 28 | Area_1/office_25/Annotations 29 | Area_1/office_26/Annotations 30 | Area_1/office_27/Annotations 31 | Area_1/office_28/Annotations 32 | Area_1/office_29/Annotations 33 | Area_1/office_2/Annotations 34 | Area_1/office_30/Annotations 35 | Area_1/office_31/Annotations 36 | Area_1/office_3/Annotations 37 | Area_1/office_4/Annotations 38 | Area_1/office_5/Annotations 39 | Area_1/office_6/Annotations 40 | Area_1/office_7/Annotations 41 | Area_1/office_8/Annotations 42 | Area_1/office_9/Annotations 43 | Area_1/pantry_1/Annotations 44 | Area_1/WC_1/Annotations 45 | Area_2/auditorium_1/Annotations 46 | Area_2/auditorium_2/Annotations 47 | Area_2/conferenceRoom_1/Annotations 48 | Area_2/hallway_10/Annotations 49 | Area_2/hallway_11/Annotations 50 | Area_2/hallway_12/Annotations 51 | Area_2/hallway_1/Annotations 52 | Area_2/hallway_2/Annotations 53 | Area_2/hallway_3/Annotations 54 | Area_2/hallway_4/Annotations 55 | Area_2/hallway_5/Annotations 56 | Area_2/hallway_6/Annotations 57 | Area_2/hallway_7/Annotations 58 | Area_2/hallway_8/Annotations 59 | Area_2/hallway_9/Annotations 60 | Area_2/office_10/Annotations 61 | Area_2/office_11/Annotations 62 | Area_2/office_12/Annotations 63 | Area_2/office_13/Annotations 64 | Area_2/office_14/Annotations 65 | Area_2/office_1/Annotations 66 | Area_2/office_2/Annotations 67 | Area_2/office_3/Annotations 68 | Area_2/office_4/Annotations 69 | Area_2/office_5/Annotations 70 | Area_2/office_6/Annotations 71 | Area_2/office_7/Annotations 72 | Area_2/office_8/Annotations 73 | Area_2/office_9/Annotations 74 | Area_2/storage_1/Annotations 75 | Area_2/storage_2/Annotations 76 | Area_2/storage_3/Annotations 77 | Area_2/storage_4/Annotations 78 | Area_2/storage_5/Annotations 79 | Area_2/storage_6/Annotations 80 | Area_2/storage_7/Annotations 81 | Area_2/storage_8/Annotations 82 | Area_2/storage_9/Annotations 83 | Area_2/WC_1/Annotations 84 | Area_2/WC_2/Annotations 85 | Area_3/conferenceRoom_1/Annotations 86 | Area_3/hallway_1/Annotations 87 | Area_3/hallway_2/Annotations 88 | Area_3/hallway_3/Annotations 89 | Area_3/hallway_4/Annotations 90 | Area_3/hallway_5/Annotations 91 | Area_3/hallway_6/Annotations 92 | Area_3/lounge_1/Annotations 93 | Area_3/lounge_2/Annotations 94 | Area_3/office_10/Annotations 95 | Area_3/office_1/Annotations 96 | Area_3/office_2/Annotations 97 | Area_3/office_3/Annotations 98 | Area_3/office_4/Annotations 99 | Area_3/office_5/Annotations 100 | Area_3/office_6/Annotations 101 | Area_3/office_7/Annotations 102 | Area_3/office_8/Annotations 103 | Area_3/office_9/Annotations 104 | Area_3/storage_1/Annotations 105 | Area_3/storage_2/Annotations 106 | Area_3/WC_1/Annotations 107 | Area_3/WC_2/Annotations 108 | Area_4/conferenceRoom_1/Annotations 109 | Area_4/conferenceRoom_2/Annotations 110 | Area_4/conferenceRoom_3/Annotations 111 | Area_4/hallway_10/Annotations 112 | Area_4/hallway_11/Annotations 113 | Area_4/hallway_12/Annotations 114 | Area_4/hallway_13/Annotations 115 | Area_4/hallway_14/Annotations 116 | Area_4/hallway_1/Annotations 117 | Area_4/hallway_2/Annotations 118 | Area_4/hallway_3/Annotations 119 | Area_4/hallway_4/Annotations 120 | Area_4/hallway_5/Annotations 121 | Area_4/hallway_6/Annotations 122 | Area_4/hallway_7/Annotations 123 | Area_4/hallway_8/Annotations 124 | Area_4/hallway_9/Annotations 125 | Area_4/lobby_1/Annotations 126 | Area_4/lobby_2/Annotations 127 | Area_4/office_10/Annotations 128 | Area_4/office_11/Annotations 129 | Area_4/office_12/Annotations 130 | Area_4/office_13/Annotations 131 | Area_4/office_14/Annotations 132 | Area_4/office_15/Annotations 133 | Area_4/office_16/Annotations 134 | Area_4/office_17/Annotations 135 | Area_4/office_18/Annotations 136 | Area_4/office_19/Annotations 137 | Area_4/office_1/Annotations 138 | Area_4/office_20/Annotations 139 | Area_4/office_21/Annotations 140 | Area_4/office_22/Annotations 141 | Area_4/office_2/Annotations 142 | Area_4/office_3/Annotations 143 | Area_4/office_4/Annotations 144 | Area_4/office_5/Annotations 145 | Area_4/office_6/Annotations 146 | Area_4/office_7/Annotations 147 | Area_4/office_8/Annotations 148 | Area_4/office_9/Annotations 149 | Area_4/storage_1/Annotations 150 | Area_4/storage_2/Annotations 151 | Area_4/storage_3/Annotations 152 | Area_4/storage_4/Annotations 153 | Area_4/WC_1/Annotations 154 | Area_4/WC_2/Annotations 155 | Area_4/WC_3/Annotations 156 | Area_4/WC_4/Annotations 157 | Area_5/conferenceRoom_1/Annotations 158 | Area_5/conferenceRoom_2/Annotations 159 | Area_5/conferenceRoom_3/Annotations 160 | Area_5/hallway_10/Annotations 161 | Area_5/hallway_11/Annotations 162 | Area_5/hallway_12/Annotations 163 | Area_5/hallway_13/Annotations 164 | Area_5/hallway_14/Annotations 165 | Area_5/hallway_15/Annotations 166 | Area_5/hallway_1/Annotations 167 | Area_5/hallway_2/Annotations 168 | Area_5/hallway_3/Annotations 169 | Area_5/hallway_4/Annotations 170 | Area_5/hallway_5/Annotations 171 | Area_5/hallway_6/Annotations 172 | Area_5/hallway_7/Annotations 173 | Area_5/hallway_8/Annotations 174 | Area_5/hallway_9/Annotations 175 | Area_5/lobby_1/Annotations 176 | Area_5/office_10/Annotations 177 | Area_5/office_11/Annotations 178 | Area_5/office_12/Annotations 179 | Area_5/office_13/Annotations 180 | Area_5/office_14/Annotations 181 | Area_5/office_15/Annotations 182 | Area_5/office_16/Annotations 183 | Area_5/office_17/Annotations 184 | Area_5/office_18/Annotations 185 | Area_5/office_19/Annotations 186 | Area_5/office_1/Annotations 187 | Area_5/office_20/Annotations 188 | Area_5/office_21/Annotations 189 | Area_5/office_22/Annotations 190 | Area_5/office_23/Annotations 191 | Area_5/office_24/Annotations 192 | Area_5/office_25/Annotations 193 | Area_5/office_26/Annotations 194 | Area_5/office_27/Annotations 195 | Area_5/office_28/Annotations 196 | Area_5/office_29/Annotations 197 | Area_5/office_2/Annotations 198 | Area_5/office_30/Annotations 199 | Area_5/office_31/Annotations 200 | Area_5/office_32/Annotations 201 | Area_5/office_33/Annotations 202 | Area_5/office_34/Annotations 203 | Area_5/office_35/Annotations 204 | Area_5/office_36/Annotations 205 | Area_5/office_37/Annotations 206 | Area_5/office_38/Annotations 207 | Area_5/office_39/Annotations 208 | Area_5/office_3/Annotations 209 | Area_5/office_40/Annotations 210 | Area_5/office_41/Annotations 211 | Area_5/office_42/Annotations 212 | Area_5/office_4/Annotations 213 | Area_5/office_5/Annotations 214 | Area_5/office_6/Annotations 215 | Area_5/office_7/Annotations 216 | Area_5/office_8/Annotations 217 | Area_5/office_9/Annotations 218 | Area_5/pantry_1/Annotations 219 | Area_5/storage_1/Annotations 220 | Area_5/storage_2/Annotations 221 | Area_5/storage_3/Annotations 222 | Area_5/storage_4/Annotations 223 | Area_5/WC_1/Annotations 224 | Area_5/WC_2/Annotations 225 | Area_6/conferenceRoom_1/Annotations 226 | Area_6/copyRoom_1/Annotations 227 | Area_6/hallway_1/Annotations 228 | Area_6/hallway_2/Annotations 229 | Area_6/hallway_3/Annotations 230 | Area_6/hallway_4/Annotations 231 | Area_6/hallway_5/Annotations 232 | Area_6/hallway_6/Annotations 233 | Area_6/lounge_1/Annotations 234 | Area_6/office_10/Annotations 235 | Area_6/office_11/Annotations 236 | Area_6/office_12/Annotations 237 | Area_6/office_13/Annotations 238 | Area_6/office_14/Annotations 239 | Area_6/office_15/Annotations 240 | Area_6/office_16/Annotations 241 | Area_6/office_17/Annotations 242 | Area_6/office_18/Annotations 243 | Area_6/office_19/Annotations 244 | Area_6/office_1/Annotations 245 | Area_6/office_20/Annotations 246 | Area_6/office_21/Annotations 247 | Area_6/office_22/Annotations 248 | Area_6/office_23/Annotations 249 | Area_6/office_24/Annotations 250 | Area_6/office_25/Annotations 251 | Area_6/office_26/Annotations 252 | Area_6/office_27/Annotations 253 | Area_6/office_28/Annotations 254 | Area_6/office_29/Annotations 255 | Area_6/office_2/Annotations 256 | Area_6/office_30/Annotations 257 | Area_6/office_31/Annotations 258 | Area_6/office_32/Annotations 259 | Area_6/office_33/Annotations 260 | Area_6/office_34/Annotations 261 | Area_6/office_35/Annotations 262 | Area_6/office_36/Annotations 263 | Area_6/office_37/Annotations 264 | Area_6/office_3/Annotations 265 | Area_6/office_4/Annotations 266 | Area_6/office_5/Annotations 267 | Area_6/office_6/Annotations 268 | Area_6/office_7/Annotations 269 | Area_6/office_8/Annotations 270 | Area_6/office_9/Annotations 271 | Area_6/openspace_1/Annotations 272 | Area_6/pantry_1/Annotations 273 | -------------------------------------------------------------------------------- /data_utils/meta/class_names.txt: -------------------------------------------------------------------------------- 1 | ceiling 2 | floor 3 | wall 4 | beam 5 | column 6 | window 7 | door 8 | table 9 | chair 10 | sofa 11 | bookcase 12 | board 13 | clutter 14 | -------------------------------------------------------------------------------- /image/plane1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/plane1.png -------------------------------------------------------------------------------- /image/plane2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/plane2.png -------------------------------------------------------------------------------- /image/psn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/psn.png -------------------------------------------------------------------------------- /models/PointSamplingNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch import Tensor 6 | from typing import List, Tuple 7 | 8 | 9 | class PointSamplingNet(nn.Module): 10 | """ 11 | Point Sampling Net PyTorch Module. 12 | 13 | Attributes: 14 | num_to_sample: the number to sample, int 15 | max_local_num: the max number of local area, int 16 | mlp: the channels of feature transform function, List[int] 17 | global_geature: whether enable global feature, bool 18 | """ 19 | 20 | def __init__(self, num_to_sample: int = 512, max_local_num: int = 32, mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None: 21 | """ 22 | Initialization of Point Sampling Net. 23 | """ 24 | super(PointSamplingNet, self).__init__() 25 | 26 | self.mlp_convs = nn.ModuleList() 27 | self.mlp_bns = nn.ModuleList() 28 | 29 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !" 30 | 31 | self.mlp_convs.append( 32 | nn.Conv1d(in_channels=6, out_channels=mlp[0], kernel_size=1)) 33 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0])) 34 | 35 | for i in range(len(mlp)-1): 36 | self.mlp_convs.append( 37 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1)) 38 | 39 | for i in range(len(mlp)-1): 40 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1])) 41 | 42 | self.global_feature = global_feature 43 | 44 | if self.global_feature: 45 | self.mlp_convs.append( 46 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1)) 47 | else: 48 | self.mlp_convs.append( 49 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1)) 50 | 51 | self.s = num_to_sample 52 | self.n = max_local_num 53 | 54 | def forward(self, coordinate: Tensor, feature: Tensor, train: bool = False) -> Tuple[Tensor, Tensor]: 55 | """ 56 | Forward propagation of Point Sampling Net 57 | 58 | Args: 59 | coordinate: input points position data, [B, m, 3] 60 | feature: input points feature, [B, m, d] 61 | Returns: 62 | sampled indices: the indices of sampled points, [B, s] 63 | grouped_indices: the indices of grouped points, [B, s, n] 64 | """ 65 | _, m, _ = coordinate.size() 66 | 67 | assert self.s < m, "The number to sample must less than input points !" 68 | 69 | r = torch.sqrt(torch.pow(coordinate[:,:,0],2)+torch.pow(coordinate[:,:,1],2)+torch.pow(coordinate[:,:,2],2)) 70 | th = torch.acos(coordinate[:,:,2] / r) 71 | fi = torch.atan2(coordinate[:,:,1], coordinate[:,:,0]) 72 | 73 | coordinate = torch.cat([coordinate, th.unsqueeze_(2), fi.unsqueeze_(2)], -1) 74 | 75 | x = coordinate.transpose(2, 1) # Channel First 76 | 77 | for i in range(len(self.mlp_convs) - 1): 78 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x))) 79 | 80 | if self.global_feature: 81 | max_feature = torch.max(x, 2, keepdim=True)[0] 82 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m] 83 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m] 84 | 85 | x = self.mlp_convs[-1](x) # [B,s,m] 86 | 87 | Q = torch.sigmoid(x) # [B, s, m] 88 | 89 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m] 90 | grouped_indices = indices[:,:,:self.n] 91 | grouped_points = index_points(coordinate, grouped_indices)[:,:,:self.n,:] #[B,s,n,3] 92 | if feature is not None: 93 | grouped_feature = index_points(feature, grouped_indices)[:,:,:self.n,:] #[B,s,n,d] 94 | if not train: 95 | sampled_points = grouped_points[:,:,0,:] # [B,s,3] 96 | sampled_feature = grouped_feature[:,:,0,:] #[B,s,d] 97 | else: 98 | Q = gumbel_softmax_sample(Q) # [B, s, m] 99 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3] 100 | sampled_feature = torch.matmul(Q, feature) # [B,s,d] 101 | grouped_feature[:,:,0,:] = sampled_feature 102 | else: 103 | if not train: 104 | sampled_points = grouped_points[:,:,0,:] # [B,s,3] 105 | sampled_feature = None #[B,s,d] 106 | else: 107 | Q = gumbel_softmax_sample(Q) # [B, s, m] 108 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3] 109 | sampled_feature = None 110 | grouped_feature = None 111 | 112 | return sampled_points, grouped_points, sampled_feature, grouped_feature 113 | 114 | class PointSamplingNetRadius(nn.Module): 115 | """ 116 | Point Sampling Net with heuristic condition PyTorch Module. 117 | This example is radius query 118 | You may replace function C(x) by your own function 119 | 120 | Attributes: 121 | num_to_sample: the number to sample, int 122 | radius: radius to query, float 123 | max_local_num: the max number of local area, int 124 | mlp: the channels of feature transform function, List[int] 125 | global_geature: whether enable global feature, bool 126 | """ 127 | 128 | def __init__(self, num_to_sample: int = 512, radius: float = 1.0, max_local_num: int = 32, mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None: 129 | """ 130 | Initialization of Point Sampling Net. 131 | """ 132 | super(PointSamplingNetRadius, self).__init__() 133 | 134 | self.mlp_convs = nn.ModuleList() 135 | self.mlp_bns = nn.ModuleList() 136 | self.radius = radius 137 | 138 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !" 139 | 140 | self.mlp_convs.append( 141 | nn.Conv1d(in_channels=3, out_channels=mlp[0], kernel_size=1)) 142 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0])) 143 | 144 | for i in range(len(mlp)-1): 145 | self.mlp_convs.append( 146 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1)) 147 | 148 | for i in range(len(mlp)-1): 149 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1])) 150 | 151 | self.global_feature = global_feature 152 | 153 | if self.global_feature: 154 | self.mlp_convs.append( 155 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1)) 156 | else: 157 | self.mlp_convs.append( 158 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1)) 159 | 160 | self.softmax = nn.Softmax(dim=1) 161 | 162 | self.s = num_to_sample 163 | self.n = max_local_num 164 | 165 | def forward(self, coordinate: Tensor) -> Tuple[Tensor, Tensor]: 166 | """ 167 | Forward propagation of Point Sampling Net 168 | 169 | Args: 170 | coordinate: input points position data, [B, m, 3] 171 | Returns: 172 | sampled indices: the indices of sampled points, [B, s] 173 | grouped_indices: the indices of grouped points, [B, s, n] 174 | """ 175 | _, m, _ = coordinate.size() 176 | 177 | assert self.s < m, "The number to sample must less than input points !" 178 | 179 | x = coordinate.transpose(2, 1) # Channel First 180 | 181 | for i in range(len(self.mlp_convs) - 1): 182 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x))) 183 | 184 | if self.global_feature: 185 | max_feature = torch.max(x, 2, keepdim=True)[0] 186 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m] 187 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m] 188 | 189 | x = self.mlp_convs[-1](x) # [B,s,m] 190 | 191 | Q = self.softmax(x) # [B, s, m] 192 | 193 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m] 194 | 195 | grouped_indices = indices[:, :, 0:self.n] # [B, s, n] 196 | 197 | sampled_indices = indices[:, :, 0] # [B, s] 198 | 199 | # function C(x) 200 | # you may replace C(x) by your heuristic condition 201 | sampled_coordinate = torch.unsqueeze(index_points(coordinate, sampled_indices), dim=2) # [B, s, 1, 3] 202 | grouped_coordinate = index_points(coordinate, grouped_indices) # [B, s, m, 3] 203 | 204 | diff = grouped_coordinate - sampled_coordinate 205 | diff = diff ** 2 206 | diff = torch.sum(diff, dim=3) #[B, s, m] 207 | mask = diff > self.radius ** 2 208 | 209 | sampled_indices_expand = torch.unsqueeze(sampled_indices, dim=2).repeat(1, 1, self.n) #[B, s, n] 210 | grouped_indices[mask] = sampled_indices_expand[mask] 211 | # function C(x) end 212 | 213 | return sampled_indices, grouped_indices 214 | 215 | class PointSamplingNetMSG(nn.Module): 216 | """ 217 | Point Sampling Net with Multi-scale Grouping PyTorch Module. 218 | 219 | Attributes: 220 | num_to_sample: the number to sample, int 221 | msg_n: the list of mutil-scale grouping n values, List[int] 222 | mlp: the channels of feature transform function, List[int] 223 | global_geature: whether enable global feature, bool 224 | """ 225 | 226 | def __init__(self, num_to_sample: int = 512, msg_n: List[int] = [32, 64], mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None: 227 | """ 228 | Initialization of Point Sampling Net. 229 | """ 230 | super(PointSamplingNetMSG, self).__init__() 231 | 232 | self.mlp_convs = nn.ModuleList() 233 | self.mlp_bns = nn.ModuleList() 234 | 235 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !" 236 | 237 | self.mlp_convs.append( 238 | nn.Conv1d(in_channels=3, out_channels=mlp[0], kernel_size=1)) 239 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0])) 240 | 241 | for i in range(len(mlp)-1): 242 | self.mlp_convs.append( 243 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1)) 244 | 245 | for i in range(len(mlp)-1): 246 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1])) 247 | 248 | self.global_feature = global_feature 249 | 250 | if self.global_feature: 251 | self.mlp_convs.append( 252 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1)) 253 | else: 254 | self.mlp_convs.append( 255 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1)) 256 | 257 | self.softmax = nn.Softmax(dim=1) 258 | 259 | self.s = num_to_sample 260 | self.msg_n = msg_n 261 | 262 | def forward(self, coordinate: Tensor, feature: Tensor, train: bool = False) -> Tuple[Tensor, Tensor]: 263 | """ 264 | Forward propagation of Point Sampling Net 265 | 266 | Args: 267 | coordinate: input points position data, [B, m, 3] 268 | Returns: 269 | sampled indices: the indices of sampled points, [B, s] 270 | grouped_indices_msg: the multi-scale grouping indices of grouped points, List[Tensor] 271 | """ 272 | _, m, _ = coordinate.size() 273 | 274 | assert self.s < m, "The number to sample must less than input points !" 275 | 276 | x = coordinate.transpose(2, 1) # Channel First 277 | 278 | for i in range(len(self.mlp_convs) - 1): 279 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x))) 280 | 281 | if self.global_feature: 282 | max_feature = torch.max(x, 2, keepdim=True)[0] 283 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m] 284 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m] 285 | 286 | x = self.mlp_convs[-1](x) # [B,s,m] 287 | 288 | 289 | Q = torch.sigmoid(x) # [B, s, m] 290 | 291 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m] 292 | grouped_indices = indices[:,:,:self.n] 293 | grouped_points_msg = [] 294 | for n in self.msg_n: 295 | grouped_points_msg.append(index_points(coordinate, grouped_indices)[:,:,:n,:]) 296 | if feature is not None: 297 | grouped_feature_msg = [] 298 | for n in self.msg_n: 299 | grouped_feature_msg.append(index_points(feature, grouped_indices)[:,:,:n,:]) 300 | if not train: 301 | sampled_points = grouped_points_msg[0][:,:,0,:] # [B,s,3] 302 | sampled_feature = grouped_feature_msg[-1][:,:,0,:] #[B,s,d] 303 | else: 304 | Q = gumbel_softmax_sample(Q) # [B, s, m] 305 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3] 306 | sampled_feature = torch.matmul(Q, feature) # [B,s,d] 307 | for n in self.msg_n: 308 | grouped_feature_msg[n][:,:,0,:] = sampled_feature 309 | else: 310 | if not train: 311 | sampled_points = grouped_points_msg[0][:,:,0,:] # [B,s,3] 312 | sampled_feature = None #[B,s,d] 313 | else: 314 | Q = gumbel_softmax_sample(Q) # [B, s, m] 315 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3] 316 | sampled_feature = None 317 | grouped_feature_msg = None 318 | 319 | return sampled_points, grouped_points_msg, sampled_feature, grouped_feature_msg 320 | 321 | 322 | def index_points(points, idx): 323 | """ 324 | 325 | Input: 326 | points: input points data, [B, N, C] 327 | idx: sample index data, [B, S] 328 | Return: 329 | new_points:, indexed points data, [B, S, C] 330 | """ 331 | device = points.device 332 | B = points.shape[0] 333 | view_shape = list(idx.shape) 334 | view_shape[1:] = [1] * (len(view_shape) - 1) 335 | repeat_shape = list(idx.shape) 336 | repeat_shape[0] = 1 337 | batch_indices = torch.arange(B, dtype=torch.long).to( 338 | device).view(view_shape).repeat(repeat_shape) 339 | new_points = points[batch_indices, idx, :] 340 | return new_points 341 | 342 | def sample_gumbel(shape, eps=1e-20): 343 | U = torch.rand(shape) 344 | U = U.cuda() 345 | return -torch.log(-torch.log(U + eps) + eps) 346 | 347 | 348 | def gumbel_softmax_sample(logits, dim=-1, temperature=0.001): 349 | y = logits + sample_gumbel(logits.size()) 350 | return F.softmax(y / temperature, dim=dim) 351 | 352 | def gumbel_softmax(logits, temperature=1.0, hard=False): 353 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 354 | Args: 355 | logits: [batch_size, n_class] unnormalized log-probs 356 | temperature: non-negative scalar 357 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 358 | Returns: 359 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 360 | If hard=True, then the returned sample will be one-hot, otherwise it will 361 | be a probabilitiy distribution that sums to 1 across classes 362 | """ 363 | y = gumbel_softmax_sample(logits, temperature) 364 | if hard: 365 | y_hard = onehot_from_logits(y) 366 | #print(y_hard[0], "random") 367 | y = (y_hard - y).detach() + y 368 | return y 369 | 370 | def onehot_from_logits(logits, eps=0.0): 371 | """ 372 | Given batch of logits, return one-hot sample using epsilon greedy strategy 373 | (based on given epsilon) 374 | """ 375 | # get best (according to current policy) actions in one-hot form 376 | argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float() 377 | #print(logits[0],"a") 378 | #print(len(argmax_acs),argmax_acs[0]) 379 | if eps == 0.0: 380 | return argmax_acs 381 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/models/__init__.py -------------------------------------------------------------------------------- /models/pointnet2_cls_ssg_psn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from models.pointnet_util_psn import PointNetSetAbstraction 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self,num_class,normal_channel=True): 8 | super(get_model, self).__init__() 9 | in_channel = 6 if normal_channel else 3 10 | self.normal_channel = normal_channel 11 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False) 12 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 13 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 14 | self.fc1 = nn.Linear(1024, 512) 15 | self.bn1 = nn.BatchNorm1d(512) 16 | self.drop1 = nn.Dropout(0.4) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.bn2 = nn.BatchNorm1d(256) 19 | self.drop2 = nn.Dropout(0.4) 20 | self.fc3 = nn.Linear(256, num_class) 21 | 22 | def forward(self, xyz, train): 23 | B, _, _ = xyz.shape 24 | if self.normal_channel: 25 | norm = xyz[:, 3:, :] 26 | xyz = xyz[:, :3, :] 27 | else: 28 | norm = None 29 | l1_xyz, l1_points = self.sa1(xyz, norm, train) 30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train) 31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train) 32 | x = l3_points.view(B, 1024) 33 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 34 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 35 | x = self.fc3(x) 36 | x = F.log_softmax(x, -1) 37 | 38 | 39 | return x, l3_points 40 | 41 | 42 | 43 | class get_loss(nn.Module): 44 | def __init__(self): 45 | super(get_loss, self).__init__() 46 | 47 | def forward(self, pred, target, trans_feat): 48 | total_loss = F.nll_loss(pred, target) 49 | 50 | return total_loss 51 | -------------------------------------------------------------------------------- /models/pointnet2_part_seg_ssg_psn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from models.pointnet_util_psn import PointNetSetAbstraction,PointNetFeaturePropagation 5 | 6 | 7 | class get_model(nn.Module): 8 | def __init__(self, num_classes, normal_channel=False): 9 | super(get_model, self).__init__() 10 | if normal_channel: 11 | additional_channel = 3 12 | else: 13 | additional_channel = 0 14 | self.normal_channel = normal_channel 15 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=6+additional_channel, mlp=[64, 64, 128], group_all=False) # 这里修改过 nsample原来是32 16 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 18 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256]) 19 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128]) 20 | self.fp1 = PointNetFeaturePropagation(in_channel=128+16+6+additional_channel, mlp=[128, 128, 128]) 21 | self.conv1 = nn.Conv1d(128, 128, 1) 22 | self.bn1 = nn.BatchNorm1d(128) 23 | self.drop1 = nn.Dropout(0.5) 24 | self.conv2 = nn.Conv1d(128, num_classes, 1) 25 | 26 | def forward(self, xyz, cls_label, train): 27 | # Set Abstraction layers 28 | B,C,N = xyz.shape 29 | if self.normal_channel: 30 | l0_points = xyz 31 | l0_xyz = xyz[:,:3,:] 32 | else: 33 | l0_points = xyz 34 | l0_xyz = xyz 35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points, train) 36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train) 37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train) 38 | # Feature Propagation layers 39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N) 42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points) 43 | # FC layers 44 | feat = F.relu(self.bn1(self.conv1(l0_points))) 45 | x = self.drop1(feat) 46 | x = self.conv2(x) 47 | x = F.log_softmax(x, dim=1) 48 | x = x.permute(0, 2, 1) 49 | # diff = diff1 + diff2 + diff3 50 | return x, l3_points 51 | 52 | 53 | class get_loss(nn.Module): 54 | def __init__(self): 55 | super(get_loss, self).__init__() 56 | 57 | def forward(self, pred, target, trans_feat): 58 | total_loss = F.nll_loss(pred, target) 59 | 60 | return total_loss -------------------------------------------------------------------------------- /models/pointnet2_sem_seg_psn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from models.pointnet_util_psn import PointNetSetAbstraction,PointNetFeaturePropagation 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self, num_classes): 8 | super(get_model, self).__init__() 9 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False) 10 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 11 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 12 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 13 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 14 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 15 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 16 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 17 | self.conv1 = nn.Conv1d(128, 128, 1) 18 | self.bn1 = nn.BatchNorm1d(128) 19 | self.drop1 = nn.Dropout(0.5) 20 | self.conv2 = nn.Conv1d(128, num_classes, 1) 21 | 22 | def forward(self, xyz, train): 23 | l0_points = xyz 24 | l0_xyz = xyz[:,:3,:] 25 | 26 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points, train) 27 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train) 28 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train) 29 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points, train) 30 | 31 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 32 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 33 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 34 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 35 | 36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 37 | x = self.conv2(x) 38 | x = F.log_softmax(x, dim=1) 39 | x = x.permute(0, 2, 1) 40 | return x, l4_points 41 | 42 | 43 | class get_loss(nn.Module): 44 | def __init__(self): 45 | super(get_loss, self).__init__() 46 | def forward(self, pred, target, trans_feat, weight): 47 | total_loss = F.nll_loss(pred, target, weight=weight) 48 | 49 | return total_loss 50 | 51 | if __name__ == '__main__': 52 | import torch 53 | model = get_model(13) 54 | xyz = torch.rand(6, 9, 2048) 55 | (model(xyz)) -------------------------------------------------------------------------------- /models/pointnet_util_psn.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | import torch 3 | from torch.functional import chain_matmul 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from time import time 7 | import numpy as np 8 | from models import PointSamplingNet as psn 9 | 10 | 11 | def timeit(tag, t): 12 | print("{}: {}s".format(tag, time() - t)) 13 | return time() 14 | 15 | 16 | def pc_normalize(pc): 17 | l = pc.shape[0] 18 | centroid = np.mean(pc, axis=0) 19 | pc = pc - centroid 20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 21 | pc = pc / m 22 | return pc 23 | 24 | 25 | def square_distance(src, dst): 26 | """ 27 | Calculate Euclid distance between each two points. 28 | 29 | src^T * dst = xn * xm + yn * ym + zn * zm; 30 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 31 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 32 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 33 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 34 | 35 | Input: 36 | src: source points, [B, N, C] 37 | dst: target points, [B, M, C] 38 | Output: 39 | dist: per-point square distance, [B, N, M] 40 | """ 41 | B, N, _ = src.shape 42 | _, M, _ = dst.shape 43 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 44 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 45 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 46 | return dist 47 | 48 | 49 | def index_points(points, idx): 50 | """ 51 | 52 | Input: 53 | points: input points data, [B, N, C] 54 | idx: sample index data, [B, S] 55 | Return: 56 | new_points:, indexed points data, [B, S, C] 57 | """ 58 | device = points.device 59 | B = points.shape[0] 60 | view_shape = list(idx.shape) 61 | view_shape[1:] = [1] * (len(view_shape) - 1) 62 | repeat_shape = list(idx.shape) 63 | repeat_shape[0] = 1 64 | batch_indices = torch.arange(B, dtype=torch.long).to( 65 | device).view(view_shape).repeat(repeat_shape) 66 | new_points = points[batch_indices, idx, :] 67 | return new_points 68 | 69 | 70 | def sample_and_group_psn(npoint, sampled_points, grouped_points, sampled_feature, grouped_feature, nsample): 71 | """ 72 | Sampling and grouping point cloud with PSN. 73 | 74 | Input: 75 | sampled_points: sampled points by PSN, [B, s, 3] 76 | grouped_points: grouped points by PSN, [B, s, n, 3] 77 | sampled_feature: sampled feature, [B, s, d] 78 | grouped_feature: grouped feature, [B, s, n, d] 79 | nsample: the max number of local area, int 80 | xyz: coordinate, [B, m, 3] 81 | points: feature , [B, m, d] 82 | Output: 83 | new_xyz: sampled points coordinate, [B, s, 3] 84 | new_points: sampled points feature, [B, s, d+3] 85 | """ 86 | B, _, _ = sampled_points.shape 87 | S = npoint 88 | grouped_xyz_norm = grouped_points - sampled_points.view(B, S, 1, 3).repeat([1, 1, nsample, 1]) # [B,s,n,3] 89 | 90 | torch.cuda.empty_cache() 91 | if grouped_feature is not None: 92 | new_points = torch.cat([grouped_xyz_norm, grouped_feature], dim=-1) 93 | else: 94 | new_points = grouped_xyz_norm 95 | return sampled_points, new_points 96 | 97 | 98 | def query_ball_point(radius, nsample, xyz, new_xyz): 99 | """ 100 | Input: 101 | radius: local region radius 102 | nsample: max sample number in local region 103 | xyz: all points, [B, N, 3] 104 | new_xyz: query points, [B, S, 3] 105 | Return: 106 | group_idx: grouped points index, [B, S, nsample] 107 | """ 108 | device = xyz.device 109 | B, N, C = xyz.shape 110 | _, S, _ = new_xyz.shape 111 | group_idx = torch.arange(N, dtype=torch.long).to( 112 | device).view(1, 1, N).repeat([B, S, 1]) 113 | sqrdists = square_distance(new_xyz, xyz) 114 | group_idx[sqrdists > radius ** 2] = N 115 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 116 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 117 | mask = group_idx == N 118 | group_idx[mask] = group_first[mask] 119 | return group_idx 120 | 121 | 122 | 123 | def sample_and_group_all(xyz, points): 124 | """ 125 | Input: 126 | xyz: input points position data, [B, N, 3] 127 | points: input points data, [B, N, D] 128 | Return: 129 | new_xyz: sampled points position data, [B, 1, 3] 130 | new_points: sampled points data, [B, 1, N, 3+D] 131 | """ 132 | device = xyz.device 133 | B, N, C = xyz.shape 134 | new_xyz = torch.zeros(B, 1, C).to(device) 135 | grouped_xyz = xyz.view(B, 1, N, C) 136 | if points is not None: 137 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 138 | else: 139 | new_points = grouped_xyz 140 | return new_xyz, new_points 141 | 142 | 143 | class PointNetSetAbstraction(nn.Module): 144 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 145 | super(PointNetSetAbstraction, self).__init__() 146 | self.npoint = npoint 147 | self.radius = radius 148 | self.nsample = nsample 149 | self.mlp_convs = nn.ModuleList() 150 | self.mlp_bns = nn.ModuleList() 151 | last_channel = in_channel 152 | for out_channel in mlp: 153 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 154 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 155 | last_channel = out_channel 156 | self.group_all = group_all 157 | if not group_all: 158 | # self.sampling = pointsampling.get_model(npoint) 159 | self.sampling = psn.PointSamplingNet(npoint, nsample, [64,128,256], global_feature=True) 160 | 161 | def forward(self, xyz, points, train): 162 | """ 163 | Input: 164 | xyz: input points position data, [B, C, N] 165 | points: input points data, [B, D, N] 166 | Return: 167 | new_xyz: sampled points position data, [B, C, S] 168 | new_points_concat: sample points feature data, [B, D', S] 169 | """ 170 | xyz = xyz.permute(0, 2, 1) 171 | if points is not None: 172 | points = points.permute(0, 2, 1) 173 | 174 | if self.group_all: 175 | new_xyz, new_points = sample_and_group_all(xyz, points) 176 | else: 177 | sampled_points, grouped_points, sampled_feature, grouped_feature = self.sampling(xyz,points,train) 178 | new_xyz, new_points = sample_and_group_psn( 179 | self.npoint, sampled_points, grouped_points, sampled_feature, grouped_feature, self.nsample, xyz, points) 180 | # new_xyz: sampled points position data, [B, npoint, C] 181 | # new_points: sampled points data, [B, npoint, nsample, C+D] 182 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 183 | for i, conv in enumerate(self.mlp_convs): 184 | bn = self.mlp_bns[i] 185 | new_points = F.relu(bn(conv(new_points))) 186 | 187 | new_points = torch.max(new_points, 2)[0] 188 | new_xyz = new_xyz.permute(0, 2, 1) 189 | return new_xyz, new_points 190 | 191 | 192 | class PointNetFeaturePropagation(nn.Module): 193 | def __init__(self, in_channel, mlp): 194 | super(PointNetFeaturePropagation, self).__init__() 195 | self.mlp_convs = nn.ModuleList() 196 | self.mlp_bns = nn.ModuleList() 197 | last_channel = in_channel 198 | for out_channel in mlp: 199 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 200 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 201 | last_channel = out_channel 202 | 203 | def forward(self, xyz1, xyz2, points1, points2): 204 | """ 205 | Input: 206 | xyz1: input points position data, [B, C, N] 207 | xyz2: sampled input points position data, [B, C, S] 208 | points1: input points data, [B, D, N] 209 | points2: input points data, [B, D, S] 210 | Return: 211 | new_points: upsampled points data, [B, D', N] 212 | """ 213 | xyz1 = xyz1.permute(0, 2, 1) 214 | xyz2 = xyz2.permute(0, 2, 1) 215 | 216 | points2 = points2.permute(0, 2, 1) 217 | B, N, C = xyz1.shape 218 | _, S, _ = xyz2.shape 219 | 220 | if S == 1: 221 | interpolated_points = points2.repeat(1, N, 1) 222 | else: 223 | dists = square_distance(xyz1, xyz2) 224 | dists, idx = dists.sort(dim=-1) 225 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 226 | 227 | dist_recip = 1.0 / (dists + 1e-8) 228 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 229 | weight = dist_recip / norm 230 | interpolated_points = torch.sum(index_points( 231 | points2, idx) * weight.view(B, N, 3, 1), dim=2) 232 | 233 | if points1 is not None: 234 | points1 = points1.permute(0, 2, 1) 235 | new_points = torch.cat([points1, interpolated_points], dim=-1) 236 | else: 237 | new_points = interpolated_points 238 | 239 | new_points = new_points.permute(0, 2, 1) 240 | for i, conv in enumerate(self.mlp_convs): 241 | bn = self.mlp_bns[i] 242 | new_points = F.relu(bn(conv(new_points))) 243 | return new_points 244 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /test_cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 6 | import argparse 7 | import numpy as np 8 | import os 9 | import torch 10 | import logging 11 | from tqdm import tqdm 12 | import sys 13 | import importlib 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | ROOT_DIR = BASE_DIR 17 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 18 | 19 | 20 | def parse_args(): 21 | '''PARAMETERS''' 22 | parser = argparse.ArgumentParser('PointNet') 23 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') 24 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 25 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]') 26 | parser.add_argument('--log_dir', type=str, default='pointnet2_ssg_normal', help='Experiment root') 27 | parser.add_argument('--normal', action='store_true', default=True, help='Whether to use normal information [default: False]') 28 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting [default: 3]') 29 | return parser.parse_args() 30 | 31 | def test(model, loader, num_class=40, vote_num=1): 32 | mean_correct = [] 33 | class_acc = np.zeros((num_class,3)) 34 | for j, data in tqdm(enumerate(loader), total=len(loader)): 35 | points, target = data 36 | target = target[:, 0] 37 | points = points.transpose(2, 1) 38 | points, target = points.cuda(), target.cuda() 39 | classifier = model.eval() 40 | vote_pool = torch.zeros(target.size()[0],num_class).cuda() 41 | for _ in range(vote_num): 42 | pred, _ = classifier(points, False) 43 | vote_pool += pred 44 | pred = vote_pool/vote_num 45 | pred_choice = pred.data.max(1)[1] 46 | for cat in np.unique(target.cpu()): 47 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 48 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 49 | class_acc[cat,1]+=1 50 | correct = pred_choice.eq(target.long().data).cpu().sum() 51 | mean_correct.append(correct.item()/float(points.size()[0])) 52 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 53 | class_acc = np.mean(class_acc[:,2]) 54 | instance_acc = np.mean(mean_correct) 55 | return instance_acc, class_acc 56 | 57 | 58 | def main(args): 59 | def log_string(str): 60 | logger.info(str) 61 | print(str) 62 | 63 | '''HYPER PARAMETER''' 64 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 65 | 66 | '''CREATE DIR''' 67 | experiment_dir = 'log/classification/' + args.log_dir 68 | 69 | '''LOG''' 70 | args = parse_args() 71 | logger = logging.getLogger("Model") 72 | logger.setLevel(logging.INFO) 73 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 74 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) 75 | file_handler.setLevel(logging.INFO) 76 | file_handler.setFormatter(formatter) 77 | logger.addHandler(file_handler) 78 | log_string('PARAMETER ...') 79 | log_string(args) 80 | 81 | '''DATA LOADING''' 82 | log_string('Load dataset ...') 83 | DATA_PATH = 'data/modelnet40_normal_resampled/' 84 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) 85 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) 86 | 87 | '''MODEL LOADING''' 88 | num_class = 40 89 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0] 90 | MODEL = importlib.import_module(model_name) 91 | 92 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() 93 | 94 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 95 | classifier.load_state_dict(checkpoint['model_state_dict']) 96 | 97 | with torch.no_grad(): 98 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes) 99 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc)) 100 | 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | main(args) 106 | -------------------------------------------------------------------------------- /test_partseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.ShapeNetDataLoader import PartNormalDataset 8 | import torch 9 | import logging 10 | import sys 11 | import importlib 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | ROOT_DIR = BASE_DIR 17 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 18 | 19 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 20 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 21 | for cat in seg_classes.keys(): 22 | for label in seg_classes[cat]: 23 | seg_label_to_cat[label] = cat 24 | 25 | def to_categorical(y, num_classes): 26 | """ 1-hot encodes a tensor """ 27 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 28 | if (y.is_cuda): 29 | return new_y.cuda() 30 | return new_y 31 | 32 | 33 | def parse_args(): 34 | '''PARAMETERS''' 35 | parser = argparse.ArgumentParser('PointNet') 36 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in testing [default: 24]') 37 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]') 38 | parser.add_argument('--num_point', type=int, default=2048, help='Point Number [default: 2048]') 39 | parser.add_argument('--log_dir', type=str, default='pointnet2_part_seg_ssg', help='Experiment root') 40 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 41 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate segmentation scores with voting [default: 3]') 42 | return parser.parse_args() 43 | 44 | def main(args): 45 | def log_string(str): 46 | logger.info(str) 47 | print(str) 48 | 49 | '''HYPER PARAMETER''' 50 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 51 | experiment_dir = 'log/part_seg/' + args.log_dir 52 | 53 | '''LOG''' 54 | args = parse_args() 55 | logger = logging.getLogger("Model") 56 | logger.setLevel(logging.INFO) 57 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 58 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) 59 | file_handler.setLevel(logging.INFO) 60 | file_handler.setFormatter(formatter) 61 | logger.addHandler(file_handler) 62 | log_string('PARAMETER ...') 63 | log_string(args) 64 | 65 | root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' 66 | 67 | TEST_DATASET = PartNormalDataset(root = root, npoints=args.num_point, split='test', normal_channel=args.normal) 68 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4) 69 | log_string("The number of test data is: %d" % len(TEST_DATASET)) 70 | num_classes = 16 71 | num_part = 50 72 | 73 | '''MODEL LOADING''' 74 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0] 75 | MODEL = importlib.import_module(model_name) 76 | classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda() 77 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 78 | classifier.load_state_dict(checkpoint['model_state_dict']) 79 | 80 | 81 | with torch.no_grad(): 82 | test_metrics = {} 83 | total_correct = 0 84 | total_seen = 0 85 | total_seen_class = [0 for _ in range(num_part)] 86 | total_correct_class = [0 for _ in range(num_part)] 87 | shape_ious = {cat: [] for cat in seg_classes.keys()} 88 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 89 | for cat in seg_classes.keys(): 90 | for label in seg_classes[cat]: 91 | seg_label_to_cat[label] = cat 92 | 93 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 94 | batchsize, num_point, _ = points.size() 95 | cur_batch_size, NUM_POINT, _ = points.size() 96 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() 97 | points = points.transpose(2, 1) 98 | classifier = classifier.eval() 99 | vote_pool = torch.zeros(target.size()[0], target.size()[1], num_part).cuda() 100 | for _ in range(args.num_votes): 101 | seg_pred, _ = classifier(points, to_categorical(label, num_classes), False) 102 | vote_pool += seg_pred 103 | seg_pred = vote_pool / args.num_votes 104 | cur_pred_val = seg_pred.cpu().data.numpy() 105 | cur_pred_val_logits = cur_pred_val 106 | cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) 107 | target = target.cpu().data.numpy() 108 | for i in range(cur_batch_size): 109 | cat = seg_label_to_cat[target[i, 0]] 110 | logits = cur_pred_val_logits[i, :, :] 111 | cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 112 | correct = np.sum(cur_pred_val == target) 113 | total_correct += correct 114 | total_seen += (cur_batch_size * NUM_POINT) 115 | 116 | for l in range(num_part): 117 | total_seen_class[l] += np.sum(target == l) 118 | total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) 119 | 120 | for i in range(cur_batch_size): 121 | segp = cur_pred_val[i, :] 122 | segl = target[i, :] 123 | cat = seg_label_to_cat[segl[0]] 124 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 125 | for l in seg_classes[cat]: 126 | if (np.sum(segl == l) == 0) and ( 127 | np.sum(segp == l) == 0): # part is not present, no prediction as well 128 | part_ious[l - seg_classes[cat][0]] = 1.0 129 | else: 130 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 131 | np.sum((segl == l) | (segp == l))) 132 | shape_ious[cat].append(np.mean(part_ious)) 133 | 134 | all_shape_ious = [] 135 | for cat in shape_ious.keys(): 136 | for iou in shape_ious[cat]: 137 | all_shape_ious.append(iou) 138 | shape_ious[cat] = np.mean(shape_ious[cat]) 139 | mean_shape_ious = np.mean(list(shape_ious.values())) 140 | test_metrics['accuracy'] = total_correct / float(total_seen) 141 | test_metrics['class_avg_accuracy'] = np.mean( 142 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 143 | for cat in sorted(shape_ious.keys()): 144 | log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 145 | test_metrics['class_avg_iou'] = mean_shape_ious 146 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 147 | 148 | 149 | log_string('Accuracy is: %.5f'%test_metrics['accuracy']) 150 | log_string('Class avg accuracy is: %.5f'%test_metrics['class_avg_accuracy']) 151 | log_string('Class avg mIOU is: %.5f'%test_metrics['class_avg_iou']) 152 | log_string('Inctance avg mIOU is: %.5f'%test_metrics['inctance_avg_iou']) 153 | 154 | if __name__ == '__main__': 155 | args = parse_args() 156 | main(args) 157 | 158 | -------------------------------------------------------------------------------- /test_semseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.S3DISDataLoader import ScannetDatasetWholeScene 8 | from data_utils.indoor3d_util import g_label2color 9 | import torch 10 | import logging 11 | from pathlib import Path 12 | import sys 13 | import importlib 14 | from tqdm import tqdm 15 | import provider 16 | import numpy as np 17 | 18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 19 | ROOT_DIR = BASE_DIR 20 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 21 | 22 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 23 | class2label = {cls: i for i,cls in enumerate(classes)} 24 | seg_classes = class2label 25 | seg_label_to_cat = {} 26 | for i,cat in enumerate(seg_classes.keys()): 27 | seg_label_to_cat[i] = cat 28 | 29 | def parse_args(): 30 | '''PARAMETERS''' 31 | parser = argparse.ArgumentParser('Model') 32 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]') 33 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 34 | parser.add_argument('--num_point', type=int, default=4096, help='Point Number [default: 4096]') 35 | parser.add_argument('--log_dir', type=str, default='pointnet2_sem_seg', help='Experiment root') 36 | parser.add_argument('--visual', action='store_true', default=False, help='Whether visualize result [default: False]') 37 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]') 38 | parser.add_argument('--num_votes', type=int, default=5, help='Aggregate segmentation scores with voting [default: 5]') 39 | return parser.parse_args() 40 | 41 | def add_vote(vote_label_pool, point_idx, pred_label, weight): 42 | B = pred_label.shape[0] 43 | N = pred_label.shape[1] 44 | for b in range(B): 45 | for n in range(N): 46 | if weight[b,n]: 47 | vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1 48 | return vote_label_pool 49 | 50 | def main(args): 51 | def log_string(str): 52 | logger.info(str) 53 | print(str) 54 | 55 | '''HYPER PARAMETER''' 56 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 57 | experiment_dir = 'log/sem_seg/' + args.log_dir 58 | visual_dir = experiment_dir + '/visual/' 59 | visual_dir = Path(visual_dir) 60 | visual_dir.mkdir(exist_ok=True) 61 | 62 | '''LOG''' 63 | args = parse_args() 64 | logger = logging.getLogger("Model") 65 | logger.setLevel(logging.INFO) 66 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 67 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) 68 | file_handler.setLevel(logging.INFO) 69 | file_handler.setFormatter(formatter) 70 | logger.addHandler(file_handler) 71 | log_string('PARAMETER ...') 72 | log_string(args) 73 | 74 | NUM_CLASSES = 13 75 | BATCH_SIZE = args.batch_size 76 | NUM_POINT = args.num_point 77 | 78 | root = 'data/stanford_indoor3d/' 79 | 80 | TEST_DATASET_WHOLE_SCENE = ScannetDatasetWholeScene(root, split='test', test_area=args.test_area, block_points=NUM_POINT) 81 | log_string("The number of test data is: %d" % len(TEST_DATASET_WHOLE_SCENE)) 82 | 83 | '''MODEL LOADING''' 84 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0] 85 | MODEL = importlib.import_module(model_name) 86 | classifier = MODEL.get_model(NUM_CLASSES).cuda() 87 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 88 | classifier.load_state_dict(checkpoint['model_state_dict']) 89 | 90 | with torch.no_grad(): 91 | scene_id = TEST_DATASET_WHOLE_SCENE.file_list 92 | scene_id = [x[:-4] for x in scene_id] 93 | num_batches = len(TEST_DATASET_WHOLE_SCENE) 94 | 95 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 96 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 97 | total_iou_deno_class = [0 for _ in range(NUM_CLASSES)] 98 | 99 | log_string('---- EVALUATION WHOLE SCENE----') 100 | 101 | for batch_idx in range(num_batches): 102 | print("visualize [%d/%d] %s ..." % (batch_idx+1, num_batches, scene_id[batch_idx])) 103 | total_seen_class_tmp = [0 for _ in range(NUM_CLASSES)] 104 | total_correct_class_tmp = [0 for _ in range(NUM_CLASSES)] 105 | total_iou_deno_class_tmp = [0 for _ in range(NUM_CLASSES)] 106 | if args.visual: 107 | fout = open(os.path.join(visual_dir, scene_id[batch_idx] + '_pred.obj'), 'w') 108 | fout_gt = open(os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'), 'w') 109 | 110 | whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[batch_idx] 111 | whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx] 112 | vote_label_pool = np.zeros((whole_scene_label.shape[0], NUM_CLASSES)) 113 | for _ in tqdm(range(args.num_votes), total=args.num_votes): 114 | scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx] 115 | num_blocks = scene_data.shape[0] 116 | s_batch_num = (num_blocks + BATCH_SIZE - 1) // BATCH_SIZE 117 | batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 9)) 118 | 119 | batch_label = np.zeros((BATCH_SIZE, NUM_POINT)) 120 | batch_point_index = np.zeros((BATCH_SIZE, NUM_POINT)) 121 | batch_smpw = np.zeros((BATCH_SIZE, NUM_POINT)) 122 | for sbatch in range(s_batch_num): 123 | start_idx = sbatch * BATCH_SIZE 124 | end_idx = min((sbatch + 1) * BATCH_SIZE, num_blocks) 125 | real_batch_size = end_idx - start_idx 126 | batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...] 127 | batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...] 128 | batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...] 129 | batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...] 130 | batch_data[:, :, 3:6] /= 1.0 131 | 132 | torch_data = torch.Tensor(batch_data) 133 | torch_data= torch_data.float().cuda() 134 | torch_data = torch_data.transpose(2, 1) 135 | seg_pred, _ = classifier(torch_data, False) 136 | batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy() 137 | 138 | vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...], 139 | batch_pred_label[0:real_batch_size, ...], 140 | batch_smpw[0:real_batch_size, ...]) 141 | 142 | pred_label = np.argmax(vote_label_pool, 1) 143 | 144 | for l in range(NUM_CLASSES): 145 | total_seen_class_tmp[l] += np.sum((whole_scene_label == l)) 146 | total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l)) 147 | total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l))) 148 | total_seen_class[l] += total_seen_class_tmp[l] 149 | total_correct_class[l] += total_correct_class_tmp[l] 150 | total_iou_deno_class[l] += total_iou_deno_class_tmp[l] 151 | 152 | iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=np.float) + 1e-6) 153 | print(iou_map) 154 | arr = np.array(total_seen_class_tmp) 155 | tmp_iou = np.mean(iou_map[arr != 0]) 156 | log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou)) 157 | print('----------------------------') 158 | 159 | filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt') 160 | with open(filename, 'w') as pl_save: 161 | for i in pred_label: 162 | pl_save.write(str(int(i)) + '\n') 163 | pl_save.close() 164 | for i in range(whole_scene_label.shape[0]): 165 | color = g_label2color[pred_label[i]] 166 | color_gt = g_label2color[whole_scene_label[i]] 167 | if args.visual: 168 | fout.write('v %f %f %f %d %d %d\n' % ( 169 | whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color[0], color[1], 170 | color[2])) 171 | fout_gt.write( 172 | 'v %f %f %f %d %d %d\n' % ( 173 | whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color_gt[0], 174 | color_gt[1], color_gt[2])) 175 | if args.visual: 176 | fout.close() 177 | fout_gt.close() 178 | 179 | IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6) 180 | iou_per_class_str = '------- IoU --------\n' 181 | for l in range(NUM_CLASSES): 182 | iou_per_class_str += 'class %s, IoU: %.3f \n' % ( 183 | seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), 184 | total_correct_class[l] / float(total_iou_deno_class[l])) 185 | log_string(iou_per_class_str) 186 | log_string('eval point avg class IoU: %f' % np.mean(IoU)) 187 | log_string('eval whole scene point avg class acc: %f' % ( 188 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6)))) 189 | log_string('eval whole scene point accuracy: %f' % ( 190 | np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6))) 191 | 192 | print("Done!") 193 | 194 | if __name__ == '__main__': 195 | args = parse_args() 196 | main(args) 197 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 6 | import argparse 7 | import numpy as np 8 | import os 9 | import torch 10 | import datetime 11 | import logging 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | import sys 15 | import provider 16 | import importlib 17 | import shutil 18 | 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = BASE_DIR 21 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 22 | 23 | 24 | def parse_args(): 25 | '''PARAMETERS''' 26 | parser = argparse.ArgumentParser('PointNet') 27 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in training [default: 32]') 28 | parser.add_argument('--model', default='pointnet2_cls_ssg_psn', help='model name [default: pointnet2_cls_ssg_psn]') 29 | parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training [default: 300]') 30 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training [default: 0.001]') 31 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]') 32 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]') 33 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training [default: Adam]') 34 | parser.add_argument('--log_dir', type=str, default=None, help='experiment root') 35 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate [default: 1e-4]') 36 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 37 | return parser.parse_args() 38 | 39 | def test(model, loader, num_class=40): 40 | mean_correct = [] 41 | class_acc = np.zeros((num_class,3)) 42 | for j, data in tqdm(enumerate(loader), total=len(loader)): 43 | points, target = data 44 | target = target[:, 0] 45 | points = points.transpose(2, 1) 46 | points, target = points.cuda(), target.cuda() 47 | classifier = model.eval() 48 | pred, _ = classifier(points, False) 49 | pred_choice = pred.data.max(1)[1] 50 | for cat in np.unique(target.cpu()): 51 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum() 52 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0]) 53 | class_acc[cat,1]+=1 54 | correct = pred_choice.eq(target.long().data).cpu().sum() 55 | mean_correct.append(correct.item()/float(points.size()[0])) 56 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1] 57 | class_acc = np.mean(class_acc[:,2]) 58 | instance_acc = np.mean(mean_correct) 59 | return instance_acc, class_acc 60 | 61 | 62 | def main(args): 63 | def log_string(str): 64 | logger.info(str) 65 | print(str) 66 | 67 | '''HYPER PARAMETER''' 68 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 69 | 70 | '''CREATE DIR''' 71 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 72 | experiment_dir = Path('./log/') 73 | experiment_dir.mkdir(exist_ok=True) 74 | experiment_dir = experiment_dir.joinpath('classification') 75 | experiment_dir.mkdir(exist_ok=True) 76 | if args.log_dir is None: 77 | experiment_dir = experiment_dir.joinpath(timestr) 78 | else: 79 | experiment_dir = experiment_dir.joinpath(args.log_dir) 80 | experiment_dir.mkdir(exist_ok=True) 81 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 82 | checkpoints_dir.mkdir(exist_ok=True) 83 | log_dir = experiment_dir.joinpath('logs/') 84 | log_dir.mkdir(exist_ok=True) 85 | 86 | '''LOG''' 87 | args = parse_args() 88 | logger = logging.getLogger("Model") 89 | logger.setLevel(logging.INFO) 90 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 91 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 92 | file_handler.setLevel(logging.INFO) 93 | file_handler.setFormatter(formatter) 94 | logger.addHandler(file_handler) 95 | log_string('PARAMETER ...') 96 | log_string(args) 97 | 98 | '''DATA LOADING''' 99 | log_string('Load dataset ...') 100 | DATA_PATH = 'data/modelnet40_normal_resampled/' 101 | 102 | TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', 103 | normal_channel=args.normal) 104 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', 105 | normal_channel=args.normal) 106 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4) 107 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) 108 | 109 | '''MODEL LOADING''' 110 | num_class = 40 111 | MODEL = importlib.import_module(args.model) 112 | shutil.copy('./models/%s.py' % args.model, str(experiment_dir)) 113 | shutil.copy('./models/pointnet_util_psn.py', str(experiment_dir)) 114 | 115 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda() 116 | criterion = MODEL.get_loss().cuda() 117 | 118 | try: 119 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 120 | start_epoch = checkpoint['epoch'] 121 | classifier.load_state_dict(checkpoint['model_state_dict']) 122 | log_string('Use pretrain model') 123 | except: 124 | log_string('No existing model, starting training from scratch...') 125 | start_epoch = 0 126 | 127 | 128 | if args.optimizer == 'Adam': 129 | optimizer = torch.optim.Adam( 130 | classifier.parameters(), 131 | lr=args.learning_rate, 132 | betas=(0.9, 0.999), 133 | eps=1e-08, 134 | weight_decay=args.decay_rate 135 | ) 136 | else: 137 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 138 | 139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) 140 | global_epoch = 0 141 | global_step = 0 142 | best_instance_acc = 0.0 143 | best_class_acc = 0.0 144 | mean_correct = [] 145 | best_epoch = 0 146 | 147 | '''TRANING''' 148 | logger.info('Start training...') 149 | for epoch in range(start_epoch,args.epoch): 150 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 151 | 152 | scheduler.step() 153 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 154 | points, target = data 155 | points = points.data.numpy() 156 | points = provider.random_point_dropout(points) 157 | points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) 158 | points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) 159 | points = torch.Tensor(points) 160 | target = target[:, 0] 161 | 162 | points = points.transpose(2, 1) 163 | points, target = points.cuda(), target.cuda() 164 | optimizer.zero_grad() 165 | 166 | classifier = classifier.train() 167 | pred, trans_feat = classifier(points, False) 168 | loss = criterion(pred, target.long(), trans_feat) 169 | pred_choice = pred.data.max(1)[1] 170 | correct = pred_choice.eq(target.long().data).cpu().sum() 171 | mean_correct.append(correct.item() / float(points.size()[0])) 172 | loss.backward() 173 | optimizer.step() 174 | global_step += 1 175 | 176 | train_instance_acc = np.mean(mean_correct) 177 | log_string('Train Instance Accuracy: %f' % train_instance_acc) 178 | 179 | 180 | with torch.no_grad(): 181 | instance_acc, class_acc = test(classifier.eval(), testDataLoader) 182 | 183 | if (instance_acc >= best_instance_acc): 184 | best_instance_acc = instance_acc 185 | best_epoch = epoch + 1 186 | 187 | if (class_acc >= best_class_acc): 188 | best_class_acc = class_acc 189 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc)) 190 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc)) 191 | 192 | if (instance_acc >= best_instance_acc): 193 | logger.info('Save model...') 194 | savepath = str(checkpoints_dir) + '/best_model.pth' 195 | log_string('Saving at %s'% savepath) 196 | state = { 197 | 'epoch': best_epoch, 198 | 'instance_acc': instance_acc, 199 | 'class_acc': class_acc, 200 | 'model_state_dict': classifier.state_dict(), 201 | 'optimizer_state_dict': optimizer.state_dict(), 202 | } 203 | torch.save(state, savepath) 204 | global_epoch += 1 205 | 206 | logger.info('End of training...') 207 | 208 | if __name__ == '__main__': 209 | args = parse_args() 210 | main(args) 211 | -------------------------------------------------------------------------------- /train_partseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.ShapeNetDataLoader import PartNormalDataset 8 | import torch 9 | import datetime 10 | import logging 11 | from pathlib import Path 12 | import sys 13 | import importlib 14 | import shutil 15 | from tqdm import tqdm 16 | import provider 17 | import numpy as np 18 | 19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 20 | ROOT_DIR = BASE_DIR 21 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 22 | torch.autograd.set_detect_anomaly(True) 23 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 24 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 25 | for cat in seg_classes.keys(): 26 | for label in seg_classes[cat]: 27 | seg_label_to_cat[label] = cat 28 | 29 | def to_categorical(y, num_classes): 30 | """ 1-hot encodes a tensor """ 31 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 32 | if (y.is_cuda): 33 | return new_y.cuda() 34 | return new_y 35 | 36 | 37 | def parse_args(): 38 | parser = argparse.ArgumentParser('Model') 39 | parser.add_argument('--model', type=str, default='pointnet2_part_seg_ssg_psn', help='model name [default: pointnet2_part_seg_ssg_psn]') 40 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 16]') 41 | parser.add_argument('--epoch', default=251, type=int, help='Epoch to run [default: 251]') 42 | parser.add_argument('--learning_rate', default=0.001, type=float, help='Initial learning rate [default: 0.001]') 43 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]') 44 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]') 45 | parser.add_argument('--log_dir', type=str, default=None, help='Log path [default: None]') 46 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]') 47 | parser.add_argument('--npoint', type=int, default=2048, help='Point Number [default: 2048]') 48 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 49 | parser.add_argument('--step_size', type=int, default=20, help='Decay step for lr decay [default: every 20 epochs]') 50 | parser.add_argument('--lr_decay', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') 51 | 52 | return parser.parse_args() 53 | 54 | def main(args): 55 | def log_string(str): 56 | logger.info(str) 57 | print(str) 58 | 59 | '''HYPER PARAMETER''' 60 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 61 | 62 | '''CREATE DIR''' 63 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 64 | experiment_dir = Path('./log/') 65 | experiment_dir.mkdir(exist_ok=True) 66 | experiment_dir = experiment_dir.joinpath('part_seg') 67 | experiment_dir.mkdir(exist_ok=True) 68 | if args.log_dir is None: 69 | experiment_dir = experiment_dir.joinpath(timestr) 70 | else: 71 | experiment_dir = experiment_dir.joinpath(args.log_dir) 72 | experiment_dir.mkdir(exist_ok=True) 73 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 74 | checkpoints_dir.mkdir(exist_ok=True) 75 | log_dir = experiment_dir.joinpath('logs/') 76 | log_dir.mkdir(exist_ok=True) 77 | 78 | '''LOG''' 79 | args = parse_args() 80 | logger = logging.getLogger("Model") 81 | logger.setLevel(logging.INFO) 82 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 83 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 84 | file_handler.setLevel(logging.INFO) 85 | file_handler.setFormatter(formatter) 86 | logger.addHandler(file_handler) 87 | log_string('PARAMETER ...') 88 | log_string(args) 89 | 90 | root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' 91 | 92 | TRAIN_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='trainval', normal_channel=args.normal) 93 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size,shuffle=True, num_workers=4) 94 | TEST_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='test', normal_channel=args.normal) 95 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4) 96 | log_string("The number of training data is: %d" % len(TRAIN_DATASET)) 97 | log_string("The number of test data is: %d" % len(TEST_DATASET)) 98 | num_classes = 16 99 | num_part = 50 100 | '''MODEL LOADING''' 101 | MODEL = importlib.import_module(args.model) 102 | shutil.copy('models/%s.py' % args.model, str(experiment_dir)) 103 | shutil.copy('models/pointnet_util_psn.py', str(experiment_dir)) 104 | 105 | classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda() 106 | criterion = MODEL.get_loss().cuda() 107 | 108 | 109 | def weights_init(m): 110 | classname = m.__class__.__name__ 111 | if classname.find('Conv2d') != -1: 112 | torch.nn.init.xavier_normal_(m.weight.data) 113 | torch.nn.init.constant_(m.bias.data, 0.0) 114 | elif classname.find('Linear') != -1: 115 | torch.nn.init.xavier_normal_(m.weight.data) 116 | torch.nn.init.constant_(m.bias.data, 0.0) 117 | 118 | try: 119 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 120 | start_epoch = checkpoint['epoch'] 121 | classifier.load_state_dict(checkpoint['model_state_dict']) 122 | log_string('Use pretrain model') 123 | except: 124 | log_string('No existing model, starting training from scratch...') 125 | start_epoch = 0 126 | classifier = classifier.apply(weights_init) 127 | 128 | if args.optimizer == 'Adam': 129 | optimizer = torch.optim.Adam( 130 | classifier.parameters(), 131 | lr=args.learning_rate, 132 | betas=(0.9, 0.999), 133 | eps=1e-08, 134 | weight_decay=args.decay_rate 135 | ) 136 | else: 137 | optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9) 138 | 139 | def bn_momentum_adjust(m, momentum): 140 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 141 | m.momentum = momentum 142 | 143 | LEARNING_RATE_CLIP = 1e-7 144 | MOMENTUM_ORIGINAL = 0.1 145 | MOMENTUM_DECCAY = 0.5 146 | MOMENTUM_DECCAY_STEP = args.step_size 147 | 148 | best_acc = 0 149 | global_epoch = 0 150 | best_class_avg_iou = 0 151 | best_inctance_avg_iou = 0 152 | 153 | for epoch in range(start_epoch,args.epoch): 154 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 155 | '''Adjust learning rate and BN momentum''' 156 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) 157 | log_string('Learning rate:%f' % lr) 158 | for param_group in optimizer.param_groups: 159 | param_group['lr'] = lr 160 | mean_correct = [] 161 | momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 162 | if momentum < 0.01: 163 | momentum = 0.01 164 | print('BN momentum updated to: %f' % momentum) 165 | classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum)) 166 | 167 | '''learning one epoch''' 168 | for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 169 | points, label, target = data 170 | points = points.data.numpy() 171 | points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3]) 172 | points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3]) 173 | points = torch.Tensor(points) 174 | points, label, target = points.float().cuda(),label.long().cuda(), target.long().cuda() 175 | points = points.transpose(2, 1) 176 | optimizer.zero_grad() 177 | classifier = classifier.train() 178 | seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes), False) 179 | seg_pred = seg_pred.contiguous().view(-1, num_part) 180 | target = target.view(-1, 1)[:, 0] 181 | pred_choice = seg_pred.data.max(1)[1] 182 | correct = pred_choice.eq(target.data).cpu().sum() 183 | mean_correct.append(correct.item() / (args.batch_size * args.npoint)) 184 | loss = criterion(seg_pred, target, trans_feat) 185 | loss.backward() 186 | optimizer.step() 187 | train_instance_acc = np.mean(mean_correct) 188 | log_string('Train accuracy is: %.5f' % train_instance_acc) 189 | 190 | with torch.no_grad(): 191 | test_metrics = {} 192 | total_correct = 0 193 | total_seen = 0 194 | total_seen_class = [0 for _ in range(num_part)] 195 | total_correct_class = [0 for _ in range(num_part)] 196 | shape_ious = {cat: [] for cat in seg_classes.keys()} 197 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 198 | for cat in seg_classes.keys(): 199 | for label in seg_classes[cat]: 200 | seg_label_to_cat[label] = cat 201 | 202 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 203 | tiaoshi = (epoch, batch_id) 204 | cur_batch_size, NUM_POINT, _ = points.size() 205 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() 206 | points = points.transpose(2, 1) 207 | classifier = classifier.eval() 208 | seg_pred, _ = classifier(points, to_categorical(label, num_classes), False) 209 | cur_pred_val = seg_pred.cpu().data.numpy() 210 | cur_pred_val_logits = cur_pred_val 211 | cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32) 212 | target = target.cpu().data.numpy() 213 | for i in range(cur_batch_size): 214 | cat = seg_label_to_cat[target[i, 0]] 215 | logits = cur_pred_val_logits[i, :, :] 216 | cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0] 217 | correct = np.sum(cur_pred_val == target) 218 | total_correct += correct 219 | total_seen += (cur_batch_size * NUM_POINT) 220 | 221 | for l in range(num_part): 222 | total_seen_class[l] += np.sum(target == l) 223 | total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l))) 224 | 225 | for i in range(cur_batch_size): 226 | segp = cur_pred_val[i, :] 227 | segl = target[i, :] 228 | cat = seg_label_to_cat[segl[0]] 229 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))] 230 | for l in seg_classes[cat]: 231 | if (np.sum(segl == l) == 0) and ( 232 | np.sum(segp == l) == 0): # part is not present, no prediction as well 233 | part_ious[l - seg_classes[cat][0]] = 1.0 234 | else: 235 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float( 236 | np.sum((segl == l) | (segp == l))) 237 | shape_ious[cat].append(np.mean(part_ious)) 238 | 239 | all_shape_ious = [] 240 | for cat in shape_ious.keys(): 241 | for iou in shape_ious[cat]: 242 | all_shape_ious.append(iou) 243 | shape_ious[cat] = np.mean(shape_ious[cat]) 244 | mean_shape_ious = np.mean(list(shape_ious.values())) 245 | test_metrics['accuracy'] = total_correct / float(total_seen) 246 | test_metrics['class_avg_accuracy'] = np.mean( 247 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float)) 248 | for cat in sorted(shape_ious.keys()): 249 | log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat])) 250 | test_metrics['class_avg_iou'] = mean_shape_ious 251 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious) 252 | 253 | 254 | log_string('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( 255 | epoch+1, test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou'])) 256 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou): 257 | logger.info('Save model...') 258 | savepath = str(checkpoints_dir) + '/best_model.pth' 259 | log_string('Saving at %s'% savepath) 260 | state = { 261 | 'epoch': epoch, 262 | 'train_acc': train_instance_acc, 263 | 'test_acc': test_metrics['accuracy'], 264 | 'class_avg_iou': test_metrics['class_avg_iou'], 265 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'], 266 | 'model_state_dict': classifier.state_dict(), 267 | 'optimizer_state_dict': optimizer.state_dict(), 268 | } 269 | torch.save(state, savepath) 270 | log_string('Saving model....') 271 | 272 | if test_metrics['accuracy'] > best_acc: 273 | best_acc = test_metrics['accuracy'] 274 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 275 | best_class_avg_iou = test_metrics['class_avg_iou'] 276 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 277 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 278 | log_string('Best accuracy is: %.5f'%best_acc) 279 | log_string('Best class avg mIOU is: %.5f'%best_class_avg_iou) 280 | log_string('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou) 281 | global_epoch+=1 282 | 283 | if __name__ == '__main__': 284 | args = parse_args() 285 | main(args) 286 | 287 | -------------------------------------------------------------------------------- /train_semseg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Benny 3 | Date: Nov 2019 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.S3DISDataLoader import S3DISDataset 8 | import torch 9 | import datetime 10 | import logging 11 | from pathlib import Path 12 | import sys 13 | import importlib 14 | import shutil 15 | from tqdm import tqdm 16 | import provider 17 | import numpy as np 18 | import time 19 | 20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | ROOT_DIR = BASE_DIR 22 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 23 | 24 | 25 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 26 | class2label = {cls: i for i,cls in enumerate(classes)} 27 | seg_classes = class2label 28 | seg_label_to_cat = {} 29 | for i,cat in enumerate(seg_classes.keys()): 30 | seg_label_to_cat[i] = cat 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser('Model') 35 | parser.add_argument('--model', type=str, default='pointnet_sem_seg_psn', help='model name [default: pointnet_sem_seg_psn]') 36 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') 37 | parser.add_argument('--epoch', default=500, type=int, help='Epoch to run [default: 500]') 38 | parser.add_argument('--learning_rate', default=0.001, type=float, help='Initial learning rate [default: 0.001]') 39 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]') 40 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]') 41 | parser.add_argument('--log_dir', type=str, default=None, help='Log path [default: None]') 42 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]') 43 | parser.add_argument('--npoint', type=int, default=4096, help='Point Number [default: 4096]') 44 | parser.add_argument('--step_size', type=int, default=10, help='Decay step for lr decay [default: every 10 epochs]') 45 | parser.add_argument('--lr_decay', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]') 46 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]') 47 | 48 | return parser.parse_args() 49 | 50 | def main(args): 51 | def log_string(str): 52 | logger.info(str) 53 | print(str) 54 | 55 | '''HYPER PARAMETER''' 56 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 57 | 58 | '''CREATE DIR''' 59 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 60 | experiment_dir = Path('./log/') 61 | experiment_dir.mkdir(exist_ok=True) 62 | experiment_dir = experiment_dir.joinpath('sem_seg') 63 | experiment_dir.mkdir(exist_ok=True) 64 | if args.log_dir is None: 65 | experiment_dir = experiment_dir.joinpath(timestr) 66 | else: 67 | experiment_dir = experiment_dir.joinpath(args.log_dir) 68 | experiment_dir.mkdir(exist_ok=True) 69 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 70 | checkpoints_dir.mkdir(exist_ok=True) 71 | log_dir = experiment_dir.joinpath('logs/') 72 | log_dir.mkdir(exist_ok=True) 73 | 74 | '''LOG''' 75 | args = parse_args() 76 | logger = logging.getLogger("Model") 77 | logger.setLevel(logging.INFO) 78 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 79 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 80 | file_handler.setLevel(logging.INFO) 81 | file_handler.setFormatter(formatter) 82 | logger.addHandler(file_handler) 83 | log_string('PARAMETER ...') 84 | log_string(args) 85 | 86 | root = 'data/stanford_indoor3d/' 87 | NUM_CLASSES = 13 88 | NUM_POINT = args.npoint 89 | BATCH_SIZE = args.batch_size 90 | 91 | print("start loading training data ...") 92 | TRAIN_DATASET = S3DISDataset(split='train', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None) 93 | print("start loading test data ...") 94 | TEST_DATASET = S3DISDataset(split='test', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None) 95 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True, worker_init_fn = lambda x: np.random.seed(x+int(time.time()))) 96 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) 97 | weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda() 98 | 99 | log_string("The number of training data is: %d" % len(TRAIN_DATASET)) 100 | log_string("The number of test data is: %d" % len(TEST_DATASET)) 101 | 102 | '''MODEL LOADING''' 103 | MODEL = importlib.import_module(args.model) 104 | shutil.copy('models/%s.py' % args.model, str(experiment_dir)) 105 | shutil.copy('models/pointnet_util_psn.py', str(experiment_dir)) 106 | 107 | classifier = MODEL.get_model(NUM_CLASSES).cuda() 108 | criterion = MODEL.get_loss().cuda() 109 | 110 | def weights_init(m): 111 | classname = m.__class__.__name__ 112 | if classname.find('Conv2d') != -1: 113 | torch.nn.init.xavier_normal_(m.weight.data) 114 | torch.nn.init.constant_(m.bias.data, 0.0) 115 | elif classname.find('Linear') != -1: 116 | torch.nn.init.xavier_normal_(m.weight.data) 117 | torch.nn.init.constant_(m.bias.data, 0.0) 118 | 119 | try: 120 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 121 | start_epoch = checkpoint['epoch'] 122 | classifier.load_state_dict(checkpoint['model_state_dict']) 123 | log_string('Use pretrain model') 124 | except: 125 | log_string('No existing model, starting training from scratch...') 126 | start_epoch = 0 127 | classifier = classifier.apply(weights_init) 128 | 129 | if args.optimizer == 'Adam': 130 | optimizer = torch.optim.Adam( 131 | classifier.parameters(), 132 | lr=args.learning_rate, 133 | betas=(0.9, 0.999), 134 | eps=1e-08, 135 | weight_decay=args.decay_rate 136 | ) 137 | else: 138 | optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9) 139 | 140 | def bn_momentum_adjust(m, momentum): 141 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d): 142 | m.momentum = momentum 143 | 144 | LEARNING_RATE_CLIP = 1e-5 145 | MOMENTUM_ORIGINAL = 0.1 146 | MOMENTUM_DECCAY = 0.5 147 | MOMENTUM_DECCAY_STEP = args.step_size 148 | 149 | global_epoch = 0 150 | best_iou = 0 151 | 152 | for epoch in range(start_epoch,args.epoch): 153 | '''Train on chopped scenes''' 154 | log_string('**** Epoch %d (%d/%s) ****' % (global_epoch + 1, epoch + 1, args.epoch)) 155 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) 156 | log_string('Learning rate:%f' % lr) 157 | for param_group in optimizer.param_groups: 158 | param_group['lr'] = lr 159 | momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP)) 160 | if momentum < 0.01: 161 | momentum = 0.01 162 | print('BN momentum updated to: %f' % momentum) 163 | classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum)) 164 | num_batches = len(trainDataLoader) 165 | total_correct = 0 166 | total_seen = 0 167 | loss_sum = 0 168 | for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 169 | points, target = data 170 | points = points.data.numpy() 171 | points[:,:, :3] = provider.rotate_point_cloud_z(points[:,:, :3]) 172 | points = torch.Tensor(points) 173 | points, target = points.float().cuda(),target.long().cuda() 174 | points = points.transpose(2, 1) 175 | optimizer.zero_grad() 176 | classifier = classifier.train() 177 | seg_pred, trans_feat = classifier(points, False) 178 | seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES) 179 | batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy() 180 | target = target.view(-1, 1)[:, 0] 181 | loss = criterion(seg_pred, target, trans_feat, weights) 182 | loss.backward() 183 | optimizer.step() 184 | pred_choice = seg_pred.cpu().data.max(1)[1].numpy() 185 | correct = np.sum(pred_choice == batch_label) 186 | total_correct += correct 187 | total_seen += (BATCH_SIZE * NUM_POINT) 188 | loss_sum += loss 189 | log_string('Training mean loss: %f' % (loss_sum / num_batches)) 190 | log_string('Training accuracy: %f' % (total_correct / float(total_seen))) 191 | 192 | if epoch % 5 == 0: 193 | logger.info('Save model...') 194 | savepath = str(checkpoints_dir) + '/model.pth' 195 | log_string('Saving at %s' % savepath) 196 | state = { 197 | 'epoch': epoch, 198 | 'model_state_dict': classifier.state_dict(), 199 | 'optimizer_state_dict': optimizer.state_dict(), 200 | } 201 | torch.save(state, savepath) 202 | log_string('Saving model....') 203 | 204 | '''Evaluate on chopped scenes''' 205 | with torch.no_grad(): 206 | num_batches = len(testDataLoader) 207 | total_correct = 0 208 | total_seen = 0 209 | loss_sum = 0 210 | labelweights = np.zeros(NUM_CLASSES) 211 | total_seen_class = [0 for _ in range(NUM_CLASSES)] 212 | total_correct_class = [0 for _ in range(NUM_CLASSES)] 213 | total_iou_deno_class = [0 for _ in range(NUM_CLASSES)] 214 | log_string('---- EPOCH %03d EVALUATION ----' % (global_epoch + 1)) 215 | for i, data in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 216 | points, target = data 217 | points = points.data.numpy() 218 | points = torch.Tensor(points) 219 | points, target = points.float().cuda(), target.long().cuda() 220 | points = points.transpose(2, 1) 221 | classifier = classifier.eval() 222 | seg_pred, trans_feat = classifier(points, False) 223 | pred_val = seg_pred.contiguous().cpu().data.numpy() 224 | seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES) 225 | batch_label = target.cpu().data.numpy() 226 | target = target.view(-1, 1)[:, 0] 227 | loss = criterion(seg_pred, target, trans_feat, weights) 228 | loss_sum += loss 229 | pred_val = np.argmax(pred_val, 2) 230 | correct = np.sum((pred_val == batch_label)) 231 | total_correct += correct 232 | total_seen += (BATCH_SIZE * NUM_POINT) 233 | tmp, _ = np.histogram(batch_label, range(NUM_CLASSES + 1)) 234 | labelweights += tmp 235 | for l in range(NUM_CLASSES): 236 | total_seen_class[l] += np.sum((batch_label == l) ) 237 | total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l) ) 238 | total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)) ) 239 | labelweights = labelweights.astype(np.float32) / np.sum(labelweights.astype(np.float32)) 240 | mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6)) 241 | log_string('eval mean loss: %f' % (loss_sum / float(num_batches))) 242 | log_string('eval point avg class IoU: %f' % (mIoU)) 243 | log_string('eval point accuracy: %f' % (total_correct / float(total_seen))) 244 | log_string('eval point avg class acc: %f' % ( 245 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6)))) 246 | iou_per_class_str = '------- IoU --------\n' 247 | for l in range(NUM_CLASSES): 248 | iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % ( 249 | seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1], 250 | total_correct_class[l] / float(total_iou_deno_class[l])) 251 | 252 | log_string(iou_per_class_str) 253 | log_string('Eval mean loss: %f' % (loss_sum / num_batches)) 254 | log_string('Eval accuracy: %f' % (total_correct / float(total_seen))) 255 | if mIoU >= best_iou: 256 | best_iou = mIoU 257 | logger.info('Save model...') 258 | savepath = str(checkpoints_dir) + '/best_model.pth' 259 | log_string('Saving at %s' % savepath) 260 | state = { 261 | 'epoch': epoch, 262 | 'class_avg_iou': mIoU, 263 | 'model_state_dict': classifier.state_dict(), 264 | 'optimizer_state_dict': optimizer.state_dict(), 265 | } 266 | torch.save(state, savepath) 267 | log_string('Saving model....') 268 | log_string('Best mIoU: %f' % best_iou) 269 | global_epoch += 1 270 | 271 | 272 | if __name__ == '__main__': 273 | args = parse_args() 274 | main(args) 275 | 276 | -------------------------------------------------------------------------------- /visualizer/build.sh: -------------------------------------------------------------------------------- 1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 2 | -------------------------------------------------------------------------------- /visualizer/eulerangles.py: -------------------------------------------------------------------------------- 1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 | # vi: set ft=python sts=4 ts=4 sw=4 et: 3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 | # 5 | # See COPYING file distributed along with the NiBabel package for the 6 | # copyright and license terms. 7 | # 8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 | ''' Module implementing Euler angle rotations and their conversions 10 | See: 11 | * http://en.wikipedia.org/wiki/Rotation_matrix 12 | * http://en.wikipedia.org/wiki/Euler_angles 13 | * http://mathworld.wolfram.com/EulerAngles.html 14 | See also: *Representing Attitude with Euler Angles and Quaternions: A 15 | Reference* (2006) by James Diebel. A cached PDF link last found here: 16 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 17 | Euler's rotation theorem tells us that any rotation in 3D can be 18 | described by 3 angles. Let's call the 3 angles the *Euler angle vector* 19 | and call the angles in the vector :math:`alpha`, :math:`beta` and 20 | :math:`gamma`. The vector is [ :math:`alpha`, 21 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the 22 | parameters specifies the order in which the rotations occur (so the 23 | rotation corresponding to :math:`alpha` is applied first). 24 | In order to specify the meaning of an *Euler angle vector* we need to 25 | specify the axes around which each of the rotations corresponding to 26 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur. 27 | There are therefore three axes for the rotations :math:`alpha`, 28 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, 29 | :math:`k`. 30 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 31 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 32 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the 33 | whole rotation expressed by the Euler angle vector [ :math:`alpha`, 34 | :math:`beta`. :math:`gamma` ], `R` is given by:: 35 | R = np.dot(G, np.dot(B, A)) 36 | See http://mathworld.wolfram.com/EulerAngles.html 37 | The order :math:`G B A` expresses the fact that the rotations are 38 | performed in the order of the vector (:math:`alpha` around axis `i` = 39 | `A` first). 40 | To convert a given Euler angle vector to a meaningful rotation, and a 41 | rotation matrix, we need to define: 42 | * the axes `i`, `j`, `k` 43 | * whether a rotation matrix should be applied on the left of a vector to 44 | be transformed (vectors are column vectors) or on the right (vectors 45 | are row vectors). 46 | * whether the rotations move the axes as they are applied (intrinsic 47 | rotations) - compared the situation where the axes stay fixed and the 48 | vectors move within the axis frame (extrinsic) 49 | * the handedness of the coordinate system 50 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities 51 | We are using the following conventions: 52 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus 53 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] 54 | in our convention implies a :math:`alpha` radian rotation around the 55 | `z` axis, followed by a :math:`beta` rotation around the `y` axis, 56 | followed by a :math:`gamma` rotation around the `x` axis. 57 | * the rotation matrix applies on the left, to column vectors on the 58 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix 59 | with N column vectors, the transformed vector set `vdash` is given by 60 | ``vdash = np.dot(R, v)``. 61 | * extrinsic rotations - the axes are fixed, and do not move with the 62 | rotations. 63 | * a right-handed coordinate system 64 | The convention of rotation around ``z``, followed by rotation around 65 | ``y``, followed by rotation around ``x``, is known (confusingly) as 66 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. 67 | ''' 68 | 69 | import math 70 | 71 | import sys 72 | if sys.version_info >= (3,0): 73 | from functools import reduce 74 | 75 | import numpy as np 76 | 77 | 78 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0 79 | 80 | 81 | def euler2mat(z=0, y=0, x=0): 82 | ''' Return matrix for rotations around z, y and x axes 83 | Uses the z, then y, then x convention above 84 | Parameters 85 | ---------- 86 | z : scalar 87 | Rotation angle in radians around z-axis (performed first) 88 | y : scalar 89 | Rotation angle in radians around y-axis 90 | x : scalar 91 | Rotation angle in radians around x-axis (performed last) 92 | Returns 93 | ------- 94 | M : array shape (3,3) 95 | Rotation matrix giving same rotation as for given angles 96 | Examples 97 | -------- 98 | >>> zrot = 1.3 # radians 99 | >>> yrot = -0.1 100 | >>> xrot = 0.2 101 | >>> M = euler2mat(zrot, yrot, xrot) 102 | >>> M.shape == (3, 3) 103 | True 104 | The output rotation matrix is equal to the composition of the 105 | individual rotations 106 | >>> M1 = euler2mat(zrot) 107 | >>> M2 = euler2mat(0, yrot) 108 | >>> M3 = euler2mat(0, 0, xrot) 109 | >>> composed_M = np.dot(M3, np.dot(M2, M1)) 110 | >>> np.allclose(M, composed_M) 111 | True 112 | You can specify rotations by named arguments 113 | >>> np.all(M3 == euler2mat(x=xrot)) 114 | True 115 | When applying M to a vector, the vector should column vector to the 116 | right of M. If the right hand side is a 2D array rather than a 117 | vector, then each column of the 2D array represents a vector. 118 | >>> vec = np.array([1, 0, 0]).reshape((3,1)) 119 | >>> v2 = np.dot(M, vec) 120 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array 121 | >>> vecs2 = np.dot(M, vecs) 122 | Rotations are counter-clockwise. 123 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) 124 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) 125 | True 126 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) 127 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) 128 | True 129 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) 130 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) 131 | True 132 | Notes 133 | ----- 134 | The direction of rotation is given by the right-hand rule (orient 135 | the thumb of the right hand along the axis around which the rotation 136 | occurs, with the end of the thumb at the positive end of the axis; 137 | curl your fingers; the direction your fingers curl is the direction 138 | of rotation). Therefore, the rotations are counterclockwise if 139 | looking along the axis of rotation from positive to negative. 140 | ''' 141 | Ms = [] 142 | if z: 143 | cosz = math.cos(z) 144 | sinz = math.sin(z) 145 | Ms.append(np.array( 146 | [[cosz, -sinz, 0], 147 | [sinz, cosz, 0], 148 | [0, 0, 1]])) 149 | if y: 150 | cosy = math.cos(y) 151 | siny = math.sin(y) 152 | Ms.append(np.array( 153 | [[cosy, 0, siny], 154 | [0, 1, 0], 155 | [-siny, 0, cosy]])) 156 | if x: 157 | cosx = math.cos(x) 158 | sinx = math.sin(x) 159 | Ms.append(np.array( 160 | [[1, 0, 0], 161 | [0, cosx, -sinx], 162 | [0, sinx, cosx]])) 163 | if Ms: 164 | return reduce(np.dot, Ms[::-1]) 165 | return np.eye(3) 166 | 167 | 168 | def mat2euler(M, cy_thresh=None): 169 | ''' Discover Euler angle vector from 3x3 matrix 170 | Uses the conventions above. 171 | Parameters 172 | ---------- 173 | M : array-like, shape (3,3) 174 | cy_thresh : None or scalar, optional 175 | threshold below which to give up on straightforward arctan for 176 | estimating x rotation. If None (default), estimate from 177 | precision of input. 178 | Returns 179 | ------- 180 | z : scalar 181 | y : scalar 182 | x : scalar 183 | Rotations in radians around z, y, x axes, respectively 184 | Notes 185 | ----- 186 | If there was no numerical error, the routine could be derived using 187 | Sympy expression for z then y then x rotation matrix, which is:: 188 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], 189 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], 190 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] 191 | with the obvious derivations for z, y, and x 192 | z = atan2(-r12, r11) 193 | y = asin(r13) 194 | x = atan2(-r23, r33) 195 | Problems arise when cos(y) is close to zero, because both of:: 196 | z = atan2(cos(y)*sin(z), cos(y)*cos(z)) 197 | x = atan2(cos(y)*sin(x), cos(x)*cos(y)) 198 | will be close to atan2(0, 0), and highly unstable. 199 | The ``cy`` fix for numerical instability below is from: *Graphics 200 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: 201 | 0123361559. Specifically it comes from EulerAngles.c by Ken 202 | Shoemake, and deals with the case where cos(y) is close to zero: 203 | See: http://www.graphicsgems.org/ 204 | The code appears to be licensed (from the website) as "can be used 205 | without restrictions". 206 | ''' 207 | M = np.asarray(M) 208 | if cy_thresh is None: 209 | try: 210 | cy_thresh = np.finfo(M.dtype).eps * 4 211 | except ValueError: 212 | cy_thresh = _FLOAT_EPS_4 213 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat 214 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 215 | cy = math.sqrt(r33*r33 + r23*r23) 216 | if cy > cy_thresh: # cos(y) not close to zero, standard form 217 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) 218 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 219 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) 220 | else: # cos(y) (close to) zero, so x -> 0.0 (see above) 221 | # so r21 -> sin(z), r22 -> cos(z) and 222 | z = math.atan2(r21, r22) 223 | y = math.atan2(r13, cy) # atan2(sin(y), cy) 224 | x = 0.0 225 | return z, y, x 226 | 227 | 228 | def euler2quat(z=0, y=0, x=0): 229 | ''' Return quaternion corresponding to these Euler angles 230 | Uses the z, then y, then x convention above 231 | Parameters 232 | ---------- 233 | z : scalar 234 | Rotation angle in radians around z-axis (performed first) 235 | y : scalar 236 | Rotation angle in radians around y-axis 237 | x : scalar 238 | Rotation angle in radians around x-axis (performed last) 239 | Returns 240 | ------- 241 | quat : array shape (4,) 242 | Quaternion in w, x, y z (real, then vector) format 243 | Notes 244 | ----- 245 | We can derive this formula in Sympy using: 246 | 1. Formula giving quaternion corresponding to rotation of theta radians 247 | about arbitrary axis: 248 | http://mathworld.wolfram.com/EulerParameters.html 249 | 2. Generated formulae from 1.) for quaternions corresponding to 250 | theta radians rotations about ``x, y, z`` axes 251 | 3. Apply quaternion multiplication formula - 252 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to 253 | formulae from 2.) to give formula for combined rotations. 254 | ''' 255 | z = z/2.0 256 | y = y/2.0 257 | x = x/2.0 258 | cz = math.cos(z) 259 | sz = math.sin(z) 260 | cy = math.cos(y) 261 | sy = math.sin(y) 262 | cx = math.cos(x) 263 | sx = math.sin(x) 264 | return np.array([ 265 | cx*cy*cz - sx*sy*sz, 266 | cx*sy*sz + cy*cz*sx, 267 | cx*cz*sy - sx*cy*sz, 268 | cx*cy*sz + sx*cz*sy]) 269 | 270 | 271 | def quat2euler(q): 272 | ''' Return Euler angles corresponding to quaternion `q` 273 | Parameters 274 | ---------- 275 | q : 4 element sequence 276 | w, x, y, z of quaternion 277 | Returns 278 | ------- 279 | z : scalar 280 | Rotation angle in radians around z-axis (performed first) 281 | y : scalar 282 | Rotation angle in radians around y-axis 283 | x : scalar 284 | Rotation angle in radians around x-axis (performed last) 285 | Notes 286 | ----- 287 | It's possible to reduce the amount of calculation a little, by 288 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but 289 | the reduction in computation is small, and the code repetition is 290 | large. 291 | ''' 292 | # delayed import to avoid cyclic dependencies 293 | import nibabel.quaternions as nq 294 | return mat2euler(nq.quat2mat(q)) 295 | 296 | 297 | def euler2angle_axis(z=0, y=0, x=0): 298 | ''' Return angle, axis corresponding to these Euler angles 299 | Uses the z, then y, then x convention above 300 | Parameters 301 | ---------- 302 | z : scalar 303 | Rotation angle in radians around z-axis (performed first) 304 | y : scalar 305 | Rotation angle in radians around y-axis 306 | x : scalar 307 | Rotation angle in radians around x-axis (performed last) 308 | Returns 309 | ------- 310 | theta : scalar 311 | angle of rotation 312 | vector : array shape (3,) 313 | axis around which rotation occurs 314 | Examples 315 | -------- 316 | >>> theta, vec = euler2angle_axis(0, 1.5, 0) 317 | >>> print(theta) 318 | 1.5 319 | >>> np.allclose(vec, [0, 1, 0]) 320 | True 321 | ''' 322 | # delayed import to avoid cyclic dependencies 323 | import nibabel.quaternions as nq 324 | return nq.quat2angle_axis(euler2quat(z, y, x)) 325 | 326 | 327 | def angle_axis2euler(theta, vector, is_normalized=False): 328 | ''' Convert angle, axis pair to Euler angles 329 | Parameters 330 | ---------- 331 | theta : scalar 332 | angle of rotation 333 | vector : 3 element sequence 334 | vector specifying axis for rotation. 335 | is_normalized : bool, optional 336 | True if vector is already normalized (has norm of 1). Default 337 | False 338 | Returns 339 | ------- 340 | z : scalar 341 | y : scalar 342 | x : scalar 343 | Rotations in radians around z, y, x axes, respectively 344 | Examples 345 | -------- 346 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) 347 | >>> np.allclose((z, y, x), 0) 348 | True 349 | Notes 350 | ----- 351 | It's possible to reduce the amount of calculation a little, by 352 | combining parts of the ``angle_axis2mat`` and ``mat2euler`` 353 | functions, but the reduction in computation is small, and the code 354 | repetition is large. 355 | ''' 356 | # delayed import to avoid cyclic dependencies 357 | import nibabel.quaternions as nq 358 | M = nq.angle_axis2mat(theta, vector, is_normalized) 359 | return mat2euler(M) -------------------------------------------------------------------------------- /visualizer/pc_utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions for processing point clouds. 2 | Author: Charles R. Qi, Hao Su 3 | Date: November 2016 4 | """ 5 | 6 | import os 7 | import sys 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | 12 | # Draw point cloud 13 | from visualizer.eulerangles import euler2mat 14 | 15 | # Point cloud IO 16 | import numpy as np 17 | from visualizer.plyfile import PlyData, PlyElement 18 | 19 | # ---------------------------------------- 20 | # Point Cloud/Volume Conversions 21 | # ---------------------------------------- 22 | 23 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): 24 | """ Input is BxNx3 batch of point cloud 25 | Output is Bx(vsize^3) 26 | """ 27 | vol_list = [] 28 | for b in range(point_clouds.shape[0]): 29 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b, :, :]), vsize, radius) 30 | if flatten: 31 | vol_list.append(vol.flatten()) 32 | else: 33 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) 34 | if flatten: 35 | return np.vstack(vol_list) 36 | else: 37 | return np.concatenate(vol_list, 0) 38 | 39 | 40 | def point_cloud_to_volume(points, vsize, radius=1.0): 41 | """ input is Nx3 points. 42 | output is vsize*vsize*vsize 43 | assumes points are in range [-radius, radius] 44 | """ 45 | vol = np.zeros((vsize, vsize, vsize)) 46 | voxel = 2 * radius / float(vsize) 47 | locations = (points + radius) / voxel 48 | locations = locations.astype(int) 49 | vol[locations[:, 0], locations[:, 1], locations[:, 2]] = 1.0 50 | return vol 51 | 52 | 53 | # a = np.zeros((16,1024,3)) 54 | # print point_cloud_to_volume_batch(a, 12, 1.0, False).shape 55 | 56 | def volume_to_point_cloud(vol): 57 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize 58 | return Nx3 numpy array. 59 | """ 60 | vsize = vol.shape[0] 61 | assert (vol.shape[1] == vsize and vol.shape[1] == vsize) 62 | points = [] 63 | for a in range(vsize): 64 | for b in range(vsize): 65 | for c in range(vsize): 66 | if vol[a, b, c] == 1: 67 | points.append(np.array([a, b, c])) 68 | if len(points) == 0: 69 | return np.zeros((0, 3)) 70 | points = np.vstack(points) 71 | return points 72 | 73 | 74 | # ---------------------------------------- 75 | # Point cloud IO 76 | # ---------------------------------------- 77 | 78 | def read_ply(filename): 79 | """ read XYZ point cloud from filename PLY file """ 80 | plydata = PlyData.read(filename) 81 | pc = plydata['vertex'].data 82 | pc_array = np.array([[x, y, z] for x, y, z in pc]) 83 | return pc_array 84 | 85 | 86 | def write_ply(points, filename, text=True): 87 | """ input: Nx3, write points to filename as PLY format. """ 88 | points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])] 89 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 90 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 91 | PlyData([el], text=text).write(filename) 92 | 93 | 94 | # ---------------------------------------- 95 | # Simple Point cloud and Volume Renderers 96 | # ---------------------------------------- 97 | 98 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, 99 | xrot=0, yrot=0, zrot=0, switch_xyz=[0, 1, 2], normalize=True): 100 | """ Render point cloud to image with alpha channel. 101 | Input: 102 | points: Nx3 numpy array (+y is up direction) 103 | Output: 104 | gray image as numpy array of size canvasSizexcanvasSize 105 | """ 106 | image = np.zeros((canvasSize, canvasSize)) 107 | if input_points is None or input_points.shape[0] == 0: 108 | return image 109 | 110 | points = input_points[:, switch_xyz] 111 | M = euler2mat(zrot, yrot, xrot) 112 | points = (np.dot(M, points.transpose())).transpose() 113 | 114 | # Normalize the point cloud 115 | # We normalize scale to fit points in a unit sphere 116 | if normalize: 117 | centroid = np.mean(points, axis=0) 118 | points -= centroid 119 | furthest_distance = np.max(np.sqrt(np.sum(abs(points) ** 2, axis=-1))) 120 | points /= furthest_distance 121 | 122 | # Pre-compute the Gaussian disk 123 | radius = (diameter - 1) / 2.0 124 | disk = np.zeros((diameter, diameter)) 125 | for i in range(diameter): 126 | for j in range(diameter): 127 | if (i - radius) * (i - radius) + (j - radius) * (j - radius) <= radius * radius: 128 | disk[i, j] = np.exp((-(i - radius) ** 2 - (j - radius) ** 2) / (radius ** 2)) 129 | mask = np.argwhere(disk > 0) 130 | dx = mask[:, 0] 131 | dy = mask[:, 1] 132 | dv = disk[disk > 0] 133 | 134 | # Order points by z-buffer 135 | zorder = np.argsort(points[:, 2]) 136 | points = points[zorder, :] 137 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) 138 | max_depth = np.max(points[:, 2]) 139 | 140 | for i in range(points.shape[0]): 141 | j = points.shape[0] - i - 1 142 | x = points[j, 0] 143 | y = points[j, 1] 144 | xc = canvasSize / 2 + (x * space) 145 | yc = canvasSize / 2 + (y * space) 146 | xc = int(np.round(xc)) 147 | yc = int(np.round(yc)) 148 | 149 | px = dx + xc 150 | py = dy + yc 151 | 152 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 153 | 154 | image = image / np.max(image) 155 | return image 156 | 157 | 158 | def point_cloud_three_views(points): 159 | """ input points Nx3 numpy array (+y is up direction). 160 | return an numpy array gray image of size 500x1500. """ 161 | # +y is up direction 162 | # xrot is azimuth 163 | # yrot is in-plane 164 | # zrot is elevation 165 | img1 = draw_point_cloud(points, zrot=110 / 180.0 * np.pi, xrot=45 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi) 166 | img2 = draw_point_cloud(points, zrot=70 / 180.0 * np.pi, xrot=135 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi) 167 | img3 = draw_point_cloud(points, zrot=180.0 / 180.0 * np.pi, xrot=90 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi) 168 | image_large = np.concatenate([img1, img2, img3], 1) 169 | return image_large 170 | 171 | 172 | from PIL import Image 173 | 174 | 175 | def point_cloud_three_views_demo(): 176 | """ Demo for draw_point_cloud function """ 177 | DATA_PATH = '../data/ShapeNet/' 178 | train_data, _, _, _, _, _ = load_data(DATA_PATH,classification=False) 179 | points = train_data[1] 180 | im_array = point_cloud_three_views(points) 181 | img = Image.fromarray(np.uint8(im_array * 255.0)) 182 | img.save('example.jpg') 183 | 184 | 185 | if __name__ == "__main__": 186 | from data_utils.ShapeNetDataLoader import load_data 187 | point_cloud_three_views_demo() 188 | 189 | import matplotlib.pyplot as plt 190 | 191 | 192 | def pyplot_draw_point_cloud(points, output_filename): 193 | """ points is a Nx3 numpy array """ 194 | fig = plt.figure() 195 | ax = fig.add_subplot(111, projection='3d') 196 | ax.scatter(points[:, 0], points[:, 1], points[:, 2]) 197 | ax.set_xlabel('x') 198 | ax.set_ylabel('y') 199 | ax.set_zlabel('z') 200 | # savefig(output_filename) 201 | 202 | 203 | def pyplot_draw_volume(vol, output_filename): 204 | """ vol is of size vsize*vsize*vsize 205 | output an image to output_filename 206 | """ 207 | points = volume_to_point_cloud(vol) 208 | pyplot_draw_point_cloud(points, output_filename) 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /visualizer/render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0: 95 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=0)) 96 | if magnifyBlue >= 2: 97 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=0)) 98 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=1)) 99 | if magnifyBlue >= 2: 100 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=1)) 101 | if showrot: 102 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), (30, showsz - 30), 0, 0.5, 103 | cv2.cv.CV_RGB(255, 0, 0)) 104 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), (30, showsz - 50), 0, 0.5, 105 | cv2.cv.CV_RGB(255, 0, 0)) 106 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 107 | 108 | changed = True 109 | while True: 110 | if changed: 111 | render() 112 | changed = False 113 | cv2.imshow('show3d', show) 114 | if waittime == 0: 115 | cmd = cv2.waitKey(10) % 256 116 | else: 117 | cmd = cv2.waitKey(waittime) % 256 118 | if cmd == ord('q'): 119 | break 120 | elif cmd == ord('Q'): 121 | sys.exit(0) 122 | 123 | if cmd == ord('t') or cmd == ord('p'): 124 | if cmd == ord('t'): 125 | if c_gt is None: 126 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 127 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 128 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 129 | else: 130 | c0 = c_gt[:, 0] 131 | c1 = c_gt[:, 1] 132 | c2 = c_gt[:, 2] 133 | else: 134 | if c_pred is None: 135 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 136 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 137 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 138 | else: 139 | c0 = c_pred[:, 0] 140 | c1 = c_pred[:, 1] 141 | c2 = c_pred[:, 2] 142 | if normalizecolor: 143 | c0 /= (c0.max() + 1e-14) / 255.0 144 | c1 /= (c1.max() + 1e-14) / 255.0 145 | c2 /= (c2.max() + 1e-14) / 255.0 146 | c0 = np.require(c0, 'float32', 'C') 147 | c1 = np.require(c1, 'float32', 'C') 148 | c2 = np.require(c2, 'float32', 'C') 149 | changed = True 150 | 151 | if cmd == ord('n'): 152 | zoom *= 1.1 153 | changed = True 154 | elif cmd == ord('m'): 155 | zoom /= 1.1 156 | changed = True 157 | elif cmd == ord('r'): 158 | zoom = 1.0 159 | changed = True 160 | elif cmd == ord('s'): 161 | cv2.imwrite('show3d.png', show) 162 | if waittime != 0: 163 | break 164 | return cmd 165 | 166 | 167 | if __name__ == '__main__': 168 | import os 169 | import numpy as np 170 | import argparse 171 | 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--dataset', type=str, default='../data/shapenet', help='dataset path') 174 | parser.add_argument('--category', type=str, default='Airplane', help='select category') 175 | parser.add_argument('--npoints', type=int, default=2500, help='resample points number') 176 | parser.add_argument('--ballradius', type=int, default=10, help='ballradius') 177 | opt = parser.parse_args() 178 | ''' 179 | Airplane 02691156 180 | Bag 02773838 181 | Cap 02954340 182 | Car 02958343 183 | Chair 03001627 184 | Earphone 03261776 185 | Guitar 03467517 186 | Knife 03624134 187 | Lamp 03636649 188 | Laptop 03642806 189 | Motorbike 03790512 190 | Mug 03797390 191 | Pistol 03948459 192 | Rocket 04099429 193 | Skateboard 04225987 194 | Table 04379243''' 195 | 196 | cmap = np.array([[1.00000000e+00, 0.00000000e+00, 0.00000000e+00], 197 | [3.12493437e-02, 1.00000000e+00, 1.31250131e-06], 198 | [0.00000000e+00, 6.25019688e-02, 1.00000000e+00], 199 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 200 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 201 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 202 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 203 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 204 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02], 205 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02]]) 206 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 207 | ROOT_DIR = os.path.dirname(BASE_DIR) 208 | sys.path.append(BASE_DIR) 209 | sys.path.append(os.path.join(ROOT_DIR, 'data_utils')) 210 | 211 | from ShapeNetDataLoader import PartNormalDataset 212 | root = '../data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' 213 | dataset = PartNormalDataset(root = root, npoints=2048, split='test', normal_channel=False) 214 | idx = np.random.randint(0, len(dataset)) 215 | data = dataset[idx] 216 | point_set, _, seg = data 217 | choice = np.random.choice(point_set.shape[0], opt.npoints, replace=True) 218 | point_set, seg = point_set[choice, :], seg[choice] 219 | seg = seg - seg.min() 220 | gt = cmap[seg, :] 221 | pred = cmap[seg, :] 222 | showpoints(point_set, gt, c_pred=pred, waittime=0, showrot=False, magnifyBlue=0, freezerot=False, 223 | background=(255, 255, 255), normalizecolor=True, ballradius=opt.ballradius) 224 | 225 | 226 | --------------------------------------------------------------------------------