├── .gitattributes
├── .gitignore
├── .vscode
└── settings.json
├── README.md
├── data_utils
├── ModelNetDataLoader.py
├── S3DISDataLoader.py
├── ShapeNetDataLoader.py
├── __init__.py
├── collect_indoor3d_data.py
├── indoor3d_util.py
└── meta
│ ├── anno_paths.txt
│ └── class_names.txt
├── image
├── plane1.png
├── plane2.png
└── psn.png
├── models
├── PointSamplingNet.py
├── __init__.py
├── pointnet2_cls_ssg_psn.py
├── pointnet2_part_seg_ssg_psn.py
├── pointnet2_sem_seg_psn.py
└── pointnet_util_psn.py
├── provider.py
├── test_cls.py
├── test_partseg.py
├── test_semseg.py
├── train_cls.py
├── train_partseg.py
├── train_semseg.py
└── visualizer
├── build.sh
├── eulerangles.py
├── pc_utils.py
├── plyfile.py
├── render_balls_so.cpp
└── show3d_balls.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # IPython
77 | profile_default/
78 | ipython_config.py
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 | .dmypy.json
111 | dmypy.json
112 |
113 | # Pyre type checker
114 | .pyre/
115 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Point Sampling Net: Fast Subsampling and Local Grouping for Deep Learning on Point Cloud
2 | ### (Anonymous Currently)
3 |
4 | This repository is the implementation for our paper :
5 | *Point Sampling Net: Fast Subsampling and Local Grouping for Deep Learning on Point Cloud*
6 |
7 | ## Introduction
8 | 
9 | **Point Sampling Net** is a differentiable fast grouping and sampling method for deep learning on point cloud, which can be applied to mainstream point cloud deep learning models. Point Sampling Net perform grouping and sampling tasks at the same time. It does not use the relationship between points as a grouping reference, so that the inference speed is independent of the number of points, and friendly to parallel implementation, that reduces the time consumption of sampling and grouping effectively.
10 | Point Sampling Net has been tested on PointNet++ [[1](#Reference)], PointConv [[2](#Reference)], RS-CNN [[3](#Reference)], GAC [[4](#Reference)]. There is not obvious adverse effects on these deep learning models of classification, part segmentation, and scene segmentation tasks and the speed of training and inference has been significantly improved.
11 |
12 | # Usage
13 | The [**CORE FILE**](https://github.com/psn-anonymous/PointSamplingNet/blob/master/models/PointSamplingNet.py) of Point Sampling Net: [models/PointSamplingNet.py](https://github.com/psn-anonymous/PointSamplingNet/blob/master/models/PointSamplingNet.py)
14 |
15 | ## Software Dependencies
16 | Python 3.7 or newer
17 | PyTorch 1.5 or newer
18 | NVIDIA® CUDA® Toolkit 9.2 or newer
19 | NVIDIA® CUDA® Deep Neural Network library (cuDNN) 7.2 or newer
20 |
21 | You can build the software dependencies through **conda** easily
22 | ```shell
23 | conda install pytorch cudatoolkit cudnn -c pytorch
24 | ```
25 |
26 | ## Import Point Sampling Net PyTorch Module
27 | You may import PSN pytorch module by:
28 | ```python
29 | import PointSmaplingNet as psn
30 | ```
31 | ## Native PSN
32 | ### Defining
33 | ```python
34 | psn_layer = psn.PointSamplingNet(num_to_sample = 512, max_local_num = 32, mlp = [32, 256])
35 | ```
36 | Attribute *mlp* is the middle channels of PSN, because the channel of first layer and last layer must be 3 and sampling number.
37 | ### Forward Propagation
38 | ```python
39 | sampled_points, grouped_points, sampled_feature, grouped_feature = psn_layer(coordinate = {coordinates of point cloud}, feature = {feature of point cloud})
40 | ```
41 | *sampled_points* is the sampled points, *grouped_points* is the grouped points.
42 | *sampled_feature* is the sampled feature, *grouped_feature* is the grouped feature.
43 | *{coordinates of point cloud}* is a torch.Tensor object, its shape is [*batch size*, *number of points*, *3*]
44 | *{feature of point cloud}* is a torch.Tensor object, , its shape is [*batch size*, *number of points*, *D*].
45 |
46 | ## PSN with Multi-Scale Grouping
47 | ### Defining
48 | ```python
49 | psn_msg_layer = psn.PointSamplingNetMSG(num_to_sample = 512, msg_n = [32, 64], mlp = [32, 256])
50 | ```
51 | Attribute *msg_n* is the list of multi-scale *n*.
52 | ### Forward Propagation
53 | ```python
54 | sampled_points, grouped_points_msg, sampled_feature, grouped_feature_msg = psn_msg_layer(coordinate = {coordinates of point cloud}, feature = {feature of point cloud})
55 | ```
56 | *sampled_points* is the sampled points, *grouped_points_msg* is the list of mutil-scale grouped points.
57 | *sampled_feature* is the sampled feature, *grouped_feature_msg* is the list of mutil-scale the grouped feature.
58 |
59 |
60 | # Visualize Effect
61 | ### Sampling
62 | 
63 | ### Grouping
64 | 
65 |
66 | # The Experiment on Deep Learning Networks
67 | There is an experiment on PointNet++
68 | ## Environments
69 | This experiment has been tested on follow environments:
70 | ### Software
71 | Canonical Ubuntu 20.04.1 LTS / Microsoft Windows 10 Pro
72 | Python 3.8.5
73 | PyTorch 1.7.0
74 | NVIDIA® CUDA® Toolkit 10.2.89
75 | NVIDIA® CUDA® Deep Neural Network library (cuDNN) 7.6.5
76 |
77 | ### Hardware
78 | Intel® Core™ i9-9900K Processor (16M Cache, up to 5.00 GHz)
79 | 64GB DDR4 RAM
80 | NVIDIA® TITAN RTX™
81 |
82 | ## Classification
83 | ### Data Preparation
84 | Download alignment **ModelNet** [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip) and save in `data/modelnet40_normal_resampled/`.
85 |
86 | ### Run
87 | ```shell
88 | python train_cls.py --log_dir [your log dir]
89 | ```
90 |
91 | ## Part Segmentation
92 | ### Data Preparation
93 | Download alignment **ShapeNet** [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip) and save in `data/shapenetcore_partanno_segmentation_benchmark_v0_normal/`.
94 | ### Run
95 | ```shell
96 | python train_partseg.py --normal --log_dir [your log dir]
97 | ```
98 |
99 | ## Semantic Segmentation
100 | ### Data Preparation
101 | Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/Stanford3dDataset_v1.2_Aligned_Version/`.
102 | ```shell
103 | cd data_utils
104 | python collect_indoor3d_data.py
105 | ```
106 | Processed data will save in `data/stanford_indoor3d/`.
107 | ### Run
108 | ```shell
109 | python train_semseg.py --log_dir [your log dir]
110 | python test_semseg.py --log_dir [your log dir] --test_area 5 --visual
111 | ```
112 |
113 | ## Experiment Reference
114 | This implementation of experiment is heavily reference to [yanx27/Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)
115 | Thanks very much !
116 |
117 | # Reference
118 | [1] Qi, Charles Ruizhongtai, et al. “[PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-se).” *Advances in Neural Information Processing Systems*, 2017, pp. 5099–5108. [[PDF](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf)]
119 | [2] Wu, Wenxuan, et al. “[PointConv: Deep Convolutional Networks on 3D Point Clouds](http://openaccess.thecvf.com/content_CVPR_2019/html/Wu_PointConv_Deep_Convolutional_Networks_on_3D_Point_Clouds_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 9621–9630. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_PointConv_Deep_Convolutional_Networks_on_3D_Point_Clouds_CVPR_2019_paper.pdf)]
120 | [3] Liu, Yongcheng, et al. “[Relation-Shape Convolutional Neural Network for Point Cloud Analysis](http://openaccess.thecvf.com/content_CVPR_2019/html/Liu_Relation-Shape_Convolutional_Neural_Network_for_Point_Cloud_Analysis_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 8895–8904. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Liu_Relation-Shape_Convolutional_Neural_Network_for_Point_Cloud_Analysis_CVPR_2019_paper.pdf)]
121 | [4] Wang, Lei, et al. “[Graph Attention Convolution for Point Cloud Semantic Segmentation](https://openaccess.thecvf.com/content_CVPR_2019/html/Wang_Graph_Attention_Convolution_for_Point_Cloud_Semantic_Segmentation_CVPR_2019_paper.html).” *2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)*, 2019, pp. 10296–10305. [[PDF](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Graph_Attention_Convolution_for_Point_Cloud_Semantic_Segmentation_CVPR_2019_paper.pdf)]
--------------------------------------------------------------------------------
/data_utils/ModelNetDataLoader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import warnings
3 | import os
4 | from torch.utils.data import Dataset
5 | warnings.filterwarnings('ignore')
6 |
7 |
8 |
9 | def pc_normalize(pc):
10 | centroid = np.mean(pc, axis=0)
11 | pc = pc - centroid
12 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
13 | pc = pc / m
14 | return pc
15 |
16 | def farthest_point_sample(point, npoint):
17 | """
18 | Input:
19 | xyz: pointcloud data, [N, D]
20 | npoint: number of samples
21 | Return:
22 | centroids: sampled pointcloud index, [npoint, D]
23 | """
24 | N, D = point.shape
25 | xyz = point[:,:3]
26 | centroids = np.zeros((npoint,))
27 | distance = np.ones((N,)) * 1e10
28 | farthest = np.random.randint(0, N)
29 | for i in range(npoint):
30 | centroids[i] = farthest
31 | centroid = xyz[farthest, :]
32 | dist = np.sum((xyz - centroid) ** 2, -1)
33 | mask = dist < distance
34 | distance[mask] = dist[mask]
35 | farthest = np.argmax(distance, -1)
36 | point = point[centroids.astype(np.int32)]
37 | return point
38 |
39 | class ModelNetDataLoader(Dataset):
40 | def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
41 | self.root = root
42 | self.npoints = npoint
43 | self.uniform = uniform
44 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
45 |
46 | self.cat = [line.rstrip() for line in open(self.catfile)]
47 | self.classes = dict(zip(self.cat, range(len(self.cat))))
48 | self.normal_channel = normal_channel
49 |
50 | shape_ids = {}
51 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
52 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
53 |
54 | assert (split == 'train' or split == 'test')
55 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
56 | # list of (shape_name, shape_txt_file_path) tuple
57 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
58 | in range(len(shape_ids[split]))]
59 | print('The size of %s data is %d'%(split,len(self.datapath)))
60 |
61 | self.cache_size = cache_size # how many data points to cache in memory
62 | self.cache = {} # from index to (point_set, cls) tuple
63 |
64 | def __len__(self):
65 | return len(self.datapath)
66 |
67 | def _get_item(self, index):
68 | if index in self.cache:
69 | point_set, cls = self.cache[index]
70 | else:
71 | fn = self.datapath[index]
72 | cls = self.classes[self.datapath[index][0]]
73 | cls = np.array([cls]).astype(np.int32)
74 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
75 | if self.uniform:
76 | point_set = farthest_point_sample(point_set, self.npoints)
77 | else:
78 | point_set = point_set[0:self.npoints,:]
79 |
80 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
81 |
82 | if not self.normal_channel:
83 | point_set = point_set[:, 0:3]
84 |
85 | if len(self.cache) < self.cache_size:
86 | self.cache[index] = (point_set, cls)
87 |
88 | return point_set, cls
89 |
90 | def __getitem__(self, index):
91 | return self._get_item(index)
92 |
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 | import torch
98 |
99 | data = ModelNetDataLoader('/data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,)
100 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
101 | for point,label in DataLoader:
102 | print(point.shape)
103 | print(label.shape)
--------------------------------------------------------------------------------
/data_utils/S3DISDataLoader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class S3DISDataset(Dataset):
7 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None):
8 | super().__init__()
9 | self.num_point = num_point
10 | self.block_size = block_size
11 | self.transform = transform
12 | rooms = sorted(os.listdir(data_root))
13 | rooms = [room for room in rooms if 'Area_' in room]
14 | if split == 'train':
15 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room]
16 | else:
17 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room]
18 | self.room_points, self.room_labels = [], []
19 | self.room_coord_min, self.room_coord_max = [], []
20 | num_point_all = []
21 | labelweights = np.zeros(13)
22 | for room_name in rooms_split:
23 | room_path = os.path.join(data_root, room_name)
24 | room_data = np.load(room_path) # xyzrgbl, N*7
25 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N
26 | tmp, _ = np.histogram(labels, range(14))
27 | labelweights += tmp
28 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
29 | self.room_points.append(points), self.room_labels.append(labels)
30 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
31 | num_point_all.append(labels.size)
32 | labelweights = labelweights.astype(np.float32)
33 | labelweights = labelweights / np.sum(labelweights)
34 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
35 | print(self.labelweights)
36 | sample_prob = num_point_all / np.sum(num_point_all)
37 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point)
38 | room_idxs = []
39 | for index in range(len(rooms_split)):
40 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter)))
41 | self.room_idxs = np.array(room_idxs)
42 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split))
43 |
44 | def __getitem__(self, idx):
45 | room_idx = self.room_idxs[idx]
46 | points = self.room_points[room_idx] # N * 6
47 | labels = self.room_labels[room_idx] # N
48 | N_points = points.shape[0]
49 |
50 | while (True):
51 | center = points[np.random.choice(N_points)][:3]
52 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0]
53 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0]
54 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0]
55 | if point_idxs.size > 1024:
56 | break
57 |
58 | if point_idxs.size >= self.num_point:
59 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False)
60 | else:
61 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True)
62 |
63 | # normalize
64 | selected_points = points[selected_point_idxs, :] # num_point * 6
65 | current_points = np.zeros((self.num_point, 9)) # num_point * 9
66 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0]
67 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1]
68 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2]
69 | selected_points[:, 0] = selected_points[:, 0] - center[0]
70 | selected_points[:, 1] = selected_points[:, 1] - center[1]
71 | selected_points[:, 3:6] /= 255.0
72 | current_points[:, 0:6] = selected_points
73 | current_labels = labels[selected_point_idxs]
74 | if self.transform is not None:
75 | current_points, current_labels = self.transform(current_points, current_labels)
76 | return current_points, current_labels
77 |
78 | def __len__(self):
79 | return len(self.room_idxs)
80 |
81 | class ScannetDatasetWholeScene():
82 | # prepare to give prediction on each points
83 | def __init__(self, root, block_points=4096, split='test', test_area=5, stride=0.5, block_size=1.0, padding=0.001):
84 | self.block_points = block_points
85 | self.block_size = block_size
86 | self.padding = padding
87 | self.root = root
88 | self.split = split
89 | self.stride = stride
90 | self.scene_points_num = []
91 | assert split in ['train', 'test']
92 | if self.split == 'train':
93 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is -1]
94 | else:
95 | self.file_list = [d for d in os.listdir(root) if d.find('Area_%d' % test_area) is not -1]
96 | self.scene_points_list = []
97 | self.semantic_labels_list = []
98 | self.room_coord_min, self.room_coord_max = [], []
99 | for file in self.file_list:
100 | data = np.load(root + file)
101 | points = data[:, :3]
102 | self.scene_points_list.append(data[:, :6])
103 | self.semantic_labels_list.append(data[:, 6])
104 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
105 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max)
106 | assert len(self.scene_points_list) == len(self.semantic_labels_list)
107 |
108 | labelweights = np.zeros(13)
109 | for seg in self.semantic_labels_list:
110 | tmp, _ = np.histogram(seg, range(14))
111 | self.scene_points_num.append(seg.shape[0])
112 | labelweights += tmp
113 | labelweights = labelweights.astype(np.float32)
114 | labelweights = labelweights / np.sum(labelweights)
115 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
116 |
117 | def __getitem__(self, index):
118 | point_set_ini = self.scene_points_list[index]
119 | points = point_set_ini[:,:6]
120 | labels = self.semantic_labels_list[index]
121 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3]
122 | grid_x = int(np.ceil(float(coord_max[0] - coord_min[0] - self.block_size) / self.stride) + 1)
123 | grid_y = int(np.ceil(float(coord_max[1] - coord_min[1] - self.block_size) / self.stride) + 1)
124 | data_room, label_room, sample_weight, index_room = np.array([]), np.array([]), np.array([]), np.array([])
125 | for index_y in range(0, grid_y):
126 | for index_x in range(0, grid_x):
127 | s_x = coord_min[0] + index_x * self.stride
128 | e_x = min(s_x + self.block_size, coord_max[0])
129 | s_x = e_x - self.block_size
130 | s_y = coord_min[1] + index_y * self.stride
131 | e_y = min(s_y + self.block_size, coord_max[1])
132 | s_y = e_y - self.block_size
133 | point_idxs = np.where(
134 | (points[:, 0] >= s_x - self.padding) & (points[:, 0] <= e_x + self.padding) & (points[:, 1] >= s_y - self.padding) & (
135 | points[:, 1] <= e_y + self.padding))[0]
136 | if point_idxs.size == 0:
137 | continue
138 | num_batch = int(np.ceil(point_idxs.size / self.block_points))
139 | point_size = int(num_batch * self.block_points)
140 | replace = False if (point_size - point_idxs.size <= point_idxs.size) else True
141 | point_idxs_repeat = np.random.choice(point_idxs, point_size - point_idxs.size, replace=replace)
142 | point_idxs = np.concatenate((point_idxs, point_idxs_repeat))
143 | np.random.shuffle(point_idxs)
144 | data_batch = points[point_idxs, :]
145 | normlized_xyz = np.zeros((point_size, 3))
146 | normlized_xyz[:, 0] = data_batch[:, 0] / coord_max[0]
147 | normlized_xyz[:, 1] = data_batch[:, 1] / coord_max[1]
148 | normlized_xyz[:, 2] = data_batch[:, 2] / coord_max[2]
149 | data_batch[:, 0] = data_batch[:, 0] - (s_x + self.block_size / 2.0)
150 | data_batch[:, 1] = data_batch[:, 1] - (s_y + self.block_size / 2.0)
151 | data_batch[:, 3:6] /= 255.0
152 | data_batch = np.concatenate((data_batch, normlized_xyz), axis=1)
153 | label_batch = labels[point_idxs].astype(int)
154 | batch_weight = self.labelweights[label_batch]
155 |
156 | data_room = np.vstack([data_room, data_batch]) if data_room.size else data_batch
157 | label_room = np.hstack([label_room, label_batch]) if label_room.size else label_batch
158 | sample_weight = np.hstack([sample_weight, batch_weight]) if label_room.size else batch_weight
159 | index_room = np.hstack([index_room, point_idxs]) if index_room.size else point_idxs
160 | data_room = data_room.reshape((-1, self.block_points, data_room.shape[1]))
161 | label_room = label_room.reshape((-1, self.block_points))
162 | sample_weight = sample_weight.reshape((-1, self.block_points))
163 | index_room = index_room.reshape((-1, self.block_points))
164 | return data_room, label_room, sample_weight, index_room
165 |
166 | def __len__(self):
167 | return len(self.scene_points_list)
168 |
169 | if __name__ == '__main__':
170 | data_root = '/data/yxu/PointNonLocal/data/stanford_indoor3d/'
171 | num_point, test_area, block_size, sample_rate = 4096, 5, 1.0, 0.01
172 |
173 | point_data = S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, block_size=block_size, sample_rate=sample_rate, transform=None)
174 | print('point data size:', point_data.__len__())
175 | print('point data 0 shape:', point_data.__getitem__(0)[0].shape)
176 | print('point label 0 shape:', point_data.__getitem__(0)[1].shape)
177 | import torch, time, random
178 | manual_seed = 123
179 | random.seed(manual_seed)
180 | np.random.seed(manual_seed)
181 | torch.manual_seed(manual_seed)
182 | torch.cuda.manual_seed_all(manual_seed)
183 | def worker_init_fn(worker_id):
184 | random.seed(manual_seed + worker_id)
185 | train_loader = torch.utils.data.DataLoader(point_data, batch_size=16, shuffle=True, num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn)
186 | for idx in range(4):
187 | end = time.time()
188 | for i, (input, target) in enumerate(train_loader):
189 | print('time: {}/{}--{}'.format(i+1, len(train_loader), time.time() - end))
190 | end = time.time()
--------------------------------------------------------------------------------
/data_utils/ShapeNetDataLoader.py:
--------------------------------------------------------------------------------
1 | # *_*coding:utf-8 *_*
2 | import os
3 | import json
4 | import warnings
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 | warnings.filterwarnings('ignore')
8 |
9 | def pc_normalize(pc):
10 | centroid = np.mean(pc, axis=0)
11 | pc = pc - centroid
12 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
13 | pc = pc / m
14 | return pc
15 |
16 | class PartNormalDataset(Dataset):
17 | def __init__(self,root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False):
18 | self.npoints = npoints
19 | self.root = root
20 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
21 | self.cat = {}
22 | self.normal_channel = normal_channel
23 |
24 |
25 | with open(self.catfile, 'r') as f:
26 | for line in f:
27 | ls = line.strip().split()
28 | self.cat[ls[0]] = ls[1]
29 | self.cat = {k: v for k, v in self.cat.items()}
30 | self.classes_original = dict(zip(self.cat, range(len(self.cat))))
31 |
32 | if not class_choice is None:
33 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
34 | # print(self.cat)
35 |
36 | self.meta = {}
37 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
38 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
39 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
40 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
41 | with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
42 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
43 | for item in self.cat:
44 | # print('category', item)
45 | self.meta[item] = []
46 | dir_point = os.path.join(self.root, self.cat[item])
47 | fns = sorted(os.listdir(dir_point))
48 | # print(fns[0][0:-4])
49 | if split == 'trainval':
50 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
51 | elif split == 'train':
52 | fns = [fn for fn in fns if fn[0:-4] in train_ids]
53 | elif split == 'val':
54 | fns = [fn for fn in fns if fn[0:-4] in val_ids]
55 | elif split == 'test':
56 | fns = [fn for fn in fns if fn[0:-4] in test_ids]
57 | else:
58 | print('Unknown split: %s. Exiting..' % (split))
59 | exit(-1)
60 |
61 | # print(os.path.basename(fns))
62 | for fn in fns:
63 | token = (os.path.splitext(os.path.basename(fn))[0])
64 | self.meta[item].append(os.path.join(dir_point, token + '.txt'))
65 |
66 | self.datapath = []
67 | for item in self.cat:
68 | for fn in self.meta[item]:
69 | self.datapath.append((item, fn))
70 |
71 | self.classes = {}
72 | for i in self.cat.keys():
73 | self.classes[i] = self.classes_original[i]
74 |
75 | # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
76 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
77 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
78 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
79 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
80 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
81 |
82 | # for cat in sorted(self.seg_classes.keys()):
83 | # print(cat, self.seg_classes[cat])
84 |
85 | self.cache = {} # from index to (point_set, cls, seg) tuple
86 | self.cache_size = 20000
87 |
88 |
89 | def __getitem__(self, index):
90 | if index in self.cache:
91 | ppoint_set, cls, seg = self.cache[index]
92 | else:
93 | fn = self.datapath[index]
94 | cat = self.datapath[index][0]
95 | cls = self.classes[cat]
96 | cls = np.array([cls]).astype(np.int32)
97 | data = np.loadtxt(fn[1]).astype(np.float32)
98 | if not self.normal_channel:
99 | point_set = data[:, 0:3]
100 | else:
101 | point_set = data[:, 0:6]
102 | seg = data[:, -1].astype(np.int32)
103 | if len(self.cache) < self.cache_size:
104 | self.cache[index] = (point_set, cls, seg)
105 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
106 |
107 | choice = np.random.choice(len(seg), self.npoints, replace=True)
108 | # resample
109 | point_set = point_set[choice, :]
110 | seg = seg[choice]
111 |
112 | return point_set, cls, seg
113 |
114 | def __len__(self):
115 | return len(self.datapath)
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/data_utils/__init__.py
--------------------------------------------------------------------------------
/data_utils/collect_indoor3d_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append('./data_utils')
4 | from indoor3d_util import DATA_PATH, collect_point_label
5 |
6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7 | ROOT_DIR = os.path.dirname(BASE_DIR)
8 | sys.path.append(BASE_DIR)
9 |
10 | anno_paths = [line.rstrip() for line in open(os.path.join(BASE_DIR, 'meta/anno_paths.txt'))]
11 | anno_paths = [os.path.join(DATA_PATH, p) for p in anno_paths]
12 |
13 | output_folder = os.path.join(ROOT_DIR, 'data/stanford_indoor3d')
14 | if not os.path.exists(output_folder):
15 | os.mkdir(output_folder)
16 |
17 | # Note: there is an extra character in the v1.2 data in Area_5/hallway_6. It's fixed manually.
18 | for anno_path in anno_paths:
19 | print(anno_path)
20 | try:
21 | elements = anno_path.split('/')
22 | out_filename = elements[-3]+'_'+elements[-2]+'.npy' # Area_1_hallway_1.npy
23 | collect_point_label(anno_path, os.path.join(output_folder, out_filename), 'numpy')
24 | except:
25 | print(anno_path, 'ERROR!!')
26 |
--------------------------------------------------------------------------------
/data_utils/indoor3d_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 | import sys
5 |
6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
7 | ROOT_DIR = os.path.dirname(BASE_DIR)
8 | sys.path.append(BASE_DIR)
9 |
10 | DATA_PATH = os.path.join(ROOT_DIR, 'data', 'Stanford3dDataset_v1.2_Aligned_Version')
11 | g_classes = [x.rstrip() for x in open(os.path.join(BASE_DIR, 'meta/class_names.txt'))]
12 | g_class2label = {cls: i for i,cls in enumerate(g_classes)}
13 | g_class2color = {'ceiling': [0,255,0],
14 | 'floor': [0,0,255],
15 | 'wall': [0,255,255],
16 | 'beam': [255,255,0],
17 | 'column': [255,0,255],
18 | 'window': [100,100,255],
19 | 'door': [200,200,100],
20 | 'table': [170,120,200],
21 | 'chair': [255,0,0],
22 | 'sofa': [200,100,100],
23 | 'bookcase': [10,200,100],
24 | 'board': [200,200,200],
25 | 'clutter': [50,50,50]}
26 | g_easy_view_labels = [7,8,9,10,11,1]
27 | g_label2color = {g_classes.index(cls): g_class2color[cls] for cls in g_classes}
28 |
29 |
30 | # -----------------------------------------------------------------------------
31 | # CONVERT ORIGINAL DATA TO OUR DATA_LABEL FILES
32 | # -----------------------------------------------------------------------------
33 |
34 | def collect_point_label(anno_path, out_filename, file_format='txt'):
35 | """ Convert original dataset files to data_label file (each line is XYZRGBL).
36 | We aggregated all the points from each instance in the room.
37 |
38 | Args:
39 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
40 | out_filename: path to save collected points and labels (each line is XYZRGBL)
41 | file_format: txt or numpy, determines what file format to save.
42 | Returns:
43 | None
44 | Note:
45 | the points are shifted before save, the most negative point is now at origin.
46 | """
47 | points_list = []
48 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
49 | cls = os.path.basename(f).split('_')[0]
50 | print(f)
51 | if cls not in g_classes: # note: in some room there is 'staris' class..
52 | cls = 'clutter'
53 |
54 | points = np.loadtxt(f)
55 | labels = np.ones((points.shape[0],1)) * g_class2label[cls]
56 | points_list.append(np.concatenate([points, labels], 1)) # Nx7
57 |
58 | data_label = np.concatenate(points_list, 0)
59 | xyz_min = np.amin(data_label, axis=0)[0:3]
60 | data_label[:, 0:3] -= xyz_min
61 |
62 | if file_format=='txt':
63 | fout = open(out_filename, 'w')
64 | for i in range(data_label.shape[0]):
65 | fout.write('%f %f %f %d %d %d %d\n' % \
66 | (data_label[i,0], data_label[i,1], data_label[i,2],
67 | data_label[i,3], data_label[i,4], data_label[i,5],
68 | data_label[i,6]))
69 | fout.close()
70 | elif file_format=='numpy':
71 | np.save(out_filename, data_label)
72 | else:
73 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \
74 | (file_format))
75 | exit()
76 |
77 | def data_to_obj(data,name='example.obj',no_wall=True):
78 | fout = open(name, 'w')
79 | label = data[:, -1].astype(int)
80 | for i in range(data.shape[0]):
81 | if no_wall and ((label[i] == 2) or (label[i]==0)):
82 | continue
83 | fout.write('v %f %f %f %d %d %d\n' % \
84 | (data[i, 0], data[i, 1], data[i, 2], data[i, 3], data[i, 4], data[i, 5]))
85 | fout.close()
86 |
87 | def point_label_to_obj(input_filename, out_filename, label_color=True, easy_view=False, no_wall=False):
88 | """ For visualization of a room from data_label file,
89 | input_filename: each line is X Y Z R G B L
90 | out_filename: OBJ filename,
91 | visualize input file by coloring point with label color
92 | easy_view: only visualize furnitures and floor
93 | """
94 | data_label = np.loadtxt(input_filename)
95 | data = data_label[:, 0:6]
96 | label = data_label[:, -1].astype(int)
97 | fout = open(out_filename, 'w')
98 | for i in range(data.shape[0]):
99 | color = g_label2color[label[i]]
100 | if easy_view and (label[i] not in g_easy_view_labels):
101 | continue
102 | if no_wall and ((label[i] == 2) or (label[i]==0)):
103 | continue
104 | if label_color:
105 | fout.write('v %f %f %f %d %d %d\n' % \
106 | (data[i,0], data[i,1], data[i,2], color[0], color[1], color[2]))
107 | else:
108 | fout.write('v %f %f %f %d %d %d\n' % \
109 | (data[i,0], data[i,1], data[i,2], data[i,3], data[i,4], data[i,5]))
110 | fout.close()
111 |
112 |
113 |
114 | # -----------------------------------------------------------------------------
115 | # PREPARE BLOCK DATA FOR DEEPNETS TRAINING/TESTING
116 | # -----------------------------------------------------------------------------
117 |
118 | def sample_data(data, num_sample):
119 | """ data is in N x ...
120 | we want to keep num_samplexC of them.
121 | if N > num_sample, we will randomly keep num_sample of them.
122 | if N < num_sample, we will randomly duplicate samples.
123 | """
124 | N = data.shape[0]
125 | if (N == num_sample):
126 | return data, range(N)
127 | elif (N > num_sample):
128 | sample = np.random.choice(N, num_sample)
129 | return data[sample, ...], sample
130 | else:
131 | sample = np.random.choice(N, num_sample-N)
132 | dup_data = data[sample, ...]
133 | return np.concatenate([data, dup_data], 0), list(range(N))+list(sample)
134 |
135 | def sample_data_label(data, label, num_sample):
136 | new_data, sample_indices = sample_data(data, num_sample)
137 | new_label = label[sample_indices]
138 | return new_data, new_label
139 |
140 | def room2blocks(data, label, num_point, block_size=1.0, stride=1.0,
141 | random_sample=False, sample_num=None, sample_aug=1):
142 | """ Prepare block training data.
143 | Args:
144 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1]
145 | assumes the data is shifted (min point is origin) and aligned
146 | (aligned with XYZ axis)
147 | label: N size uint8 numpy array from 0-12
148 | num_point: int, how many points to sample in each block
149 | block_size: float, physical size of the block in meters
150 | stride: float, stride for block sweeping
151 | random_sample: bool, if True, we will randomly sample blocks in the room
152 | sample_num: int, if random sample, how many blocks to sample
153 | [default: room area]
154 | sample_aug: if random sample, how much aug
155 | Returns:
156 | block_datas: K x num_point x 6 np array of XYZRGB, RGB is in [0,1]
157 | block_labels: K x num_point x 1 np array of uint8 labels
158 |
159 | TODO: for this version, blocking is in fixed, non-overlapping pattern.
160 | """
161 | assert(stride<=block_size)
162 |
163 | limit = np.amax(data, 0)[0:3]
164 |
165 | # Get the corner location for our sampling blocks
166 | xbeg_list = []
167 | ybeg_list = []
168 | if not random_sample:
169 | num_block_x = int(np.ceil((limit[0] - block_size) / stride)) + 1
170 | num_block_y = int(np.ceil(collect_point_label(limit[1] - block_size) / stride)) + 1
171 | for i in range(num_block_x):
172 | for j in range(num_block_y):
173 | xbeg_list.append(i*stride)
174 | ybeg_list.append(j*stride)
175 | else:
176 | num_block_x = int(np.ceil(limit[0] / block_size))
177 | num_block_y = int(np.ceil(limit[1] / block_size))
178 | if sample_num is None:
179 | sample_num = num_block_x * num_block_y * sample_aug
180 | for _ in range(sample_num):
181 | xbeg = np.random.uniform(-block_size, limit[0])
182 | ybeg = np.random.uniform(-block_size, limit[1])
183 | xbeg_list.append(xbeg)
184 | ybeg_list.append(ybeg)
185 |
186 | # Collect blocks
187 | block_data_list = []
188 | block_label_list = []
189 | idx = 0
190 | for idx in range(len(xbeg_list)):
191 | xbeg = xbeg_list[idx]
192 | ybeg = ybeg_list[idx]
193 | xcond = (data[:,0]<=xbeg+block_size) & (data[:,0]>=xbeg)
194 | ycond = (data[:,1]<=ybeg+block_size) & (data[:,1]>=ybeg)
195 | cond = xcond & ycond
196 | if np.sum(cond) < 100: # discard block if there are less than 100 pts.
197 | continue
198 |
199 | block_data = data[cond, :]
200 | block_label = label[cond]
201 |
202 | # randomly subsample data
203 | block_data_sampled, block_label_sampled = \
204 | sample_data_label(block_data, block_label, num_point)
205 | block_data_list.append(np.expand_dims(block_data_sampled, 0))
206 | block_label_list.append(np.expand_dims(block_label_sampled, 0))
207 |
208 | return np.concatenate(block_data_list, 0), \
209 | np.concatenate(block_label_list, 0)
210 |
211 |
212 | def room2blocks_plus(data_label, num_point, block_size, stride,
213 | random_sample, sample_num, sample_aug):
214 | """ room2block with input filename and RGB preprocessing.
215 | """
216 | data = data_label[:,0:6]
217 | data[:,3:6] /= 255.0
218 | label = data_label[:,-1].astype(np.uint8)
219 |
220 | return room2blocks(data, label, num_point, block_size, stride,
221 | random_sample, sample_num, sample_aug)
222 |
223 | def room2blocks_wrapper(data_label_filename, num_point, block_size=1.0, stride=1.0,
224 | random_sample=False, sample_num=None, sample_aug=1):
225 | if data_label_filename[-3:] == 'txt':
226 | data_label = np.loadtxt(data_label_filename)
227 | elif data_label_filename[-3:] == 'npy':
228 | data_label = np.load(data_label_filename)
229 | else:
230 | print('Unknown file type! exiting.')
231 | exit()
232 | return room2blocks_plus(data_label, num_point, block_size, stride,
233 | random_sample, sample_num, sample_aug)
234 |
235 | def room2blocks_plus_normalized(data_label, num_point, block_size, stride,
236 | random_sample, sample_num, sample_aug):
237 | """ room2block, with input filename and RGB preprocessing.
238 | for each block centralize XYZ, add normalized XYZ as 678 channels
239 | """
240 | data = data_label[:,0:6]
241 | data[:,3:6] /= 255.0
242 | label = data_label[:,-1].astype(np.uint8)
243 | max_room_x = max(data[:,0])
244 | max_room_y = max(data[:,1])
245 | max_room_z = max(data[:,2])
246 |
247 | data_batch, label_batch = room2blocks(data, label, num_point, block_size, stride,
248 | random_sample, sample_num, sample_aug)
249 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9))
250 | for b in range(data_batch.shape[0]):
251 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x
252 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y
253 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z
254 | minx = min(data_batch[b, :, 0])
255 | miny = min(data_batch[b, :, 1])
256 | data_batch[b, :, 0] -= (minx+block_size/2)
257 | data_batch[b, :, 1] -= (miny+block_size/2)
258 | new_data_batch[:, :, 0:6] = data_batch
259 | return new_data_batch, label_batch
260 |
261 |
262 | def room2blocks_wrapper_normalized(data_label_filename, num_point, block_size=1.0, stride=1.0,
263 | random_sample=False, sample_num=None, sample_aug=1):
264 | if data_label_filename[-3:] == 'txt':
265 | data_label = np.loadtxt(data_label_filename)
266 | elif data_label_filename[-3:] == 'npy':
267 | data_label = np.load(data_label_filename)
268 | else:
269 | print('Unknown file type! exiting.')
270 | exit()
271 | return room2blocks_plus_normalized(data_label, num_point, block_size, stride,
272 | random_sample, sample_num, sample_aug)
273 |
274 | def room2samples(data, label, sample_num_point):
275 | """ Prepare whole room samples.
276 |
277 | Args:
278 | data: N x 6 numpy array, 012 are XYZ in meters, 345 are RGB in [0,1]
279 | assumes the data is shifted (min point is origin) and
280 | aligned (aligned with XYZ axis)
281 | label: N size uint8 numpy array from 0-12
282 | sample_num_point: int, how many points to sample in each sample
283 | Returns:
284 | sample_datas: K x sample_num_point x 9
285 | numpy array of XYZRGBX'Y'Z', RGB is in [0,1]
286 | sample_labels: K x sample_num_point x 1 np array of uint8 labels
287 | """
288 | N = data.shape[0]
289 | order = np.arange(N)
290 | np.random.shuffle(order)
291 | data = data[order, :]
292 | label = label[order]
293 |
294 | batch_num = int(np.ceil(N / float(sample_num_point)))
295 | sample_datas = np.zeros((batch_num, sample_num_point, 6))
296 | sample_labels = np.zeros((batch_num, sample_num_point, 1))
297 |
298 | for i in range(batch_num):
299 | beg_idx = i*sample_num_point
300 | end_idx = min((i+1)*sample_num_point, N)
301 | num = end_idx - beg_idx
302 | sample_datas[i,0:num,:] = data[beg_idx:end_idx, :]
303 | sample_labels[i,0:num,0] = label[beg_idx:end_idx]
304 | if num < sample_num_point:
305 | makeup_indices = np.random.choice(N, sample_num_point - num)
306 | sample_datas[i,num:,:] = data[makeup_indices, :]
307 | sample_labels[i,num:,0] = label[makeup_indices]
308 | return sample_datas, sample_labels
309 |
310 | def room2samples_plus_normalized(data_label, num_point):
311 | """ room2sample, with input filename and RGB preprocessing.
312 | for each block centralize XYZ, add normalized XYZ as 678 channels
313 | """
314 | data = data_label[:,0:6]
315 | data[:,3:6] /= 255.0
316 | label = data_label[:,-1].astype(np.uint8)
317 | max_room_x = max(data[:,0])
318 | max_room_y = max(data[:,1])
319 | max_room_z = max(data[:,2])
320 | #print(max_room_x, max_room_y, max_room_z)
321 |
322 | data_batch, label_batch = room2samples(data, label, num_point)
323 | new_data_batch = np.zeros((data_batch.shape[0], num_point, 9))
324 | for b in range(data_batch.shape[0]):
325 | new_data_batch[b, :, 6] = data_batch[b, :, 0]/max_room_x
326 | new_data_batch[b, :, 7] = data_batch[b, :, 1]/max_room_y
327 | new_data_batch[b, :, 8] = data_batch[b, :, 2]/max_room_z
328 | #minx = min(data_batch[b, :, 0])
329 | #miny = min(data_batch[b, :, 1])
330 | #data_batch[b, :, 0] -= (minx+block_size/2)
331 | #data_batch[b, :, 1] -= (miny+block_size/2)
332 | new_data_batch[:, :, 0:6] = data_batch
333 | return new_data_batch, label_batch
334 |
335 |
336 | def room2samples_wrapper_normalized(data_label_filename, num_point):
337 | if data_label_filename[-3:] == 'txt':
338 | data_label = np.loadtxt(data_label_filename)
339 | elif data_label_filename[-3:] == 'npy':
340 | data_label = np.load(data_label_filename)
341 | else:
342 | print('Unknown file type! exiting.')
343 | exit()
344 | return room2samples_plus_normalized(data_label, num_point)
345 |
346 |
347 | # -----------------------------------------------------------------------------
348 | # EXTRACT INSTANCE BBOX FROM ORIGINAL DATA (for detection evaluation)
349 | # -----------------------------------------------------------------------------
350 |
351 | def collect_bounding_box(anno_path, out_filename):
352 | """ Compute bounding boxes from each instance in original dataset files on
353 | one room. **We assume the bbox is aligned with XYZ coordinate.**
354 |
355 | Args:
356 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
357 | out_filename: path to save instance bounding boxes for that room.
358 | each line is x1 y1 z1 x2 y2 z2 label,
359 | where (x1,y1,z1) is the point on the diagonal closer to origin
360 | Returns:
361 | None
362 | Note:
363 | room points are shifted, the most negative point is now at origin.
364 | """
365 | bbox_label_list = []
366 |
367 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
368 | cls = os.path.basename(f).split('_')[0]
369 | if cls not in g_classes: # note: in some room there is 'staris' class..
370 | cls = 'clutter'
371 | points = np.loadtxt(f)
372 | label = g_class2label[cls]
373 | # Compute tightest axis aligned bounding box
374 | xyz_min = np.amin(points[:, 0:3], axis=0)
375 | xyz_max = np.amax(points[:, 0:3], axis=0)
376 | ins_bbox_label = np.expand_dims(
377 | np.concatenate([xyz_min, xyz_max, np.array([label])], 0), 0)
378 | bbox_label_list.append(ins_bbox_label)
379 |
380 | bbox_label = np.concatenate(bbox_label_list, 0)
381 | room_xyz_min = np.amin(bbox_label[:, 0:3], axis=0)
382 | bbox_label[:, 0:3] -= room_xyz_min
383 | bbox_label[:, 3:6] -= room_xyz_min
384 |
385 | fout = open(out_filename, 'w')
386 | for i in range(bbox_label.shape[0]):
387 | fout.write('%f %f %f %f %f %f %d\n' % \
388 | (bbox_label[i,0], bbox_label[i,1], bbox_label[i,2],
389 | bbox_label[i,3], bbox_label[i,4], bbox_label[i,5],
390 | bbox_label[i,6]))
391 | fout.close()
392 |
393 | def bbox_label_to_obj(input_filename, out_filename_prefix, easy_view=False):
394 | """ Visualization of bounding boxes.
395 |
396 | Args:
397 | input_filename: each line is x1 y1 z1 x2 y2 z2 label
398 | out_filename_prefix: OBJ filename prefix,
399 | visualize object by g_label2color
400 | easy_view: if True, only visualize furniture and floor
401 | Returns:
402 | output a list of OBJ file and MTL files with the same prefix
403 | """
404 | bbox_label = np.loadtxt(input_filename)
405 | bbox = bbox_label[:, 0:6]
406 | label = bbox_label[:, -1].astype(int)
407 | v_cnt = 0 # count vertex
408 | ins_cnt = 0 # count instance
409 | for i in range(bbox.shape[0]):
410 | if easy_view and (label[i] not in g_easy_view_labels):
411 | continue
412 | obj_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.obj'
413 | mtl_filename = out_filename_prefix+'_'+g_classes[label[i]]+'_'+str(ins_cnt)+'.mtl'
414 | fout_obj = open(obj_filename, 'w')
415 | fout_mtl = open(mtl_filename, 'w')
416 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename)))
417 |
418 | length = bbox[i, 3:6] - bbox[i, 0:3]
419 | a = length[0]
420 | b = length[1]
421 | c = length[2]
422 | x = bbox[i, 0]
423 | y = bbox[i, 1]
424 | z = bbox[i, 2]
425 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0
426 |
427 | material = 'material%d' % (ins_cnt)
428 | fout_obj.write('usemtl %s\n' % (material))
429 | fout_obj.write('v %f %f %f\n' % (x,y,z+c))
430 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c))
431 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c))
432 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c))
433 | fout_obj.write('v %f %f %f\n' % (x,y,z))
434 | fout_obj.write('v %f %f %f\n' % (x,y+b,z))
435 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z))
436 | fout_obj.write('v %f %f %f\n' % (x+a,y,z))
437 | fout_obj.write('g default\n')
438 | v_cnt = 0 # for individual box
439 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt))
440 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt))
441 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt))
442 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt))
443 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt))
444 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt))
445 | fout_obj.write('\n')
446 |
447 | fout_mtl.write('newmtl %s\n' % (material))
448 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2]))
449 | fout_mtl.write('\n')
450 | fout_obj.close()
451 | fout_mtl.close()
452 |
453 | v_cnt += 8
454 | ins_cnt += 1
455 |
456 | def bbox_label_to_obj_room(input_filename, out_filename_prefix, easy_view=False, permute=None, center=False, exclude_table=False):
457 | """ Visualization of bounding boxes.
458 |
459 | Args:
460 | input_filename: each line is x1 y1 z1 x2 y2 z2 label
461 | out_filename_prefix: OBJ filename prefix,
462 | visualize object by g_label2color
463 | easy_view: if True, only visualize furniture and floor
464 | permute: if not None, permute XYZ for rendering, e.g. [0 2 1]
465 | center: if True, move obj to have zero origin
466 | Returns:
467 | output a list of OBJ file and MTL files with the same prefix
468 | """
469 | bbox_label = np.loadtxt(input_filename)
470 | bbox = bbox_label[:, 0:6]
471 | if permute is not None:
472 | assert(len(permute)==3)
473 | permute = np.array(permute)
474 | bbox[:,0:3] = bbox[:,permute]
475 | bbox[:,3:6] = bbox[:,permute+3]
476 | if center:
477 | xyz_max = np.amax(bbox[:,3:6], 0)
478 | bbox[:,0:3] -= (xyz_max/2.0)
479 | bbox[:,3:6] -= (xyz_max/2.0)
480 | bbox /= np.max(xyz_max/2.0)
481 | label = bbox_label[:, -1].astype(int)
482 | obj_filename = out_filename_prefix+'.obj'
483 | mtl_filename = out_filename_prefix+'.mtl'
484 |
485 | fout_obj = open(obj_filename, 'w')
486 | fout_mtl = open(mtl_filename, 'w')
487 | fout_obj.write('mtllib %s\n' % (os.path.basename(mtl_filename)))
488 | v_cnt = 0 # count vertex
489 | ins_cnt = 0 # count instance
490 | for i in range(bbox.shape[0]):
491 | if easy_view and (label[i] not in g_easy_view_labels):
492 | continue
493 | if exclude_table and label[i] == g_classes.index('table'):
494 | continue
495 |
496 | length = bbox[i, 3:6] - bbox[i, 0:3]
497 | a = length[0]
498 | b = length[1]
499 | c = length[2]
500 | x = bbox[i, 0]
501 | y = bbox[i, 1]
502 | z = bbox[i, 2]
503 | color = np.array(g_label2color[label[i]], dtype=float) / 255.0
504 |
505 | material = 'material%d' % (ins_cnt)
506 | fout_obj.write('usemtl %s\n' % (material))
507 | fout_obj.write('v %f %f %f\n' % (x,y,z+c))
508 | fout_obj.write('v %f %f %f\n' % (x,y+b,z+c))
509 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z+c))
510 | fout_obj.write('v %f %f %f\n' % (x+a,y,z+c))
511 | fout_obj.write('v %f %f %f\n' % (x,y,z))
512 | fout_obj.write('v %f %f %f\n' % (x,y+b,z))
513 | fout_obj.write('v %f %f %f\n' % (x+a,y+b,z))
514 | fout_obj.write('v %f %f %f\n' % (x+a,y,z))
515 | fout_obj.write('g default\n')
516 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 3+v_cnt, 2+v_cnt, 1+v_cnt))
517 | fout_obj.write('f %d %d %d %d\n' % (1+v_cnt, 2+v_cnt, 6+v_cnt, 5+v_cnt))
518 | fout_obj.write('f %d %d %d %d\n' % (7+v_cnt, 6+v_cnt, 2+v_cnt, 3+v_cnt))
519 | fout_obj.write('f %d %d %d %d\n' % (4+v_cnt, 8+v_cnt, 7+v_cnt, 3+v_cnt))
520 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 8+v_cnt, 4+v_cnt, 1+v_cnt))
521 | fout_obj.write('f %d %d %d %d\n' % (5+v_cnt, 6+v_cnt, 7+v_cnt, 8+v_cnt))
522 | fout_obj.write('\n')
523 |
524 | fout_mtl.write('newmtl %s\n' % (material))
525 | fout_mtl.write('Kd %f %f %f\n' % (color[0], color[1], color[2]))
526 | fout_mtl.write('\n')
527 |
528 | v_cnt += 8
529 | ins_cnt += 1
530 |
531 | fout_obj.close()
532 | fout_mtl.close()
533 |
534 |
535 | def collect_point_bounding_box(anno_path, out_filename, file_format):
536 | """ Compute bounding boxes from each instance in original dataset files on
537 | one room. **We assume the bbox is aligned with XYZ coordinate.**
538 | Save both the point XYZRGB and the bounding box for the point's
539 | parent element.
540 |
541 | Args:
542 | anno_path: path to annotations. e.g. Area_1/office_2/Annotations/
543 | out_filename: path to save instance bounding boxes for each point,
544 | plus the point's XYZRGBL
545 | each line is XYZRGBL offsetX offsetY offsetZ a b c,
546 | where cx = X+offsetX, cy=X+offsetY, cz=Z+offsetZ
547 | where (cx,cy,cz) is center of the box, a,b,c are distances from center
548 | to the surfaces of the box, i.e. x1 = cx-a, x2 = cx+a, y1=cy-b etc.
549 | file_format: output file format, txt or numpy
550 | Returns:
551 | None
552 |
553 | Note:
554 | room points are shifted, the most negative point is now at origin.
555 | """
556 | point_bbox_list = []
557 |
558 | for f in glob.glob(os.path.join(anno_path, '*.txt')):
559 | cls = os.path.basename(f).split('_')[0]
560 | if cls not in g_classes: # note: in some room there is 'staris' class..
561 | cls = 'clutter'
562 | points = np.loadtxt(f) # Nx6
563 | label = g_class2label[cls] # N,
564 | # Compute tightest axis aligned bounding box
565 | xyz_min = np.amin(points[:, 0:3], axis=0) # 3,
566 | xyz_max = np.amax(points[:, 0:3], axis=0) # 3,
567 | xyz_center = (xyz_min + xyz_max) / 2
568 | dimension = (xyz_max - xyz_min) / 2
569 |
570 | xyz_offsets = xyz_center - points[:,0:3] # Nx3
571 | dimensions = np.ones((points.shape[0],3)) * dimension # Nx3
572 | labels = np.ones((points.shape[0],1)) * label # N
573 | point_bbox_list.append(np.concatenate([points, labels,
574 | xyz_offsets, dimensions], 1)) # Nx13
575 |
576 | point_bbox = np.concatenate(point_bbox_list, 0) # KxNx13
577 | room_xyz_min = np.amin(point_bbox[:, 0:3], axis=0)
578 | point_bbox[:, 0:3] -= room_xyz_min
579 |
580 | if file_format == 'txt':
581 | fout = open(out_filename, 'w')
582 | for i in range(point_bbox.shape[0]):
583 | fout.write('%f %f %f %d %d %d %d %f %f %f %f %f %f\n' % \
584 | (point_bbox[i,0], point_bbox[i,1], point_bbox[i,2],
585 | point_bbox[i,3], point_bbox[i,4], point_bbox[i,5],
586 | point_bbox[i,6],
587 | point_bbox[i,7], point_bbox[i,8], point_bbox[i,9],
588 | point_bbox[i,10], point_bbox[i,11], point_bbox[i,12]))
589 |
590 | fout.close()
591 | elif file_format == 'numpy':
592 | np.save(out_filename, point_bbox)
593 | else:
594 | print('ERROR!! Unknown file format: %s, please use txt or numpy.' % \
595 | (file_format))
596 | exit()
597 |
598 |
599 |
--------------------------------------------------------------------------------
/data_utils/meta/anno_paths.txt:
--------------------------------------------------------------------------------
1 | Area_1/conferenceRoom_1/Annotations
2 | Area_1/conferenceRoom_2/Annotations
3 | Area_1/copyRoom_1/Annotations
4 | Area_1/hallway_1/Annotations
5 | Area_1/hallway_2/Annotations
6 | Area_1/hallway_3/Annotations
7 | Area_1/hallway_4/Annotations
8 | Area_1/hallway_5/Annotations
9 | Area_1/hallway_6/Annotations
10 | Area_1/hallway_7/Annotations
11 | Area_1/hallway_8/Annotations
12 | Area_1/office_10/Annotations
13 | Area_1/office_11/Annotations
14 | Area_1/office_12/Annotations
15 | Area_1/office_13/Annotations
16 | Area_1/office_14/Annotations
17 | Area_1/office_15/Annotations
18 | Area_1/office_16/Annotations
19 | Area_1/office_17/Annotations
20 | Area_1/office_18/Annotations
21 | Area_1/office_19/Annotations
22 | Area_1/office_1/Annotations
23 | Area_1/office_20/Annotations
24 | Area_1/office_21/Annotations
25 | Area_1/office_22/Annotations
26 | Area_1/office_23/Annotations
27 | Area_1/office_24/Annotations
28 | Area_1/office_25/Annotations
29 | Area_1/office_26/Annotations
30 | Area_1/office_27/Annotations
31 | Area_1/office_28/Annotations
32 | Area_1/office_29/Annotations
33 | Area_1/office_2/Annotations
34 | Area_1/office_30/Annotations
35 | Area_1/office_31/Annotations
36 | Area_1/office_3/Annotations
37 | Area_1/office_4/Annotations
38 | Area_1/office_5/Annotations
39 | Area_1/office_6/Annotations
40 | Area_1/office_7/Annotations
41 | Area_1/office_8/Annotations
42 | Area_1/office_9/Annotations
43 | Area_1/pantry_1/Annotations
44 | Area_1/WC_1/Annotations
45 | Area_2/auditorium_1/Annotations
46 | Area_2/auditorium_2/Annotations
47 | Area_2/conferenceRoom_1/Annotations
48 | Area_2/hallway_10/Annotations
49 | Area_2/hallway_11/Annotations
50 | Area_2/hallway_12/Annotations
51 | Area_2/hallway_1/Annotations
52 | Area_2/hallway_2/Annotations
53 | Area_2/hallway_3/Annotations
54 | Area_2/hallway_4/Annotations
55 | Area_2/hallway_5/Annotations
56 | Area_2/hallway_6/Annotations
57 | Area_2/hallway_7/Annotations
58 | Area_2/hallway_8/Annotations
59 | Area_2/hallway_9/Annotations
60 | Area_2/office_10/Annotations
61 | Area_2/office_11/Annotations
62 | Area_2/office_12/Annotations
63 | Area_2/office_13/Annotations
64 | Area_2/office_14/Annotations
65 | Area_2/office_1/Annotations
66 | Area_2/office_2/Annotations
67 | Area_2/office_3/Annotations
68 | Area_2/office_4/Annotations
69 | Area_2/office_5/Annotations
70 | Area_2/office_6/Annotations
71 | Area_2/office_7/Annotations
72 | Area_2/office_8/Annotations
73 | Area_2/office_9/Annotations
74 | Area_2/storage_1/Annotations
75 | Area_2/storage_2/Annotations
76 | Area_2/storage_3/Annotations
77 | Area_2/storage_4/Annotations
78 | Area_2/storage_5/Annotations
79 | Area_2/storage_6/Annotations
80 | Area_2/storage_7/Annotations
81 | Area_2/storage_8/Annotations
82 | Area_2/storage_9/Annotations
83 | Area_2/WC_1/Annotations
84 | Area_2/WC_2/Annotations
85 | Area_3/conferenceRoom_1/Annotations
86 | Area_3/hallway_1/Annotations
87 | Area_3/hallway_2/Annotations
88 | Area_3/hallway_3/Annotations
89 | Area_3/hallway_4/Annotations
90 | Area_3/hallway_5/Annotations
91 | Area_3/hallway_6/Annotations
92 | Area_3/lounge_1/Annotations
93 | Area_3/lounge_2/Annotations
94 | Area_3/office_10/Annotations
95 | Area_3/office_1/Annotations
96 | Area_3/office_2/Annotations
97 | Area_3/office_3/Annotations
98 | Area_3/office_4/Annotations
99 | Area_3/office_5/Annotations
100 | Area_3/office_6/Annotations
101 | Area_3/office_7/Annotations
102 | Area_3/office_8/Annotations
103 | Area_3/office_9/Annotations
104 | Area_3/storage_1/Annotations
105 | Area_3/storage_2/Annotations
106 | Area_3/WC_1/Annotations
107 | Area_3/WC_2/Annotations
108 | Area_4/conferenceRoom_1/Annotations
109 | Area_4/conferenceRoom_2/Annotations
110 | Area_4/conferenceRoom_3/Annotations
111 | Area_4/hallway_10/Annotations
112 | Area_4/hallway_11/Annotations
113 | Area_4/hallway_12/Annotations
114 | Area_4/hallway_13/Annotations
115 | Area_4/hallway_14/Annotations
116 | Area_4/hallway_1/Annotations
117 | Area_4/hallway_2/Annotations
118 | Area_4/hallway_3/Annotations
119 | Area_4/hallway_4/Annotations
120 | Area_4/hallway_5/Annotations
121 | Area_4/hallway_6/Annotations
122 | Area_4/hallway_7/Annotations
123 | Area_4/hallway_8/Annotations
124 | Area_4/hallway_9/Annotations
125 | Area_4/lobby_1/Annotations
126 | Area_4/lobby_2/Annotations
127 | Area_4/office_10/Annotations
128 | Area_4/office_11/Annotations
129 | Area_4/office_12/Annotations
130 | Area_4/office_13/Annotations
131 | Area_4/office_14/Annotations
132 | Area_4/office_15/Annotations
133 | Area_4/office_16/Annotations
134 | Area_4/office_17/Annotations
135 | Area_4/office_18/Annotations
136 | Area_4/office_19/Annotations
137 | Area_4/office_1/Annotations
138 | Area_4/office_20/Annotations
139 | Area_4/office_21/Annotations
140 | Area_4/office_22/Annotations
141 | Area_4/office_2/Annotations
142 | Area_4/office_3/Annotations
143 | Area_4/office_4/Annotations
144 | Area_4/office_5/Annotations
145 | Area_4/office_6/Annotations
146 | Area_4/office_7/Annotations
147 | Area_4/office_8/Annotations
148 | Area_4/office_9/Annotations
149 | Area_4/storage_1/Annotations
150 | Area_4/storage_2/Annotations
151 | Area_4/storage_3/Annotations
152 | Area_4/storage_4/Annotations
153 | Area_4/WC_1/Annotations
154 | Area_4/WC_2/Annotations
155 | Area_4/WC_3/Annotations
156 | Area_4/WC_4/Annotations
157 | Area_5/conferenceRoom_1/Annotations
158 | Area_5/conferenceRoom_2/Annotations
159 | Area_5/conferenceRoom_3/Annotations
160 | Area_5/hallway_10/Annotations
161 | Area_5/hallway_11/Annotations
162 | Area_5/hallway_12/Annotations
163 | Area_5/hallway_13/Annotations
164 | Area_5/hallway_14/Annotations
165 | Area_5/hallway_15/Annotations
166 | Area_5/hallway_1/Annotations
167 | Area_5/hallway_2/Annotations
168 | Area_5/hallway_3/Annotations
169 | Area_5/hallway_4/Annotations
170 | Area_5/hallway_5/Annotations
171 | Area_5/hallway_6/Annotations
172 | Area_5/hallway_7/Annotations
173 | Area_5/hallway_8/Annotations
174 | Area_5/hallway_9/Annotations
175 | Area_5/lobby_1/Annotations
176 | Area_5/office_10/Annotations
177 | Area_5/office_11/Annotations
178 | Area_5/office_12/Annotations
179 | Area_5/office_13/Annotations
180 | Area_5/office_14/Annotations
181 | Area_5/office_15/Annotations
182 | Area_5/office_16/Annotations
183 | Area_5/office_17/Annotations
184 | Area_5/office_18/Annotations
185 | Area_5/office_19/Annotations
186 | Area_5/office_1/Annotations
187 | Area_5/office_20/Annotations
188 | Area_5/office_21/Annotations
189 | Area_5/office_22/Annotations
190 | Area_5/office_23/Annotations
191 | Area_5/office_24/Annotations
192 | Area_5/office_25/Annotations
193 | Area_5/office_26/Annotations
194 | Area_5/office_27/Annotations
195 | Area_5/office_28/Annotations
196 | Area_5/office_29/Annotations
197 | Area_5/office_2/Annotations
198 | Area_5/office_30/Annotations
199 | Area_5/office_31/Annotations
200 | Area_5/office_32/Annotations
201 | Area_5/office_33/Annotations
202 | Area_5/office_34/Annotations
203 | Area_5/office_35/Annotations
204 | Area_5/office_36/Annotations
205 | Area_5/office_37/Annotations
206 | Area_5/office_38/Annotations
207 | Area_5/office_39/Annotations
208 | Area_5/office_3/Annotations
209 | Area_5/office_40/Annotations
210 | Area_5/office_41/Annotations
211 | Area_5/office_42/Annotations
212 | Area_5/office_4/Annotations
213 | Area_5/office_5/Annotations
214 | Area_5/office_6/Annotations
215 | Area_5/office_7/Annotations
216 | Area_5/office_8/Annotations
217 | Area_5/office_9/Annotations
218 | Area_5/pantry_1/Annotations
219 | Area_5/storage_1/Annotations
220 | Area_5/storage_2/Annotations
221 | Area_5/storage_3/Annotations
222 | Area_5/storage_4/Annotations
223 | Area_5/WC_1/Annotations
224 | Area_5/WC_2/Annotations
225 | Area_6/conferenceRoom_1/Annotations
226 | Area_6/copyRoom_1/Annotations
227 | Area_6/hallway_1/Annotations
228 | Area_6/hallway_2/Annotations
229 | Area_6/hallway_3/Annotations
230 | Area_6/hallway_4/Annotations
231 | Area_6/hallway_5/Annotations
232 | Area_6/hallway_6/Annotations
233 | Area_6/lounge_1/Annotations
234 | Area_6/office_10/Annotations
235 | Area_6/office_11/Annotations
236 | Area_6/office_12/Annotations
237 | Area_6/office_13/Annotations
238 | Area_6/office_14/Annotations
239 | Area_6/office_15/Annotations
240 | Area_6/office_16/Annotations
241 | Area_6/office_17/Annotations
242 | Area_6/office_18/Annotations
243 | Area_6/office_19/Annotations
244 | Area_6/office_1/Annotations
245 | Area_6/office_20/Annotations
246 | Area_6/office_21/Annotations
247 | Area_6/office_22/Annotations
248 | Area_6/office_23/Annotations
249 | Area_6/office_24/Annotations
250 | Area_6/office_25/Annotations
251 | Area_6/office_26/Annotations
252 | Area_6/office_27/Annotations
253 | Area_6/office_28/Annotations
254 | Area_6/office_29/Annotations
255 | Area_6/office_2/Annotations
256 | Area_6/office_30/Annotations
257 | Area_6/office_31/Annotations
258 | Area_6/office_32/Annotations
259 | Area_6/office_33/Annotations
260 | Area_6/office_34/Annotations
261 | Area_6/office_35/Annotations
262 | Area_6/office_36/Annotations
263 | Area_6/office_37/Annotations
264 | Area_6/office_3/Annotations
265 | Area_6/office_4/Annotations
266 | Area_6/office_5/Annotations
267 | Area_6/office_6/Annotations
268 | Area_6/office_7/Annotations
269 | Area_6/office_8/Annotations
270 | Area_6/office_9/Annotations
271 | Area_6/openspace_1/Annotations
272 | Area_6/pantry_1/Annotations
273 |
--------------------------------------------------------------------------------
/data_utils/meta/class_names.txt:
--------------------------------------------------------------------------------
1 | ceiling
2 | floor
3 | wall
4 | beam
5 | column
6 | window
7 | door
8 | table
9 | chair
10 | sofa
11 | bookcase
12 | board
13 | clutter
14 |
--------------------------------------------------------------------------------
/image/plane1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/plane1.png
--------------------------------------------------------------------------------
/image/plane2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/plane2.png
--------------------------------------------------------------------------------
/image/psn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/image/psn.png
--------------------------------------------------------------------------------
/models/PointSamplingNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch import Tensor
6 | from typing import List, Tuple
7 |
8 |
9 | class PointSamplingNet(nn.Module):
10 | """
11 | Point Sampling Net PyTorch Module.
12 |
13 | Attributes:
14 | num_to_sample: the number to sample, int
15 | max_local_num: the max number of local area, int
16 | mlp: the channels of feature transform function, List[int]
17 | global_geature: whether enable global feature, bool
18 | """
19 |
20 | def __init__(self, num_to_sample: int = 512, max_local_num: int = 32, mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None:
21 | """
22 | Initialization of Point Sampling Net.
23 | """
24 | super(PointSamplingNet, self).__init__()
25 |
26 | self.mlp_convs = nn.ModuleList()
27 | self.mlp_bns = nn.ModuleList()
28 |
29 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !"
30 |
31 | self.mlp_convs.append(
32 | nn.Conv1d(in_channels=6, out_channels=mlp[0], kernel_size=1))
33 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0]))
34 |
35 | for i in range(len(mlp)-1):
36 | self.mlp_convs.append(
37 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1))
38 |
39 | for i in range(len(mlp)-1):
40 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1]))
41 |
42 | self.global_feature = global_feature
43 |
44 | if self.global_feature:
45 | self.mlp_convs.append(
46 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1))
47 | else:
48 | self.mlp_convs.append(
49 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1))
50 |
51 | self.s = num_to_sample
52 | self.n = max_local_num
53 |
54 | def forward(self, coordinate: Tensor, feature: Tensor, train: bool = False) -> Tuple[Tensor, Tensor]:
55 | """
56 | Forward propagation of Point Sampling Net
57 |
58 | Args:
59 | coordinate: input points position data, [B, m, 3]
60 | feature: input points feature, [B, m, d]
61 | Returns:
62 | sampled indices: the indices of sampled points, [B, s]
63 | grouped_indices: the indices of grouped points, [B, s, n]
64 | """
65 | _, m, _ = coordinate.size()
66 |
67 | assert self.s < m, "The number to sample must less than input points !"
68 |
69 | r = torch.sqrt(torch.pow(coordinate[:,:,0],2)+torch.pow(coordinate[:,:,1],2)+torch.pow(coordinate[:,:,2],2))
70 | th = torch.acos(coordinate[:,:,2] / r)
71 | fi = torch.atan2(coordinate[:,:,1], coordinate[:,:,0])
72 |
73 | coordinate = torch.cat([coordinate, th.unsqueeze_(2), fi.unsqueeze_(2)], -1)
74 |
75 | x = coordinate.transpose(2, 1) # Channel First
76 |
77 | for i in range(len(self.mlp_convs) - 1):
78 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x)))
79 |
80 | if self.global_feature:
81 | max_feature = torch.max(x, 2, keepdim=True)[0]
82 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m]
83 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m]
84 |
85 | x = self.mlp_convs[-1](x) # [B,s,m]
86 |
87 | Q = torch.sigmoid(x) # [B, s, m]
88 |
89 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m]
90 | grouped_indices = indices[:,:,:self.n]
91 | grouped_points = index_points(coordinate, grouped_indices)[:,:,:self.n,:] #[B,s,n,3]
92 | if feature is not None:
93 | grouped_feature = index_points(feature, grouped_indices)[:,:,:self.n,:] #[B,s,n,d]
94 | if not train:
95 | sampled_points = grouped_points[:,:,0,:] # [B,s,3]
96 | sampled_feature = grouped_feature[:,:,0,:] #[B,s,d]
97 | else:
98 | Q = gumbel_softmax_sample(Q) # [B, s, m]
99 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3]
100 | sampled_feature = torch.matmul(Q, feature) # [B,s,d]
101 | grouped_feature[:,:,0,:] = sampled_feature
102 | else:
103 | if not train:
104 | sampled_points = grouped_points[:,:,0,:] # [B,s,3]
105 | sampled_feature = None #[B,s,d]
106 | else:
107 | Q = gumbel_softmax_sample(Q) # [B, s, m]
108 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3]
109 | sampled_feature = None
110 | grouped_feature = None
111 |
112 | return sampled_points, grouped_points, sampled_feature, grouped_feature
113 |
114 | class PointSamplingNetRadius(nn.Module):
115 | """
116 | Point Sampling Net with heuristic condition PyTorch Module.
117 | This example is radius query
118 | You may replace function C(x) by your own function
119 |
120 | Attributes:
121 | num_to_sample: the number to sample, int
122 | radius: radius to query, float
123 | max_local_num: the max number of local area, int
124 | mlp: the channels of feature transform function, List[int]
125 | global_geature: whether enable global feature, bool
126 | """
127 |
128 | def __init__(self, num_to_sample: int = 512, radius: float = 1.0, max_local_num: int = 32, mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None:
129 | """
130 | Initialization of Point Sampling Net.
131 | """
132 | super(PointSamplingNetRadius, self).__init__()
133 |
134 | self.mlp_convs = nn.ModuleList()
135 | self.mlp_bns = nn.ModuleList()
136 | self.radius = radius
137 |
138 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !"
139 |
140 | self.mlp_convs.append(
141 | nn.Conv1d(in_channels=3, out_channels=mlp[0], kernel_size=1))
142 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0]))
143 |
144 | for i in range(len(mlp)-1):
145 | self.mlp_convs.append(
146 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1))
147 |
148 | for i in range(len(mlp)-1):
149 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1]))
150 |
151 | self.global_feature = global_feature
152 |
153 | if self.global_feature:
154 | self.mlp_convs.append(
155 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1))
156 | else:
157 | self.mlp_convs.append(
158 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1))
159 |
160 | self.softmax = nn.Softmax(dim=1)
161 |
162 | self.s = num_to_sample
163 | self.n = max_local_num
164 |
165 | def forward(self, coordinate: Tensor) -> Tuple[Tensor, Tensor]:
166 | """
167 | Forward propagation of Point Sampling Net
168 |
169 | Args:
170 | coordinate: input points position data, [B, m, 3]
171 | Returns:
172 | sampled indices: the indices of sampled points, [B, s]
173 | grouped_indices: the indices of grouped points, [B, s, n]
174 | """
175 | _, m, _ = coordinate.size()
176 |
177 | assert self.s < m, "The number to sample must less than input points !"
178 |
179 | x = coordinate.transpose(2, 1) # Channel First
180 |
181 | for i in range(len(self.mlp_convs) - 1):
182 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x)))
183 |
184 | if self.global_feature:
185 | max_feature = torch.max(x, 2, keepdim=True)[0]
186 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m]
187 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m]
188 |
189 | x = self.mlp_convs[-1](x) # [B,s,m]
190 |
191 | Q = self.softmax(x) # [B, s, m]
192 |
193 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m]
194 |
195 | grouped_indices = indices[:, :, 0:self.n] # [B, s, n]
196 |
197 | sampled_indices = indices[:, :, 0] # [B, s]
198 |
199 | # function C(x)
200 | # you may replace C(x) by your heuristic condition
201 | sampled_coordinate = torch.unsqueeze(index_points(coordinate, sampled_indices), dim=2) # [B, s, 1, 3]
202 | grouped_coordinate = index_points(coordinate, grouped_indices) # [B, s, m, 3]
203 |
204 | diff = grouped_coordinate - sampled_coordinate
205 | diff = diff ** 2
206 | diff = torch.sum(diff, dim=3) #[B, s, m]
207 | mask = diff > self.radius ** 2
208 |
209 | sampled_indices_expand = torch.unsqueeze(sampled_indices, dim=2).repeat(1, 1, self.n) #[B, s, n]
210 | grouped_indices[mask] = sampled_indices_expand[mask]
211 | # function C(x) end
212 |
213 | return sampled_indices, grouped_indices
214 |
215 | class PointSamplingNetMSG(nn.Module):
216 | """
217 | Point Sampling Net with Multi-scale Grouping PyTorch Module.
218 |
219 | Attributes:
220 | num_to_sample: the number to sample, int
221 | msg_n: the list of mutil-scale grouping n values, List[int]
222 | mlp: the channels of feature transform function, List[int]
223 | global_geature: whether enable global feature, bool
224 | """
225 |
226 | def __init__(self, num_to_sample: int = 512, msg_n: List[int] = [32, 64], mlp: List[int] = [32, 64, 256], global_feature: bool = False) -> None:
227 | """
228 | Initialization of Point Sampling Net.
229 | """
230 | super(PointSamplingNetMSG, self).__init__()
231 |
232 | self.mlp_convs = nn.ModuleList()
233 | self.mlp_bns = nn.ModuleList()
234 |
235 | assert len(mlp) > 1, "The number of MLP layers must greater than 1 !"
236 |
237 | self.mlp_convs.append(
238 | nn.Conv1d(in_channels=3, out_channels=mlp[0], kernel_size=1))
239 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[0]))
240 |
241 | for i in range(len(mlp)-1):
242 | self.mlp_convs.append(
243 | nn.Conv1d(in_channels=mlp[i], out_channels=mlp[i+1], kernel_size=1))
244 |
245 | for i in range(len(mlp)-1):
246 | self.mlp_bns.append(nn.BatchNorm1d(num_features=mlp[i+1]))
247 |
248 | self.global_feature = global_feature
249 |
250 | if self.global_feature:
251 | self.mlp_convs.append(
252 | nn.Conv1d(in_channels=mlp[-1] * 2, out_channels=num_to_sample, kernel_size=1))
253 | else:
254 | self.mlp_convs.append(
255 | nn.Conv1d(in_channels=mlp[-1], out_channels=num_to_sample, kernel_size=1))
256 |
257 | self.softmax = nn.Softmax(dim=1)
258 |
259 | self.s = num_to_sample
260 | self.msg_n = msg_n
261 |
262 | def forward(self, coordinate: Tensor, feature: Tensor, train: bool = False) -> Tuple[Tensor, Tensor]:
263 | """
264 | Forward propagation of Point Sampling Net
265 |
266 | Args:
267 | coordinate: input points position data, [B, m, 3]
268 | Returns:
269 | sampled indices: the indices of sampled points, [B, s]
270 | grouped_indices_msg: the multi-scale grouping indices of grouped points, List[Tensor]
271 | """
272 | _, m, _ = coordinate.size()
273 |
274 | assert self.s < m, "The number to sample must less than input points !"
275 |
276 | x = coordinate.transpose(2, 1) # Channel First
277 |
278 | for i in range(len(self.mlp_convs) - 1):
279 | x = F.relu(self.mlp_bns[i](self.mlp_convs[i](x)))
280 |
281 | if self.global_feature:
282 | max_feature = torch.max(x, 2, keepdim=True)[0]
283 | max_feature = max_feature.repeat(1, 1, m) # [B, mlp[-1], m]
284 | x = torch.cat([x, max_feature], 1) # [B, mlp[-1] * 2, m]
285 |
286 | x = self.mlp_convs[-1](x) # [B,s,m]
287 |
288 |
289 | Q = torch.sigmoid(x) # [B, s, m]
290 |
291 | _, indices = torch.sort(input=Q, dim=2, descending=True) # [B, s, m]
292 | grouped_indices = indices[:,:,:self.n]
293 | grouped_points_msg = []
294 | for n in self.msg_n:
295 | grouped_points_msg.append(index_points(coordinate, grouped_indices)[:,:,:n,:])
296 | if feature is not None:
297 | grouped_feature_msg = []
298 | for n in self.msg_n:
299 | grouped_feature_msg.append(index_points(feature, grouped_indices)[:,:,:n,:])
300 | if not train:
301 | sampled_points = grouped_points_msg[0][:,:,0,:] # [B,s,3]
302 | sampled_feature = grouped_feature_msg[-1][:,:,0,:] #[B,s,d]
303 | else:
304 | Q = gumbel_softmax_sample(Q) # [B, s, m]
305 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3]
306 | sampled_feature = torch.matmul(Q, feature) # [B,s,d]
307 | for n in self.msg_n:
308 | grouped_feature_msg[n][:,:,0,:] = sampled_feature
309 | else:
310 | if not train:
311 | sampled_points = grouped_points_msg[0][:,:,0,:] # [B,s,3]
312 | sampled_feature = None #[B,s,d]
313 | else:
314 | Q = gumbel_softmax_sample(Q) # [B, s, m]
315 | sampled_points = torch.matmul(Q, coordinate) # [B,s,3]
316 | sampled_feature = None
317 | grouped_feature_msg = None
318 |
319 | return sampled_points, grouped_points_msg, sampled_feature, grouped_feature_msg
320 |
321 |
322 | def index_points(points, idx):
323 | """
324 |
325 | Input:
326 | points: input points data, [B, N, C]
327 | idx: sample index data, [B, S]
328 | Return:
329 | new_points:, indexed points data, [B, S, C]
330 | """
331 | device = points.device
332 | B = points.shape[0]
333 | view_shape = list(idx.shape)
334 | view_shape[1:] = [1] * (len(view_shape) - 1)
335 | repeat_shape = list(idx.shape)
336 | repeat_shape[0] = 1
337 | batch_indices = torch.arange(B, dtype=torch.long).to(
338 | device).view(view_shape).repeat(repeat_shape)
339 | new_points = points[batch_indices, idx, :]
340 | return new_points
341 |
342 | def sample_gumbel(shape, eps=1e-20):
343 | U = torch.rand(shape)
344 | U = U.cuda()
345 | return -torch.log(-torch.log(U + eps) + eps)
346 |
347 |
348 | def gumbel_softmax_sample(logits, dim=-1, temperature=0.001):
349 | y = logits + sample_gumbel(logits.size())
350 | return F.softmax(y / temperature, dim=dim)
351 |
352 | def gumbel_softmax(logits, temperature=1.0, hard=False):
353 | """Sample from the Gumbel-Softmax distribution and optionally discretize.
354 | Args:
355 | logits: [batch_size, n_class] unnormalized log-probs
356 | temperature: non-negative scalar
357 | hard: if True, take argmax, but differentiate w.r.t. soft sample y
358 | Returns:
359 | [batch_size, n_class] sample from the Gumbel-Softmax distribution.
360 | If hard=True, then the returned sample will be one-hot, otherwise it will
361 | be a probabilitiy distribution that sums to 1 across classes
362 | """
363 | y = gumbel_softmax_sample(logits, temperature)
364 | if hard:
365 | y_hard = onehot_from_logits(y)
366 | #print(y_hard[0], "random")
367 | y = (y_hard - y).detach() + y
368 | return y
369 |
370 | def onehot_from_logits(logits, eps=0.0):
371 | """
372 | Given batch of logits, return one-hot sample using epsilon greedy strategy
373 | (based on given epsilon)
374 | """
375 | # get best (according to current policy) actions in one-hot form
376 | argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()
377 | #print(logits[0],"a")
378 | #print(len(argmax_acs),argmax_acs[0])
379 | if eps == 0.0:
380 | return argmax_acs
381 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/psn-anonymous/PointSamplingNet/42188d8128662aa7a03dcf590743d5e6f2eb0457/models/__init__.py
--------------------------------------------------------------------------------
/models/pointnet2_cls_ssg_psn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from models.pointnet_util_psn import PointNetSetAbstraction
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self,num_class,normal_channel=True):
8 | super(get_model, self).__init__()
9 | in_channel = 6 if normal_channel else 3
10 | self.normal_channel = normal_channel
11 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
12 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
13 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
14 | self.fc1 = nn.Linear(1024, 512)
15 | self.bn1 = nn.BatchNorm1d(512)
16 | self.drop1 = nn.Dropout(0.4)
17 | self.fc2 = nn.Linear(512, 256)
18 | self.bn2 = nn.BatchNorm1d(256)
19 | self.drop2 = nn.Dropout(0.4)
20 | self.fc3 = nn.Linear(256, num_class)
21 |
22 | def forward(self, xyz, train):
23 | B, _, _ = xyz.shape
24 | if self.normal_channel:
25 | norm = xyz[:, 3:, :]
26 | xyz = xyz[:, :3, :]
27 | else:
28 | norm = None
29 | l1_xyz, l1_points = self.sa1(xyz, norm, train)
30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train)
31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train)
32 | x = l3_points.view(B, 1024)
33 | x = self.drop1(F.relu(self.bn1(self.fc1(x))))
34 | x = self.drop2(F.relu(self.bn2(self.fc2(x))))
35 | x = self.fc3(x)
36 | x = F.log_softmax(x, -1)
37 |
38 |
39 | return x, l3_points
40 |
41 |
42 |
43 | class get_loss(nn.Module):
44 | def __init__(self):
45 | super(get_loss, self).__init__()
46 |
47 | def forward(self, pred, target, trans_feat):
48 | total_loss = F.nll_loss(pred, target)
49 |
50 | return total_loss
51 |
--------------------------------------------------------------------------------
/models/pointnet2_part_seg_ssg_psn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from models.pointnet_util_psn import PointNetSetAbstraction,PointNetFeaturePropagation
5 |
6 |
7 | class get_model(nn.Module):
8 | def __init__(self, num_classes, normal_channel=False):
9 | super(get_model, self).__init__()
10 | if normal_channel:
11 | additional_channel = 3
12 | else:
13 | additional_channel = 0
14 | self.normal_channel = normal_channel
15 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=6+additional_channel, mlp=[64, 64, 128], group_all=False) # 这里修改过 nsample原来是32
16 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
18 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256])
19 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128])
20 | self.fp1 = PointNetFeaturePropagation(in_channel=128+16+6+additional_channel, mlp=[128, 128, 128])
21 | self.conv1 = nn.Conv1d(128, 128, 1)
22 | self.bn1 = nn.BatchNorm1d(128)
23 | self.drop1 = nn.Dropout(0.5)
24 | self.conv2 = nn.Conv1d(128, num_classes, 1)
25 |
26 | def forward(self, xyz, cls_label, train):
27 | # Set Abstraction layers
28 | B,C,N = xyz.shape
29 | if self.normal_channel:
30 | l0_points = xyz
31 | l0_xyz = xyz[:,:3,:]
32 | else:
33 | l0_points = xyz
34 | l0_xyz = xyz
35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points, train)
36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train)
37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train)
38 | # Feature Propagation layers
39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
43 | # FC layers
44 | feat = F.relu(self.bn1(self.conv1(l0_points)))
45 | x = self.drop1(feat)
46 | x = self.conv2(x)
47 | x = F.log_softmax(x, dim=1)
48 | x = x.permute(0, 2, 1)
49 | # diff = diff1 + diff2 + diff3
50 | return x, l3_points
51 |
52 |
53 | class get_loss(nn.Module):
54 | def __init__(self):
55 | super(get_loss, self).__init__()
56 |
57 | def forward(self, pred, target, trans_feat):
58 | total_loss = F.nll_loss(pred, target)
59 |
60 | return total_loss
--------------------------------------------------------------------------------
/models/pointnet2_sem_seg_psn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from models.pointnet_util_psn import PointNetSetAbstraction,PointNetFeaturePropagation
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self, num_classes):
8 | super(get_model, self).__init__()
9 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
10 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
11 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
12 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
13 | self.fp4 = PointNetFeaturePropagation(768, [256, 256])
14 | self.fp3 = PointNetFeaturePropagation(384, [256, 256])
15 | self.fp2 = PointNetFeaturePropagation(320, [256, 128])
16 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
17 | self.conv1 = nn.Conv1d(128, 128, 1)
18 | self.bn1 = nn.BatchNorm1d(128)
19 | self.drop1 = nn.Dropout(0.5)
20 | self.conv2 = nn.Conv1d(128, num_classes, 1)
21 |
22 | def forward(self, xyz, train):
23 | l0_points = xyz
24 | l0_xyz = xyz[:,:3,:]
25 |
26 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points, train)
27 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points, train)
28 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points, train)
29 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points, train)
30 |
31 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
32 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
33 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
34 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
35 |
36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
37 | x = self.conv2(x)
38 | x = F.log_softmax(x, dim=1)
39 | x = x.permute(0, 2, 1)
40 | return x, l4_points
41 |
42 |
43 | class get_loss(nn.Module):
44 | def __init__(self):
45 | super(get_loss, self).__init__()
46 | def forward(self, pred, target, trans_feat, weight):
47 | total_loss = F.nll_loss(pred, target, weight=weight)
48 |
49 | return total_loss
50 |
51 | if __name__ == '__main__':
52 | import torch
53 | model = get_model(13)
54 | xyz = torch.rand(6, 9, 2048)
55 | (model(xyz))
--------------------------------------------------------------------------------
/models/pointnet_util_psn.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | import torch
3 | from torch.functional import chain_matmul
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from time import time
7 | import numpy as np
8 | from models import PointSamplingNet as psn
9 |
10 |
11 | def timeit(tag, t):
12 | print("{}: {}s".format(tag, time() - t))
13 | return time()
14 |
15 |
16 | def pc_normalize(pc):
17 | l = pc.shape[0]
18 | centroid = np.mean(pc, axis=0)
19 | pc = pc - centroid
20 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
21 | pc = pc / m
22 | return pc
23 |
24 |
25 | def square_distance(src, dst):
26 | """
27 | Calculate Euclid distance between each two points.
28 |
29 | src^T * dst = xn * xm + yn * ym + zn * zm;
30 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
31 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
32 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
33 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
34 |
35 | Input:
36 | src: source points, [B, N, C]
37 | dst: target points, [B, M, C]
38 | Output:
39 | dist: per-point square distance, [B, N, M]
40 | """
41 | B, N, _ = src.shape
42 | _, M, _ = dst.shape
43 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
44 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
45 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
46 | return dist
47 |
48 |
49 | def index_points(points, idx):
50 | """
51 |
52 | Input:
53 | points: input points data, [B, N, C]
54 | idx: sample index data, [B, S]
55 | Return:
56 | new_points:, indexed points data, [B, S, C]
57 | """
58 | device = points.device
59 | B = points.shape[0]
60 | view_shape = list(idx.shape)
61 | view_shape[1:] = [1] * (len(view_shape) - 1)
62 | repeat_shape = list(idx.shape)
63 | repeat_shape[0] = 1
64 | batch_indices = torch.arange(B, dtype=torch.long).to(
65 | device).view(view_shape).repeat(repeat_shape)
66 | new_points = points[batch_indices, idx, :]
67 | return new_points
68 |
69 |
70 | def sample_and_group_psn(npoint, sampled_points, grouped_points, sampled_feature, grouped_feature, nsample):
71 | """
72 | Sampling and grouping point cloud with PSN.
73 |
74 | Input:
75 | sampled_points: sampled points by PSN, [B, s, 3]
76 | grouped_points: grouped points by PSN, [B, s, n, 3]
77 | sampled_feature: sampled feature, [B, s, d]
78 | grouped_feature: grouped feature, [B, s, n, d]
79 | nsample: the max number of local area, int
80 | xyz: coordinate, [B, m, 3]
81 | points: feature , [B, m, d]
82 | Output:
83 | new_xyz: sampled points coordinate, [B, s, 3]
84 | new_points: sampled points feature, [B, s, d+3]
85 | """
86 | B, _, _ = sampled_points.shape
87 | S = npoint
88 | grouped_xyz_norm = grouped_points - sampled_points.view(B, S, 1, 3).repeat([1, 1, nsample, 1]) # [B,s,n,3]
89 |
90 | torch.cuda.empty_cache()
91 | if grouped_feature is not None:
92 | new_points = torch.cat([grouped_xyz_norm, grouped_feature], dim=-1)
93 | else:
94 | new_points = grouped_xyz_norm
95 | return sampled_points, new_points
96 |
97 |
98 | def query_ball_point(radius, nsample, xyz, new_xyz):
99 | """
100 | Input:
101 | radius: local region radius
102 | nsample: max sample number in local region
103 | xyz: all points, [B, N, 3]
104 | new_xyz: query points, [B, S, 3]
105 | Return:
106 | group_idx: grouped points index, [B, S, nsample]
107 | """
108 | device = xyz.device
109 | B, N, C = xyz.shape
110 | _, S, _ = new_xyz.shape
111 | group_idx = torch.arange(N, dtype=torch.long).to(
112 | device).view(1, 1, N).repeat([B, S, 1])
113 | sqrdists = square_distance(new_xyz, xyz)
114 | group_idx[sqrdists > radius ** 2] = N
115 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
116 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
117 | mask = group_idx == N
118 | group_idx[mask] = group_first[mask]
119 | return group_idx
120 |
121 |
122 |
123 | def sample_and_group_all(xyz, points):
124 | """
125 | Input:
126 | xyz: input points position data, [B, N, 3]
127 | points: input points data, [B, N, D]
128 | Return:
129 | new_xyz: sampled points position data, [B, 1, 3]
130 | new_points: sampled points data, [B, 1, N, 3+D]
131 | """
132 | device = xyz.device
133 | B, N, C = xyz.shape
134 | new_xyz = torch.zeros(B, 1, C).to(device)
135 | grouped_xyz = xyz.view(B, 1, N, C)
136 | if points is not None:
137 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
138 | else:
139 | new_points = grouped_xyz
140 | return new_xyz, new_points
141 |
142 |
143 | class PointNetSetAbstraction(nn.Module):
144 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
145 | super(PointNetSetAbstraction, self).__init__()
146 | self.npoint = npoint
147 | self.radius = radius
148 | self.nsample = nsample
149 | self.mlp_convs = nn.ModuleList()
150 | self.mlp_bns = nn.ModuleList()
151 | last_channel = in_channel
152 | for out_channel in mlp:
153 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
154 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
155 | last_channel = out_channel
156 | self.group_all = group_all
157 | if not group_all:
158 | # self.sampling = pointsampling.get_model(npoint)
159 | self.sampling = psn.PointSamplingNet(npoint, nsample, [64,128,256], global_feature=True)
160 |
161 | def forward(self, xyz, points, train):
162 | """
163 | Input:
164 | xyz: input points position data, [B, C, N]
165 | points: input points data, [B, D, N]
166 | Return:
167 | new_xyz: sampled points position data, [B, C, S]
168 | new_points_concat: sample points feature data, [B, D', S]
169 | """
170 | xyz = xyz.permute(0, 2, 1)
171 | if points is not None:
172 | points = points.permute(0, 2, 1)
173 |
174 | if self.group_all:
175 | new_xyz, new_points = sample_and_group_all(xyz, points)
176 | else:
177 | sampled_points, grouped_points, sampled_feature, grouped_feature = self.sampling(xyz,points,train)
178 | new_xyz, new_points = sample_and_group_psn(
179 | self.npoint, sampled_points, grouped_points, sampled_feature, grouped_feature, self.nsample, xyz, points)
180 | # new_xyz: sampled points position data, [B, npoint, C]
181 | # new_points: sampled points data, [B, npoint, nsample, C+D]
182 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
183 | for i, conv in enumerate(self.mlp_convs):
184 | bn = self.mlp_bns[i]
185 | new_points = F.relu(bn(conv(new_points)))
186 |
187 | new_points = torch.max(new_points, 2)[0]
188 | new_xyz = new_xyz.permute(0, 2, 1)
189 | return new_xyz, new_points
190 |
191 |
192 | class PointNetFeaturePropagation(nn.Module):
193 | def __init__(self, in_channel, mlp):
194 | super(PointNetFeaturePropagation, self).__init__()
195 | self.mlp_convs = nn.ModuleList()
196 | self.mlp_bns = nn.ModuleList()
197 | last_channel = in_channel
198 | for out_channel in mlp:
199 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
200 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
201 | last_channel = out_channel
202 |
203 | def forward(self, xyz1, xyz2, points1, points2):
204 | """
205 | Input:
206 | xyz1: input points position data, [B, C, N]
207 | xyz2: sampled input points position data, [B, C, S]
208 | points1: input points data, [B, D, N]
209 | points2: input points data, [B, D, S]
210 | Return:
211 | new_points: upsampled points data, [B, D', N]
212 | """
213 | xyz1 = xyz1.permute(0, 2, 1)
214 | xyz2 = xyz2.permute(0, 2, 1)
215 |
216 | points2 = points2.permute(0, 2, 1)
217 | B, N, C = xyz1.shape
218 | _, S, _ = xyz2.shape
219 |
220 | if S == 1:
221 | interpolated_points = points2.repeat(1, N, 1)
222 | else:
223 | dists = square_distance(xyz1, xyz2)
224 | dists, idx = dists.sort(dim=-1)
225 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
226 |
227 | dist_recip = 1.0 / (dists + 1e-8)
228 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
229 | weight = dist_recip / norm
230 | interpolated_points = torch.sum(index_points(
231 | points2, idx) * weight.view(B, N, 3, 1), dim=2)
232 |
233 | if points1 is not None:
234 | points1 = points1.permute(0, 2, 1)
235 | new_points = torch.cat([points1, interpolated_points], dim=-1)
236 | else:
237 | new_points = interpolated_points
238 |
239 | new_points = new_points.permute(0, 2, 1)
240 | for i, conv in enumerate(self.mlp_convs):
241 | bn = self.mlp_bns[i]
242 | new_points = F.relu(bn(conv(new_points)))
243 | return new_points
244 |
--------------------------------------------------------------------------------
/provider.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def normalize_data(batch_data):
4 | """ Normalize the batch data, use coordinates of the block centered at origin,
5 | Input:
6 | BxNxC array
7 | Output:
8 | BxNxC array
9 | """
10 | B, N, C = batch_data.shape
11 | normal_data = np.zeros((B, N, C))
12 | for b in range(B):
13 | pc = batch_data[b]
14 | centroid = np.mean(pc, axis=0)
15 | pc = pc - centroid
16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
17 | pc = pc / m
18 | normal_data[b] = pc
19 | return normal_data
20 |
21 |
22 | def shuffle_data(data, labels):
23 | """ Shuffle data and labels.
24 | Input:
25 | data: B,N,... numpy array
26 | label: B,... numpy array
27 | Return:
28 | shuffled data, label and shuffle indices
29 | """
30 | idx = np.arange(len(labels))
31 | np.random.shuffle(idx)
32 | return data[idx, ...], labels[idx], idx
33 |
34 | def shuffle_points(batch_data):
35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior.
36 | Use the same shuffling idx for the entire batch.
37 | Input:
38 | BxNxC array
39 | Output:
40 | BxNxC array
41 | """
42 | idx = np.arange(batch_data.shape[1])
43 | np.random.shuffle(idx)
44 | return batch_data[:,idx,:]
45 |
46 | def rotate_point_cloud(batch_data):
47 | """ Randomly rotate the point clouds to augument the dataset
48 | rotation is per shape based along up direction
49 | Input:
50 | BxNx3 array, original batch of point clouds
51 | Return:
52 | BxNx3 array, rotated batch of point clouds
53 | """
54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
55 | for k in range(batch_data.shape[0]):
56 | rotation_angle = np.random.uniform() * 2 * np.pi
57 | cosval = np.cos(rotation_angle)
58 | sinval = np.sin(rotation_angle)
59 | rotation_matrix = np.array([[cosval, 0, sinval],
60 | [0, 1, 0],
61 | [-sinval, 0, cosval]])
62 | shape_pc = batch_data[k, ...]
63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
64 | return rotated_data
65 |
66 | def rotate_point_cloud_z(batch_data):
67 | """ Randomly rotate the point clouds to augument the dataset
68 | rotation is per shape based along up direction
69 | Input:
70 | BxNx3 array, original batch of point clouds
71 | Return:
72 | BxNx3 array, rotated batch of point clouds
73 | """
74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
75 | for k in range(batch_data.shape[0]):
76 | rotation_angle = np.random.uniform() * 2 * np.pi
77 | cosval = np.cos(rotation_angle)
78 | sinval = np.sin(rotation_angle)
79 | rotation_matrix = np.array([[cosval, sinval, 0],
80 | [-sinval, cosval, 0],
81 | [0, 0, 1]])
82 | shape_pc = batch_data[k, ...]
83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
84 | return rotated_data
85 |
86 | def rotate_point_cloud_with_normal(batch_xyz_normal):
87 | ''' Randomly rotate XYZ, normal point cloud.
88 | Input:
89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
90 | Output:
91 | B,N,6, rotated XYZ, normal point cloud
92 | '''
93 | for k in range(batch_xyz_normal.shape[0]):
94 | rotation_angle = np.random.uniform() * 2 * np.pi
95 | cosval = np.cos(rotation_angle)
96 | sinval = np.sin(rotation_angle)
97 | rotation_matrix = np.array([[cosval, 0, sinval],
98 | [0, 1, 0],
99 | [-sinval, 0, cosval]])
100 | shape_pc = batch_xyz_normal[k,:,0:3]
101 | shape_normal = batch_xyz_normal[k,:,3:6]
102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
104 | return batch_xyz_normal
105 |
106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
107 | """ Randomly perturb the point clouds by small rotations
108 | Input:
109 | BxNx6 array, original batch of point clouds and point normals
110 | Return:
111 | BxNx3 array, rotated batch of point clouds
112 | """
113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
114 | for k in range(batch_data.shape[0]):
115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
116 | Rx = np.array([[1,0,0],
117 | [0,np.cos(angles[0]),-np.sin(angles[0])],
118 | [0,np.sin(angles[0]),np.cos(angles[0])]])
119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
120 | [0,1,0],
121 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
123 | [np.sin(angles[2]),np.cos(angles[2]),0],
124 | [0,0,1]])
125 | R = np.dot(Rz, np.dot(Ry,Rx))
126 | shape_pc = batch_data[k,:,0:3]
127 | shape_normal = batch_data[k,:,3:6]
128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
130 | return rotated_data
131 |
132 |
133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle):
134 | """ Rotate the point cloud along up direction with certain angle.
135 | Input:
136 | BxNx3 array, original batch of point clouds
137 | Return:
138 | BxNx3 array, rotated batch of point clouds
139 | """
140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
141 | for k in range(batch_data.shape[0]):
142 | #rotation_angle = np.random.uniform() * 2 * np.pi
143 | cosval = np.cos(rotation_angle)
144 | sinval = np.sin(rotation_angle)
145 | rotation_matrix = np.array([[cosval, 0, sinval],
146 | [0, 1, 0],
147 | [-sinval, 0, cosval]])
148 | shape_pc = batch_data[k,:,0:3]
149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
150 | return rotated_data
151 |
152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
153 | """ Rotate the point cloud along up direction with certain angle.
154 | Input:
155 | BxNx6 array, original batch of point clouds with normal
156 | scalar, angle of rotation
157 | Return:
158 | BxNx6 array, rotated batch of point clouds iwth normal
159 | """
160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
161 | for k in range(batch_data.shape[0]):
162 | #rotation_angle = np.random.uniform() * 2 * np.pi
163 | cosval = np.cos(rotation_angle)
164 | sinval = np.sin(rotation_angle)
165 | rotation_matrix = np.array([[cosval, 0, sinval],
166 | [0, 1, 0],
167 | [-sinval, 0, cosval]])
168 | shape_pc = batch_data[k,:,0:3]
169 | shape_normal = batch_data[k,:,3:6]
170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
172 | return rotated_data
173 |
174 |
175 |
176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
177 | """ Randomly perturb the point clouds by small rotations
178 | Input:
179 | BxNx3 array, original batch of point clouds
180 | Return:
181 | BxNx3 array, rotated batch of point clouds
182 | """
183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
184 | for k in range(batch_data.shape[0]):
185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
186 | Rx = np.array([[1,0,0],
187 | [0,np.cos(angles[0]),-np.sin(angles[0])],
188 | [0,np.sin(angles[0]),np.cos(angles[0])]])
189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
190 | [0,1,0],
191 | [-np.sin(angles[1]),0,np.cos(angles[1])]])
192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
193 | [np.sin(angles[2]),np.cos(angles[2]),0],
194 | [0,0,1]])
195 | R = np.dot(Rz, np.dot(Ry,Rx))
196 | shape_pc = batch_data[k, ...]
197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
198 | return rotated_data
199 |
200 |
201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
202 | """ Randomly jitter points. jittering is per point.
203 | Input:
204 | BxNx3 array, original batch of point clouds
205 | Return:
206 | BxNx3 array, jittered batch of point clouds
207 | """
208 | B, N, C = batch_data.shape
209 | assert(clip > 0)
210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
211 | jittered_data += batch_data
212 | return jittered_data
213 |
214 | def shift_point_cloud(batch_data, shift_range=0.1):
215 | """ Randomly shift point cloud. Shift is per point cloud.
216 | Input:
217 | BxNx3 array, original batch of point clouds
218 | Return:
219 | BxNx3 array, shifted batch of point clouds
220 | """
221 | B, N, C = batch_data.shape
222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3))
223 | for batch_index in range(B):
224 | batch_data[batch_index,:,:] += shifts[batch_index,:]
225 | return batch_data
226 |
227 |
228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
229 | """ Randomly scale the point cloud. Scale is per point cloud.
230 | Input:
231 | BxNx3 array, original batch of point clouds
232 | Return:
233 | BxNx3 array, scaled batch of point clouds
234 | """
235 | B, N, C = batch_data.shape
236 | scales = np.random.uniform(scale_low, scale_high, B)
237 | for batch_index in range(B):
238 | batch_data[batch_index,:,:] *= scales[batch_index]
239 | return batch_data
240 |
241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
242 | ''' batch_pc: BxNx3 '''
243 | for b in range(batch_pc.shape[0]):
244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
246 | if len(drop_idx)>0:
247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
248 | return batch_pc
249 |
250 |
251 |
252 |
--------------------------------------------------------------------------------
/test_cls.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
6 | import argparse
7 | import numpy as np
8 | import os
9 | import torch
10 | import logging
11 | from tqdm import tqdm
12 | import sys
13 | import importlib
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | ROOT_DIR = BASE_DIR
17 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
18 |
19 |
20 | def parse_args():
21 | '''PARAMETERS'''
22 | parser = argparse.ArgumentParser('PointNet')
23 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
24 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
25 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]')
26 | parser.add_argument('--log_dir', type=str, default='pointnet2_ssg_normal', help='Experiment root')
27 | parser.add_argument('--normal', action='store_true', default=True, help='Whether to use normal information [default: False]')
28 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting [default: 3]')
29 | return parser.parse_args()
30 |
31 | def test(model, loader, num_class=40, vote_num=1):
32 | mean_correct = []
33 | class_acc = np.zeros((num_class,3))
34 | for j, data in tqdm(enumerate(loader), total=len(loader)):
35 | points, target = data
36 | target = target[:, 0]
37 | points = points.transpose(2, 1)
38 | points, target = points.cuda(), target.cuda()
39 | classifier = model.eval()
40 | vote_pool = torch.zeros(target.size()[0],num_class).cuda()
41 | for _ in range(vote_num):
42 | pred, _ = classifier(points, False)
43 | vote_pool += pred
44 | pred = vote_pool/vote_num
45 | pred_choice = pred.data.max(1)[1]
46 | for cat in np.unique(target.cpu()):
47 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
48 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
49 | class_acc[cat,1]+=1
50 | correct = pred_choice.eq(target.long().data).cpu().sum()
51 | mean_correct.append(correct.item()/float(points.size()[0]))
52 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
53 | class_acc = np.mean(class_acc[:,2])
54 | instance_acc = np.mean(mean_correct)
55 | return instance_acc, class_acc
56 |
57 |
58 | def main(args):
59 | def log_string(str):
60 | logger.info(str)
61 | print(str)
62 |
63 | '''HYPER PARAMETER'''
64 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
65 |
66 | '''CREATE DIR'''
67 | experiment_dir = 'log/classification/' + args.log_dir
68 |
69 | '''LOG'''
70 | args = parse_args()
71 | logger = logging.getLogger("Model")
72 | logger.setLevel(logging.INFO)
73 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
74 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
75 | file_handler.setLevel(logging.INFO)
76 | file_handler.setFormatter(formatter)
77 | logger.addHandler(file_handler)
78 | log_string('PARAMETER ...')
79 | log_string(args)
80 |
81 | '''DATA LOADING'''
82 | log_string('Load dataset ...')
83 | DATA_PATH = 'data/modelnet40_normal_resampled/'
84 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal)
85 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
86 |
87 | '''MODEL LOADING'''
88 | num_class = 40
89 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0]
90 | MODEL = importlib.import_module(model_name)
91 |
92 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
93 |
94 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
95 | classifier.load_state_dict(checkpoint['model_state_dict'])
96 |
97 | with torch.no_grad():
98 | instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes)
99 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
100 |
101 |
102 |
103 | if __name__ == '__main__':
104 | args = parse_args()
105 | main(args)
106 |
--------------------------------------------------------------------------------
/test_partseg.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | import argparse
6 | import os
7 | from data_utils.ShapeNetDataLoader import PartNormalDataset
8 | import torch
9 | import logging
10 | import sys
11 | import importlib
12 | from tqdm import tqdm
13 | import numpy as np
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | ROOT_DIR = BASE_DIR
17 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
18 |
19 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
20 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
21 | for cat in seg_classes.keys():
22 | for label in seg_classes[cat]:
23 | seg_label_to_cat[label] = cat
24 |
25 | def to_categorical(y, num_classes):
26 | """ 1-hot encodes a tensor """
27 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
28 | if (y.is_cuda):
29 | return new_y.cuda()
30 | return new_y
31 |
32 |
33 | def parse_args():
34 | '''PARAMETERS'''
35 | parser = argparse.ArgumentParser('PointNet')
36 | parser.add_argument('--batch_size', type=int, default=24, help='batch size in testing [default: 24]')
37 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]')
38 | parser.add_argument('--num_point', type=int, default=2048, help='Point Number [default: 2048]')
39 | parser.add_argument('--log_dir', type=str, default='pointnet2_part_seg_ssg', help='Experiment root')
40 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
41 | parser.add_argument('--num_votes', type=int, default=3, help='Aggregate segmentation scores with voting [default: 3]')
42 | return parser.parse_args()
43 |
44 | def main(args):
45 | def log_string(str):
46 | logger.info(str)
47 | print(str)
48 |
49 | '''HYPER PARAMETER'''
50 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
51 | experiment_dir = 'log/part_seg/' + args.log_dir
52 |
53 | '''LOG'''
54 | args = parse_args()
55 | logger = logging.getLogger("Model")
56 | logger.setLevel(logging.INFO)
57 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
58 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
59 | file_handler.setLevel(logging.INFO)
60 | file_handler.setFormatter(formatter)
61 | logger.addHandler(file_handler)
62 | log_string('PARAMETER ...')
63 | log_string(args)
64 |
65 | root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'
66 |
67 | TEST_DATASET = PartNormalDataset(root = root, npoints=args.num_point, split='test', normal_channel=args.normal)
68 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4)
69 | log_string("The number of test data is: %d" % len(TEST_DATASET))
70 | num_classes = 16
71 | num_part = 50
72 |
73 | '''MODEL LOADING'''
74 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0]
75 | MODEL = importlib.import_module(model_name)
76 | classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda()
77 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
78 | classifier.load_state_dict(checkpoint['model_state_dict'])
79 |
80 |
81 | with torch.no_grad():
82 | test_metrics = {}
83 | total_correct = 0
84 | total_seen = 0
85 | total_seen_class = [0 for _ in range(num_part)]
86 | total_correct_class = [0 for _ in range(num_part)]
87 | shape_ious = {cat: [] for cat in seg_classes.keys()}
88 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
89 | for cat in seg_classes.keys():
90 | for label in seg_classes[cat]:
91 | seg_label_to_cat[label] = cat
92 |
93 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
94 | batchsize, num_point, _ = points.size()
95 | cur_batch_size, NUM_POINT, _ = points.size()
96 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
97 | points = points.transpose(2, 1)
98 | classifier = classifier.eval()
99 | vote_pool = torch.zeros(target.size()[0], target.size()[1], num_part).cuda()
100 | for _ in range(args.num_votes):
101 | seg_pred, _ = classifier(points, to_categorical(label, num_classes), False)
102 | vote_pool += seg_pred
103 | seg_pred = vote_pool / args.num_votes
104 | cur_pred_val = seg_pred.cpu().data.numpy()
105 | cur_pred_val_logits = cur_pred_val
106 | cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
107 | target = target.cpu().data.numpy()
108 | for i in range(cur_batch_size):
109 | cat = seg_label_to_cat[target[i, 0]]
110 | logits = cur_pred_val_logits[i, :, :]
111 | cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
112 | correct = np.sum(cur_pred_val == target)
113 | total_correct += correct
114 | total_seen += (cur_batch_size * NUM_POINT)
115 |
116 | for l in range(num_part):
117 | total_seen_class[l] += np.sum(target == l)
118 | total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
119 |
120 | for i in range(cur_batch_size):
121 | segp = cur_pred_val[i, :]
122 | segl = target[i, :]
123 | cat = seg_label_to_cat[segl[0]]
124 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
125 | for l in seg_classes[cat]:
126 | if (np.sum(segl == l) == 0) and (
127 | np.sum(segp == l) == 0): # part is not present, no prediction as well
128 | part_ious[l - seg_classes[cat][0]] = 1.0
129 | else:
130 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
131 | np.sum((segl == l) | (segp == l)))
132 | shape_ious[cat].append(np.mean(part_ious))
133 |
134 | all_shape_ious = []
135 | for cat in shape_ious.keys():
136 | for iou in shape_ious[cat]:
137 | all_shape_ious.append(iou)
138 | shape_ious[cat] = np.mean(shape_ious[cat])
139 | mean_shape_ious = np.mean(list(shape_ious.values()))
140 | test_metrics['accuracy'] = total_correct / float(total_seen)
141 | test_metrics['class_avg_accuracy'] = np.mean(
142 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
143 | for cat in sorted(shape_ious.keys()):
144 | log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
145 | test_metrics['class_avg_iou'] = mean_shape_ious
146 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
147 |
148 |
149 | log_string('Accuracy is: %.5f'%test_metrics['accuracy'])
150 | log_string('Class avg accuracy is: %.5f'%test_metrics['class_avg_accuracy'])
151 | log_string('Class avg mIOU is: %.5f'%test_metrics['class_avg_iou'])
152 | log_string('Inctance avg mIOU is: %.5f'%test_metrics['inctance_avg_iou'])
153 |
154 | if __name__ == '__main__':
155 | args = parse_args()
156 | main(args)
157 |
158 |
--------------------------------------------------------------------------------
/test_semseg.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | import argparse
6 | import os
7 | from data_utils.S3DISDataLoader import ScannetDatasetWholeScene
8 | from data_utils.indoor3d_util import g_label2color
9 | import torch
10 | import logging
11 | from pathlib import Path
12 | import sys
13 | import importlib
14 | from tqdm import tqdm
15 | import provider
16 | import numpy as np
17 |
18 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19 | ROOT_DIR = BASE_DIR
20 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
21 |
22 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
23 | class2label = {cls: i for i,cls in enumerate(classes)}
24 | seg_classes = class2label
25 | seg_label_to_cat = {}
26 | for i,cat in enumerate(seg_classes.keys()):
27 | seg_label_to_cat[i] = cat
28 |
29 | def parse_args():
30 | '''PARAMETERS'''
31 | parser = argparse.ArgumentParser('Model')
32 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]')
33 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
34 | parser.add_argument('--num_point', type=int, default=4096, help='Point Number [default: 4096]')
35 | parser.add_argument('--log_dir', type=str, default='pointnet2_sem_seg', help='Experiment root')
36 | parser.add_argument('--visual', action='store_true', default=False, help='Whether visualize result [default: False]')
37 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]')
38 | parser.add_argument('--num_votes', type=int, default=5, help='Aggregate segmentation scores with voting [default: 5]')
39 | return parser.parse_args()
40 |
41 | def add_vote(vote_label_pool, point_idx, pred_label, weight):
42 | B = pred_label.shape[0]
43 | N = pred_label.shape[1]
44 | for b in range(B):
45 | for n in range(N):
46 | if weight[b,n]:
47 | vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1
48 | return vote_label_pool
49 |
50 | def main(args):
51 | def log_string(str):
52 | logger.info(str)
53 | print(str)
54 |
55 | '''HYPER PARAMETER'''
56 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
57 | experiment_dir = 'log/sem_seg/' + args.log_dir
58 | visual_dir = experiment_dir + '/visual/'
59 | visual_dir = Path(visual_dir)
60 | visual_dir.mkdir(exist_ok=True)
61 |
62 | '''LOG'''
63 | args = parse_args()
64 | logger = logging.getLogger("Model")
65 | logger.setLevel(logging.INFO)
66 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
67 | file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
68 | file_handler.setLevel(logging.INFO)
69 | file_handler.setFormatter(formatter)
70 | logger.addHandler(file_handler)
71 | log_string('PARAMETER ...')
72 | log_string(args)
73 |
74 | NUM_CLASSES = 13
75 | BATCH_SIZE = args.batch_size
76 | NUM_POINT = args.num_point
77 |
78 | root = 'data/stanford_indoor3d/'
79 |
80 | TEST_DATASET_WHOLE_SCENE = ScannetDatasetWholeScene(root, split='test', test_area=args.test_area, block_points=NUM_POINT)
81 | log_string("The number of test data is: %d" % len(TEST_DATASET_WHOLE_SCENE))
82 |
83 | '''MODEL LOADING'''
84 | model_name = os.listdir(experiment_dir+'/logs')[0].split('.')[0]
85 | MODEL = importlib.import_module(model_name)
86 | classifier = MODEL.get_model(NUM_CLASSES).cuda()
87 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
88 | classifier.load_state_dict(checkpoint['model_state_dict'])
89 |
90 | with torch.no_grad():
91 | scene_id = TEST_DATASET_WHOLE_SCENE.file_list
92 | scene_id = [x[:-4] for x in scene_id]
93 | num_batches = len(TEST_DATASET_WHOLE_SCENE)
94 |
95 | total_seen_class = [0 for _ in range(NUM_CLASSES)]
96 | total_correct_class = [0 for _ in range(NUM_CLASSES)]
97 | total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]
98 |
99 | log_string('---- EVALUATION WHOLE SCENE----')
100 |
101 | for batch_idx in range(num_batches):
102 | print("visualize [%d/%d] %s ..." % (batch_idx+1, num_batches, scene_id[batch_idx]))
103 | total_seen_class_tmp = [0 for _ in range(NUM_CLASSES)]
104 | total_correct_class_tmp = [0 for _ in range(NUM_CLASSES)]
105 | total_iou_deno_class_tmp = [0 for _ in range(NUM_CLASSES)]
106 | if args.visual:
107 | fout = open(os.path.join(visual_dir, scene_id[batch_idx] + '_pred.obj'), 'w')
108 | fout_gt = open(os.path.join(visual_dir, scene_id[batch_idx] + '_gt.obj'), 'w')
109 |
110 | whole_scene_data = TEST_DATASET_WHOLE_SCENE.scene_points_list[batch_idx]
111 | whole_scene_label = TEST_DATASET_WHOLE_SCENE.semantic_labels_list[batch_idx]
112 | vote_label_pool = np.zeros((whole_scene_label.shape[0], NUM_CLASSES))
113 | for _ in tqdm(range(args.num_votes), total=args.num_votes):
114 | scene_data, scene_label, scene_smpw, scene_point_index = TEST_DATASET_WHOLE_SCENE[batch_idx]
115 | num_blocks = scene_data.shape[0]
116 | s_batch_num = (num_blocks + BATCH_SIZE - 1) // BATCH_SIZE
117 | batch_data = np.zeros((BATCH_SIZE, NUM_POINT, 9))
118 |
119 | batch_label = np.zeros((BATCH_SIZE, NUM_POINT))
120 | batch_point_index = np.zeros((BATCH_SIZE, NUM_POINT))
121 | batch_smpw = np.zeros((BATCH_SIZE, NUM_POINT))
122 | for sbatch in range(s_batch_num):
123 | start_idx = sbatch * BATCH_SIZE
124 | end_idx = min((sbatch + 1) * BATCH_SIZE, num_blocks)
125 | real_batch_size = end_idx - start_idx
126 | batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...]
127 | batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...]
128 | batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...]
129 | batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...]
130 | batch_data[:, :, 3:6] /= 1.0
131 |
132 | torch_data = torch.Tensor(batch_data)
133 | torch_data= torch_data.float().cuda()
134 | torch_data = torch_data.transpose(2, 1)
135 | seg_pred, _ = classifier(torch_data, False)
136 | batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy()
137 |
138 | vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...],
139 | batch_pred_label[0:real_batch_size, ...],
140 | batch_smpw[0:real_batch_size, ...])
141 |
142 | pred_label = np.argmax(vote_label_pool, 1)
143 |
144 | for l in range(NUM_CLASSES):
145 | total_seen_class_tmp[l] += np.sum((whole_scene_label == l))
146 | total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l))
147 | total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l)))
148 | total_seen_class[l] += total_seen_class_tmp[l]
149 | total_correct_class[l] += total_correct_class_tmp[l]
150 | total_iou_deno_class[l] += total_iou_deno_class_tmp[l]
151 |
152 | iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=np.float) + 1e-6)
153 | print(iou_map)
154 | arr = np.array(total_seen_class_tmp)
155 | tmp_iou = np.mean(iou_map[arr != 0])
156 | log_string('Mean IoU of %s: %.4f' % (scene_id[batch_idx], tmp_iou))
157 | print('----------------------------')
158 |
159 | filename = os.path.join(visual_dir, scene_id[batch_idx] + '.txt')
160 | with open(filename, 'w') as pl_save:
161 | for i in pred_label:
162 | pl_save.write(str(int(i)) + '\n')
163 | pl_save.close()
164 | for i in range(whole_scene_label.shape[0]):
165 | color = g_label2color[pred_label[i]]
166 | color_gt = g_label2color[whole_scene_label[i]]
167 | if args.visual:
168 | fout.write('v %f %f %f %d %d %d\n' % (
169 | whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color[0], color[1],
170 | color[2]))
171 | fout_gt.write(
172 | 'v %f %f %f %d %d %d\n' % (
173 | whole_scene_data[i, 0], whole_scene_data[i, 1], whole_scene_data[i, 2], color_gt[0],
174 | color_gt[1], color_gt[2]))
175 | if args.visual:
176 | fout.close()
177 | fout_gt.close()
178 |
179 | IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6)
180 | iou_per_class_str = '------- IoU --------\n'
181 | for l in range(NUM_CLASSES):
182 | iou_per_class_str += 'class %s, IoU: %.3f \n' % (
183 | seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])),
184 | total_correct_class[l] / float(total_iou_deno_class[l]))
185 | log_string(iou_per_class_str)
186 | log_string('eval point avg class IoU: %f' % np.mean(IoU))
187 | log_string('eval whole scene point avg class acc: %f' % (
188 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6))))
189 | log_string('eval whole scene point accuracy: %f' % (
190 | np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6)))
191 |
192 | print("Done!")
193 |
194 | if __name__ == '__main__':
195 | args = parse_args()
196 | main(args)
197 |
--------------------------------------------------------------------------------
/train_cls.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | from data_utils.ModelNetDataLoader import ModelNetDataLoader
6 | import argparse
7 | import numpy as np
8 | import os
9 | import torch
10 | import datetime
11 | import logging
12 | from pathlib import Path
13 | from tqdm import tqdm
14 | import sys
15 | import provider
16 | import importlib
17 | import shutil
18 |
19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
20 | ROOT_DIR = BASE_DIR
21 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
22 |
23 |
24 | def parse_args():
25 | '''PARAMETERS'''
26 | parser = argparse.ArgumentParser('PointNet')
27 | parser.add_argument('--batch_size', type=int, default=32, help='batch size in training [default: 32]')
28 | parser.add_argument('--model', default='pointnet2_cls_ssg_psn', help='model name [default: pointnet2_cls_ssg_psn]')
29 | parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training [default: 300]')
30 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training [default: 0.001]')
31 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device [default: 0]')
32 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]')
33 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training [default: Adam]')
34 | parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
35 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate [default: 1e-4]')
36 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
37 | return parser.parse_args()
38 |
39 | def test(model, loader, num_class=40):
40 | mean_correct = []
41 | class_acc = np.zeros((num_class,3))
42 | for j, data in tqdm(enumerate(loader), total=len(loader)):
43 | points, target = data
44 | target = target[:, 0]
45 | points = points.transpose(2, 1)
46 | points, target = points.cuda(), target.cuda()
47 | classifier = model.eval()
48 | pred, _ = classifier(points, False)
49 | pred_choice = pred.data.max(1)[1]
50 | for cat in np.unique(target.cpu()):
51 | classacc = pred_choice[target==cat].eq(target[target==cat].long().data).cpu().sum()
52 | class_acc[cat,0]+= classacc.item()/float(points[target==cat].size()[0])
53 | class_acc[cat,1]+=1
54 | correct = pred_choice.eq(target.long().data).cpu().sum()
55 | mean_correct.append(correct.item()/float(points.size()[0]))
56 | class_acc[:,2] = class_acc[:,0]/ class_acc[:,1]
57 | class_acc = np.mean(class_acc[:,2])
58 | instance_acc = np.mean(mean_correct)
59 | return instance_acc, class_acc
60 |
61 |
62 | def main(args):
63 | def log_string(str):
64 | logger.info(str)
65 | print(str)
66 |
67 | '''HYPER PARAMETER'''
68 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
69 |
70 | '''CREATE DIR'''
71 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
72 | experiment_dir = Path('./log/')
73 | experiment_dir.mkdir(exist_ok=True)
74 | experiment_dir = experiment_dir.joinpath('classification')
75 | experiment_dir.mkdir(exist_ok=True)
76 | if args.log_dir is None:
77 | experiment_dir = experiment_dir.joinpath(timestr)
78 | else:
79 | experiment_dir = experiment_dir.joinpath(args.log_dir)
80 | experiment_dir.mkdir(exist_ok=True)
81 | checkpoints_dir = experiment_dir.joinpath('checkpoints/')
82 | checkpoints_dir.mkdir(exist_ok=True)
83 | log_dir = experiment_dir.joinpath('logs/')
84 | log_dir.mkdir(exist_ok=True)
85 |
86 | '''LOG'''
87 | args = parse_args()
88 | logger = logging.getLogger("Model")
89 | logger.setLevel(logging.INFO)
90 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
91 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
92 | file_handler.setLevel(logging.INFO)
93 | file_handler.setFormatter(formatter)
94 | logger.addHandler(file_handler)
95 | log_string('PARAMETER ...')
96 | log_string(args)
97 |
98 | '''DATA LOADING'''
99 | log_string('Load dataset ...')
100 | DATA_PATH = 'data/modelnet40_normal_resampled/'
101 |
102 | TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train',
103 | normal_channel=args.normal)
104 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test',
105 | normal_channel=args.normal)
106 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=4)
107 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
108 |
109 | '''MODEL LOADING'''
110 | num_class = 40
111 | MODEL = importlib.import_module(args.model)
112 | shutil.copy('./models/%s.py' % args.model, str(experiment_dir))
113 | shutil.copy('./models/pointnet_util_psn.py', str(experiment_dir))
114 |
115 | classifier = MODEL.get_model(num_class,normal_channel=args.normal).cuda()
116 | criterion = MODEL.get_loss().cuda()
117 |
118 | try:
119 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
120 | start_epoch = checkpoint['epoch']
121 | classifier.load_state_dict(checkpoint['model_state_dict'])
122 | log_string('Use pretrain model')
123 | except:
124 | log_string('No existing model, starting training from scratch...')
125 | start_epoch = 0
126 |
127 |
128 | if args.optimizer == 'Adam':
129 | optimizer = torch.optim.Adam(
130 | classifier.parameters(),
131 | lr=args.learning_rate,
132 | betas=(0.9, 0.999),
133 | eps=1e-08,
134 | weight_decay=args.decay_rate
135 | )
136 | else:
137 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
138 |
139 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
140 | global_epoch = 0
141 | global_step = 0
142 | best_instance_acc = 0.0
143 | best_class_acc = 0.0
144 | mean_correct = []
145 | best_epoch = 0
146 |
147 | '''TRANING'''
148 | logger.info('Start training...')
149 | for epoch in range(start_epoch,args.epoch):
150 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
151 |
152 | scheduler.step()
153 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
154 | points, target = data
155 | points = points.data.numpy()
156 | points = provider.random_point_dropout(points)
157 | points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
158 | points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
159 | points = torch.Tensor(points)
160 | target = target[:, 0]
161 |
162 | points = points.transpose(2, 1)
163 | points, target = points.cuda(), target.cuda()
164 | optimizer.zero_grad()
165 |
166 | classifier = classifier.train()
167 | pred, trans_feat = classifier(points, False)
168 | loss = criterion(pred, target.long(), trans_feat)
169 | pred_choice = pred.data.max(1)[1]
170 | correct = pred_choice.eq(target.long().data).cpu().sum()
171 | mean_correct.append(correct.item() / float(points.size()[0]))
172 | loss.backward()
173 | optimizer.step()
174 | global_step += 1
175 |
176 | train_instance_acc = np.mean(mean_correct)
177 | log_string('Train Instance Accuracy: %f' % train_instance_acc)
178 |
179 |
180 | with torch.no_grad():
181 | instance_acc, class_acc = test(classifier.eval(), testDataLoader)
182 |
183 | if (instance_acc >= best_instance_acc):
184 | best_instance_acc = instance_acc
185 | best_epoch = epoch + 1
186 |
187 | if (class_acc >= best_class_acc):
188 | best_class_acc = class_acc
189 | log_string('Test Instance Accuracy: %f, Class Accuracy: %f'% (instance_acc, class_acc))
190 | log_string('Best Instance Accuracy: %f, Class Accuracy: %f'% (best_instance_acc, best_class_acc))
191 |
192 | if (instance_acc >= best_instance_acc):
193 | logger.info('Save model...')
194 | savepath = str(checkpoints_dir) + '/best_model.pth'
195 | log_string('Saving at %s'% savepath)
196 | state = {
197 | 'epoch': best_epoch,
198 | 'instance_acc': instance_acc,
199 | 'class_acc': class_acc,
200 | 'model_state_dict': classifier.state_dict(),
201 | 'optimizer_state_dict': optimizer.state_dict(),
202 | }
203 | torch.save(state, savepath)
204 | global_epoch += 1
205 |
206 | logger.info('End of training...')
207 |
208 | if __name__ == '__main__':
209 | args = parse_args()
210 | main(args)
211 |
--------------------------------------------------------------------------------
/train_partseg.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | import argparse
6 | import os
7 | from data_utils.ShapeNetDataLoader import PartNormalDataset
8 | import torch
9 | import datetime
10 | import logging
11 | from pathlib import Path
12 | import sys
13 | import importlib
14 | import shutil
15 | from tqdm import tqdm
16 | import provider
17 | import numpy as np
18 |
19 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
20 | ROOT_DIR = BASE_DIR
21 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
22 | torch.autograd.set_detect_anomaly(True)
23 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
24 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
25 | for cat in seg_classes.keys():
26 | for label in seg_classes[cat]:
27 | seg_label_to_cat[label] = cat
28 |
29 | def to_categorical(y, num_classes):
30 | """ 1-hot encodes a tensor """
31 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
32 | if (y.is_cuda):
33 | return new_y.cuda()
34 | return new_y
35 |
36 |
37 | def parse_args():
38 | parser = argparse.ArgumentParser('Model')
39 | parser.add_argument('--model', type=str, default='pointnet2_part_seg_ssg_psn', help='model name [default: pointnet2_part_seg_ssg_psn]')
40 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 16]')
41 | parser.add_argument('--epoch', default=251, type=int, help='Epoch to run [default: 251]')
42 | parser.add_argument('--learning_rate', default=0.001, type=float, help='Initial learning rate [default: 0.001]')
43 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')
44 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]')
45 | parser.add_argument('--log_dir', type=str, default=None, help='Log path [default: None]')
46 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]')
47 | parser.add_argument('--npoint', type=int, default=2048, help='Point Number [default: 2048]')
48 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
49 | parser.add_argument('--step_size', type=int, default=20, help='Decay step for lr decay [default: every 20 epochs]')
50 | parser.add_argument('--lr_decay', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]')
51 |
52 | return parser.parse_args()
53 |
54 | def main(args):
55 | def log_string(str):
56 | logger.info(str)
57 | print(str)
58 |
59 | '''HYPER PARAMETER'''
60 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
61 |
62 | '''CREATE DIR'''
63 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
64 | experiment_dir = Path('./log/')
65 | experiment_dir.mkdir(exist_ok=True)
66 | experiment_dir = experiment_dir.joinpath('part_seg')
67 | experiment_dir.mkdir(exist_ok=True)
68 | if args.log_dir is None:
69 | experiment_dir = experiment_dir.joinpath(timestr)
70 | else:
71 | experiment_dir = experiment_dir.joinpath(args.log_dir)
72 | experiment_dir.mkdir(exist_ok=True)
73 | checkpoints_dir = experiment_dir.joinpath('checkpoints/')
74 | checkpoints_dir.mkdir(exist_ok=True)
75 | log_dir = experiment_dir.joinpath('logs/')
76 | log_dir.mkdir(exist_ok=True)
77 |
78 | '''LOG'''
79 | args = parse_args()
80 | logger = logging.getLogger("Model")
81 | logger.setLevel(logging.INFO)
82 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
83 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
84 | file_handler.setLevel(logging.INFO)
85 | file_handler.setFormatter(formatter)
86 | logger.addHandler(file_handler)
87 | log_string('PARAMETER ...')
88 | log_string(args)
89 |
90 | root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'
91 |
92 | TRAIN_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='trainval', normal_channel=args.normal)
93 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size,shuffle=True, num_workers=4)
94 | TEST_DATASET = PartNormalDataset(root = root, npoints=args.npoint, split='test', normal_channel=args.normal)
95 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4)
96 | log_string("The number of training data is: %d" % len(TRAIN_DATASET))
97 | log_string("The number of test data is: %d" % len(TEST_DATASET))
98 | num_classes = 16
99 | num_part = 50
100 | '''MODEL LOADING'''
101 | MODEL = importlib.import_module(args.model)
102 | shutil.copy('models/%s.py' % args.model, str(experiment_dir))
103 | shutil.copy('models/pointnet_util_psn.py', str(experiment_dir))
104 |
105 | classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda()
106 | criterion = MODEL.get_loss().cuda()
107 |
108 |
109 | def weights_init(m):
110 | classname = m.__class__.__name__
111 | if classname.find('Conv2d') != -1:
112 | torch.nn.init.xavier_normal_(m.weight.data)
113 | torch.nn.init.constant_(m.bias.data, 0.0)
114 | elif classname.find('Linear') != -1:
115 | torch.nn.init.xavier_normal_(m.weight.data)
116 | torch.nn.init.constant_(m.bias.data, 0.0)
117 |
118 | try:
119 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
120 | start_epoch = checkpoint['epoch']
121 | classifier.load_state_dict(checkpoint['model_state_dict'])
122 | log_string('Use pretrain model')
123 | except:
124 | log_string('No existing model, starting training from scratch...')
125 | start_epoch = 0
126 | classifier = classifier.apply(weights_init)
127 |
128 | if args.optimizer == 'Adam':
129 | optimizer = torch.optim.Adam(
130 | classifier.parameters(),
131 | lr=args.learning_rate,
132 | betas=(0.9, 0.999),
133 | eps=1e-08,
134 | weight_decay=args.decay_rate
135 | )
136 | else:
137 | optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
138 |
139 | def bn_momentum_adjust(m, momentum):
140 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
141 | m.momentum = momentum
142 |
143 | LEARNING_RATE_CLIP = 1e-7
144 | MOMENTUM_ORIGINAL = 0.1
145 | MOMENTUM_DECCAY = 0.5
146 | MOMENTUM_DECCAY_STEP = args.step_size
147 |
148 | best_acc = 0
149 | global_epoch = 0
150 | best_class_avg_iou = 0
151 | best_inctance_avg_iou = 0
152 |
153 | for epoch in range(start_epoch,args.epoch):
154 | log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
155 | '''Adjust learning rate and BN momentum'''
156 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
157 | log_string('Learning rate:%f' % lr)
158 | for param_group in optimizer.param_groups:
159 | param_group['lr'] = lr
160 | mean_correct = []
161 | momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
162 | if momentum < 0.01:
163 | momentum = 0.01
164 | print('BN momentum updated to: %f' % momentum)
165 | classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum))
166 |
167 | '''learning one epoch'''
168 | for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
169 | points, label, target = data
170 | points = points.data.numpy()
171 | points[:,:, 0:3] = provider.random_scale_point_cloud(points[:,:, 0:3])
172 | points[:,:, 0:3] = provider.shift_point_cloud(points[:,:, 0:3])
173 | points = torch.Tensor(points)
174 | points, label, target = points.float().cuda(),label.long().cuda(), target.long().cuda()
175 | points = points.transpose(2, 1)
176 | optimizer.zero_grad()
177 | classifier = classifier.train()
178 | seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes), False)
179 | seg_pred = seg_pred.contiguous().view(-1, num_part)
180 | target = target.view(-1, 1)[:, 0]
181 | pred_choice = seg_pred.data.max(1)[1]
182 | correct = pred_choice.eq(target.data).cpu().sum()
183 | mean_correct.append(correct.item() / (args.batch_size * args.npoint))
184 | loss = criterion(seg_pred, target, trans_feat)
185 | loss.backward()
186 | optimizer.step()
187 | train_instance_acc = np.mean(mean_correct)
188 | log_string('Train accuracy is: %.5f' % train_instance_acc)
189 |
190 | with torch.no_grad():
191 | test_metrics = {}
192 | total_correct = 0
193 | total_seen = 0
194 | total_seen_class = [0 for _ in range(num_part)]
195 | total_correct_class = [0 for _ in range(num_part)]
196 | shape_ious = {cat: [] for cat in seg_classes.keys()}
197 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
198 | for cat in seg_classes.keys():
199 | for label in seg_classes[cat]:
200 | seg_label_to_cat[label] = cat
201 |
202 | for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
203 | tiaoshi = (epoch, batch_id)
204 | cur_batch_size, NUM_POINT, _ = points.size()
205 | points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
206 | points = points.transpose(2, 1)
207 | classifier = classifier.eval()
208 | seg_pred, _ = classifier(points, to_categorical(label, num_classes), False)
209 | cur_pred_val = seg_pred.cpu().data.numpy()
210 | cur_pred_val_logits = cur_pred_val
211 | cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
212 | target = target.cpu().data.numpy()
213 | for i in range(cur_batch_size):
214 | cat = seg_label_to_cat[target[i, 0]]
215 | logits = cur_pred_val_logits[i, :, :]
216 | cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
217 | correct = np.sum(cur_pred_val == target)
218 | total_correct += correct
219 | total_seen += (cur_batch_size * NUM_POINT)
220 |
221 | for l in range(num_part):
222 | total_seen_class[l] += np.sum(target == l)
223 | total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))
224 |
225 | for i in range(cur_batch_size):
226 | segp = cur_pred_val[i, :]
227 | segl = target[i, :]
228 | cat = seg_label_to_cat[segl[0]]
229 | part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
230 | for l in seg_classes[cat]:
231 | if (np.sum(segl == l) == 0) and (
232 | np.sum(segp == l) == 0): # part is not present, no prediction as well
233 | part_ious[l - seg_classes[cat][0]] = 1.0
234 | else:
235 | part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
236 | np.sum((segl == l) | (segp == l)))
237 | shape_ious[cat].append(np.mean(part_ious))
238 |
239 | all_shape_ious = []
240 | for cat in shape_ious.keys():
241 | for iou in shape_ious[cat]:
242 | all_shape_ious.append(iou)
243 | shape_ious[cat] = np.mean(shape_ious[cat])
244 | mean_shape_ious = np.mean(list(shape_ious.values()))
245 | test_metrics['accuracy'] = total_correct / float(total_seen)
246 | test_metrics['class_avg_accuracy'] = np.mean(
247 | np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
248 | for cat in sorted(shape_ious.keys()):
249 | log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
250 | test_metrics['class_avg_iou'] = mean_shape_ious
251 | test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
252 |
253 |
254 | log_string('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
255 | epoch+1, test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou']))
256 | if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
257 | logger.info('Save model...')
258 | savepath = str(checkpoints_dir) + '/best_model.pth'
259 | log_string('Saving at %s'% savepath)
260 | state = {
261 | 'epoch': epoch,
262 | 'train_acc': train_instance_acc,
263 | 'test_acc': test_metrics['accuracy'],
264 | 'class_avg_iou': test_metrics['class_avg_iou'],
265 | 'inctance_avg_iou': test_metrics['inctance_avg_iou'],
266 | 'model_state_dict': classifier.state_dict(),
267 | 'optimizer_state_dict': optimizer.state_dict(),
268 | }
269 | torch.save(state, savepath)
270 | log_string('Saving model....')
271 |
272 | if test_metrics['accuracy'] > best_acc:
273 | best_acc = test_metrics['accuracy']
274 | if test_metrics['class_avg_iou'] > best_class_avg_iou:
275 | best_class_avg_iou = test_metrics['class_avg_iou']
276 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
277 | best_inctance_avg_iou = test_metrics['inctance_avg_iou']
278 | log_string('Best accuracy is: %.5f'%best_acc)
279 | log_string('Best class avg mIOU is: %.5f'%best_class_avg_iou)
280 | log_string('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou)
281 | global_epoch+=1
282 |
283 | if __name__ == '__main__':
284 | args = parse_args()
285 | main(args)
286 |
287 |
--------------------------------------------------------------------------------
/train_semseg.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Benny
3 | Date: Nov 2019
4 | """
5 | import argparse
6 | import os
7 | from data_utils.S3DISDataLoader import S3DISDataset
8 | import torch
9 | import datetime
10 | import logging
11 | from pathlib import Path
12 | import sys
13 | import importlib
14 | import shutil
15 | from tqdm import tqdm
16 | import provider
17 | import numpy as np
18 | import time
19 |
20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21 | ROOT_DIR = BASE_DIR
22 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
23 |
24 |
25 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
26 | class2label = {cls: i for i,cls in enumerate(classes)}
27 | seg_classes = class2label
28 | seg_label_to_cat = {}
29 | for i,cat in enumerate(seg_classes.keys()):
30 | seg_label_to_cat[i] = cat
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser('Model')
35 | parser.add_argument('--model', type=str, default='pointnet_sem_seg_psn', help='model name [default: pointnet_sem_seg_psn]')
36 | parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]')
37 | parser.add_argument('--epoch', default=500, type=int, help='Epoch to run [default: 500]')
38 | parser.add_argument('--learning_rate', default=0.001, type=float, help='Initial learning rate [default: 0.001]')
39 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]')
40 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]')
41 | parser.add_argument('--log_dir', type=str, default=None, help='Log path [default: None]')
42 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]')
43 | parser.add_argument('--npoint', type=int, default=4096, help='Point Number [default: 4096]')
44 | parser.add_argument('--step_size', type=int, default=10, help='Decay step for lr decay [default: every 10 epochs]')
45 | parser.add_argument('--lr_decay', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]')
46 | parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]')
47 |
48 | return parser.parse_args()
49 |
50 | def main(args):
51 | def log_string(str):
52 | logger.info(str)
53 | print(str)
54 |
55 | '''HYPER PARAMETER'''
56 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
57 |
58 | '''CREATE DIR'''
59 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
60 | experiment_dir = Path('./log/')
61 | experiment_dir.mkdir(exist_ok=True)
62 | experiment_dir = experiment_dir.joinpath('sem_seg')
63 | experiment_dir.mkdir(exist_ok=True)
64 | if args.log_dir is None:
65 | experiment_dir = experiment_dir.joinpath(timestr)
66 | else:
67 | experiment_dir = experiment_dir.joinpath(args.log_dir)
68 | experiment_dir.mkdir(exist_ok=True)
69 | checkpoints_dir = experiment_dir.joinpath('checkpoints/')
70 | checkpoints_dir.mkdir(exist_ok=True)
71 | log_dir = experiment_dir.joinpath('logs/')
72 | log_dir.mkdir(exist_ok=True)
73 |
74 | '''LOG'''
75 | args = parse_args()
76 | logger = logging.getLogger("Model")
77 | logger.setLevel(logging.INFO)
78 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
79 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
80 | file_handler.setLevel(logging.INFO)
81 | file_handler.setFormatter(formatter)
82 | logger.addHandler(file_handler)
83 | log_string('PARAMETER ...')
84 | log_string(args)
85 |
86 | root = 'data/stanford_indoor3d/'
87 | NUM_CLASSES = 13
88 | NUM_POINT = args.npoint
89 | BATCH_SIZE = args.batch_size
90 |
91 | print("start loading training data ...")
92 | TRAIN_DATASET = S3DISDataset(split='train', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
93 | print("start loading test data ...")
94 | TEST_DATASET = S3DISDataset(split='test', data_root=root, num_point=NUM_POINT, test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
95 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True, worker_init_fn = lambda x: np.random.seed(x+int(time.time())))
96 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
97 | weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda()
98 |
99 | log_string("The number of training data is: %d" % len(TRAIN_DATASET))
100 | log_string("The number of test data is: %d" % len(TEST_DATASET))
101 |
102 | '''MODEL LOADING'''
103 | MODEL = importlib.import_module(args.model)
104 | shutil.copy('models/%s.py' % args.model, str(experiment_dir))
105 | shutil.copy('models/pointnet_util_psn.py', str(experiment_dir))
106 |
107 | classifier = MODEL.get_model(NUM_CLASSES).cuda()
108 | criterion = MODEL.get_loss().cuda()
109 |
110 | def weights_init(m):
111 | classname = m.__class__.__name__
112 | if classname.find('Conv2d') != -1:
113 | torch.nn.init.xavier_normal_(m.weight.data)
114 | torch.nn.init.constant_(m.bias.data, 0.0)
115 | elif classname.find('Linear') != -1:
116 | torch.nn.init.xavier_normal_(m.weight.data)
117 | torch.nn.init.constant_(m.bias.data, 0.0)
118 |
119 | try:
120 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
121 | start_epoch = checkpoint['epoch']
122 | classifier.load_state_dict(checkpoint['model_state_dict'])
123 | log_string('Use pretrain model')
124 | except:
125 | log_string('No existing model, starting training from scratch...')
126 | start_epoch = 0
127 | classifier = classifier.apply(weights_init)
128 |
129 | if args.optimizer == 'Adam':
130 | optimizer = torch.optim.Adam(
131 | classifier.parameters(),
132 | lr=args.learning_rate,
133 | betas=(0.9, 0.999),
134 | eps=1e-08,
135 | weight_decay=args.decay_rate
136 | )
137 | else:
138 | optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
139 |
140 | def bn_momentum_adjust(m, momentum):
141 | if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
142 | m.momentum = momentum
143 |
144 | LEARNING_RATE_CLIP = 1e-5
145 | MOMENTUM_ORIGINAL = 0.1
146 | MOMENTUM_DECCAY = 0.5
147 | MOMENTUM_DECCAY_STEP = args.step_size
148 |
149 | global_epoch = 0
150 | best_iou = 0
151 |
152 | for epoch in range(start_epoch,args.epoch):
153 | '''Train on chopped scenes'''
154 | log_string('**** Epoch %d (%d/%s) ****' % (global_epoch + 1, epoch + 1, args.epoch))
155 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
156 | log_string('Learning rate:%f' % lr)
157 | for param_group in optimizer.param_groups:
158 | param_group['lr'] = lr
159 | momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
160 | if momentum < 0.01:
161 | momentum = 0.01
162 | print('BN momentum updated to: %f' % momentum)
163 | classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum))
164 | num_batches = len(trainDataLoader)
165 | total_correct = 0
166 | total_seen = 0
167 | loss_sum = 0
168 | for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
169 | points, target = data
170 | points = points.data.numpy()
171 | points[:,:, :3] = provider.rotate_point_cloud_z(points[:,:, :3])
172 | points = torch.Tensor(points)
173 | points, target = points.float().cuda(),target.long().cuda()
174 | points = points.transpose(2, 1)
175 | optimizer.zero_grad()
176 | classifier = classifier.train()
177 | seg_pred, trans_feat = classifier(points, False)
178 | seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)
179 | batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
180 | target = target.view(-1, 1)[:, 0]
181 | loss = criterion(seg_pred, target, trans_feat, weights)
182 | loss.backward()
183 | optimizer.step()
184 | pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
185 | correct = np.sum(pred_choice == batch_label)
186 | total_correct += correct
187 | total_seen += (BATCH_SIZE * NUM_POINT)
188 | loss_sum += loss
189 | log_string('Training mean loss: %f' % (loss_sum / num_batches))
190 | log_string('Training accuracy: %f' % (total_correct / float(total_seen)))
191 |
192 | if epoch % 5 == 0:
193 | logger.info('Save model...')
194 | savepath = str(checkpoints_dir) + '/model.pth'
195 | log_string('Saving at %s' % savepath)
196 | state = {
197 | 'epoch': epoch,
198 | 'model_state_dict': classifier.state_dict(),
199 | 'optimizer_state_dict': optimizer.state_dict(),
200 | }
201 | torch.save(state, savepath)
202 | log_string('Saving model....')
203 |
204 | '''Evaluate on chopped scenes'''
205 | with torch.no_grad():
206 | num_batches = len(testDataLoader)
207 | total_correct = 0
208 | total_seen = 0
209 | loss_sum = 0
210 | labelweights = np.zeros(NUM_CLASSES)
211 | total_seen_class = [0 for _ in range(NUM_CLASSES)]
212 | total_correct_class = [0 for _ in range(NUM_CLASSES)]
213 | total_iou_deno_class = [0 for _ in range(NUM_CLASSES)]
214 | log_string('---- EPOCH %03d EVALUATION ----' % (global_epoch + 1))
215 | for i, data in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
216 | points, target = data
217 | points = points.data.numpy()
218 | points = torch.Tensor(points)
219 | points, target = points.float().cuda(), target.long().cuda()
220 | points = points.transpose(2, 1)
221 | classifier = classifier.eval()
222 | seg_pred, trans_feat = classifier(points, False)
223 | pred_val = seg_pred.contiguous().cpu().data.numpy()
224 | seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)
225 | batch_label = target.cpu().data.numpy()
226 | target = target.view(-1, 1)[:, 0]
227 | loss = criterion(seg_pred, target, trans_feat, weights)
228 | loss_sum += loss
229 | pred_val = np.argmax(pred_val, 2)
230 | correct = np.sum((pred_val == batch_label))
231 | total_correct += correct
232 | total_seen += (BATCH_SIZE * NUM_POINT)
233 | tmp, _ = np.histogram(batch_label, range(NUM_CLASSES + 1))
234 | labelweights += tmp
235 | for l in range(NUM_CLASSES):
236 | total_seen_class[l] += np.sum((batch_label == l) )
237 | total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l) )
238 | total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)) )
239 | labelweights = labelweights.astype(np.float32) / np.sum(labelweights.astype(np.float32))
240 | mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float) + 1e-6))
241 | log_string('eval mean loss: %f' % (loss_sum / float(num_batches)))
242 | log_string('eval point avg class IoU: %f' % (mIoU))
243 | log_string('eval point accuracy: %f' % (total_correct / float(total_seen)))
244 | log_string('eval point avg class acc: %f' % (
245 | np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float) + 1e-6))))
246 | iou_per_class_str = '------- IoU --------\n'
247 | for l in range(NUM_CLASSES):
248 | iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
249 | seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1],
250 | total_correct_class[l] / float(total_iou_deno_class[l]))
251 |
252 | log_string(iou_per_class_str)
253 | log_string('Eval mean loss: %f' % (loss_sum / num_batches))
254 | log_string('Eval accuracy: %f' % (total_correct / float(total_seen)))
255 | if mIoU >= best_iou:
256 | best_iou = mIoU
257 | logger.info('Save model...')
258 | savepath = str(checkpoints_dir) + '/best_model.pth'
259 | log_string('Saving at %s' % savepath)
260 | state = {
261 | 'epoch': epoch,
262 | 'class_avg_iou': mIoU,
263 | 'model_state_dict': classifier.state_dict(),
264 | 'optimizer_state_dict': optimizer.state_dict(),
265 | }
266 | torch.save(state, savepath)
267 | log_string('Saving model....')
268 | log_string('Best mIoU: %f' % best_iou)
269 | global_epoch += 1
270 |
271 |
272 | if __name__ == '__main__':
273 | args = parse_args()
274 | main(args)
275 |
276 |
--------------------------------------------------------------------------------
/visualizer/build.sh:
--------------------------------------------------------------------------------
1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0
2 |
--------------------------------------------------------------------------------
/visualizer/eulerangles.py:
--------------------------------------------------------------------------------
1 | # emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
2 | # vi: set ft=python sts=4 ts=4 sw=4 et:
3 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
4 | #
5 | # See COPYING file distributed along with the NiBabel package for the
6 | # copyright and license terms.
7 | #
8 | ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9 | ''' Module implementing Euler angle rotations and their conversions
10 | See:
11 | * http://en.wikipedia.org/wiki/Rotation_matrix
12 | * http://en.wikipedia.org/wiki/Euler_angles
13 | * http://mathworld.wolfram.com/EulerAngles.html
14 | See also: *Representing Attitude with Euler Angles and Quaternions: A
15 | Reference* (2006) by James Diebel. A cached PDF link last found here:
16 | http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134
17 | Euler's rotation theorem tells us that any rotation in 3D can be
18 | described by 3 angles. Let's call the 3 angles the *Euler angle vector*
19 | and call the angles in the vector :math:`alpha`, :math:`beta` and
20 | :math:`gamma`. The vector is [ :math:`alpha`,
21 | :math:`beta`. :math:`gamma` ] and, in this description, the order of the
22 | parameters specifies the order in which the rotations occur (so the
23 | rotation corresponding to :math:`alpha` is applied first).
24 | In order to specify the meaning of an *Euler angle vector* we need to
25 | specify the axes around which each of the rotations corresponding to
26 | :math:`alpha`, :math:`beta` and :math:`gamma` will occur.
27 | There are therefore three axes for the rotations :math:`alpha`,
28 | :math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`,
29 | :math:`k`.
30 | Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3
31 | rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3
32 | matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the
33 | whole rotation expressed by the Euler angle vector [ :math:`alpha`,
34 | :math:`beta`. :math:`gamma` ], `R` is given by::
35 | R = np.dot(G, np.dot(B, A))
36 | See http://mathworld.wolfram.com/EulerAngles.html
37 | The order :math:`G B A` expresses the fact that the rotations are
38 | performed in the order of the vector (:math:`alpha` around axis `i` =
39 | `A` first).
40 | To convert a given Euler angle vector to a meaningful rotation, and a
41 | rotation matrix, we need to define:
42 | * the axes `i`, `j`, `k`
43 | * whether a rotation matrix should be applied on the left of a vector to
44 | be transformed (vectors are column vectors) or on the right (vectors
45 | are row vectors).
46 | * whether the rotations move the axes as they are applied (intrinsic
47 | rotations) - compared the situation where the axes stay fixed and the
48 | vectors move within the axis frame (extrinsic)
49 | * the handedness of the coordinate system
50 | See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities
51 | We are using the following conventions:
52 | * axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus
53 | an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ]
54 | in our convention implies a :math:`alpha` radian rotation around the
55 | `z` axis, followed by a :math:`beta` rotation around the `y` axis,
56 | followed by a :math:`gamma` rotation around the `x` axis.
57 | * the rotation matrix applies on the left, to column vectors on the
58 | right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix
59 | with N column vectors, the transformed vector set `vdash` is given by
60 | ``vdash = np.dot(R, v)``.
61 | * extrinsic rotations - the axes are fixed, and do not move with the
62 | rotations.
63 | * a right-handed coordinate system
64 | The convention of rotation around ``z``, followed by rotation around
65 | ``y``, followed by rotation around ``x``, is known (confusingly) as
66 | "xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles.
67 | '''
68 |
69 | import math
70 |
71 | import sys
72 | if sys.version_info >= (3,0):
73 | from functools import reduce
74 |
75 | import numpy as np
76 |
77 |
78 | _FLOAT_EPS_4 = np.finfo(float).eps * 4.0
79 |
80 |
81 | def euler2mat(z=0, y=0, x=0):
82 | ''' Return matrix for rotations around z, y and x axes
83 | Uses the z, then y, then x convention above
84 | Parameters
85 | ----------
86 | z : scalar
87 | Rotation angle in radians around z-axis (performed first)
88 | y : scalar
89 | Rotation angle in radians around y-axis
90 | x : scalar
91 | Rotation angle in radians around x-axis (performed last)
92 | Returns
93 | -------
94 | M : array shape (3,3)
95 | Rotation matrix giving same rotation as for given angles
96 | Examples
97 | --------
98 | >>> zrot = 1.3 # radians
99 | >>> yrot = -0.1
100 | >>> xrot = 0.2
101 | >>> M = euler2mat(zrot, yrot, xrot)
102 | >>> M.shape == (3, 3)
103 | True
104 | The output rotation matrix is equal to the composition of the
105 | individual rotations
106 | >>> M1 = euler2mat(zrot)
107 | >>> M2 = euler2mat(0, yrot)
108 | >>> M3 = euler2mat(0, 0, xrot)
109 | >>> composed_M = np.dot(M3, np.dot(M2, M1))
110 | >>> np.allclose(M, composed_M)
111 | True
112 | You can specify rotations by named arguments
113 | >>> np.all(M3 == euler2mat(x=xrot))
114 | True
115 | When applying M to a vector, the vector should column vector to the
116 | right of M. If the right hand side is a 2D array rather than a
117 | vector, then each column of the 2D array represents a vector.
118 | >>> vec = np.array([1, 0, 0]).reshape((3,1))
119 | >>> v2 = np.dot(M, vec)
120 | >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array
121 | >>> vecs2 = np.dot(M, vecs)
122 | Rotations are counter-clockwise.
123 | >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3))
124 | >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]])
125 | True
126 | >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3))
127 | >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]])
128 | True
129 | >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3))
130 | >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]])
131 | True
132 | Notes
133 | -----
134 | The direction of rotation is given by the right-hand rule (orient
135 | the thumb of the right hand along the axis around which the rotation
136 | occurs, with the end of the thumb at the positive end of the axis;
137 | curl your fingers; the direction your fingers curl is the direction
138 | of rotation). Therefore, the rotations are counterclockwise if
139 | looking along the axis of rotation from positive to negative.
140 | '''
141 | Ms = []
142 | if z:
143 | cosz = math.cos(z)
144 | sinz = math.sin(z)
145 | Ms.append(np.array(
146 | [[cosz, -sinz, 0],
147 | [sinz, cosz, 0],
148 | [0, 0, 1]]))
149 | if y:
150 | cosy = math.cos(y)
151 | siny = math.sin(y)
152 | Ms.append(np.array(
153 | [[cosy, 0, siny],
154 | [0, 1, 0],
155 | [-siny, 0, cosy]]))
156 | if x:
157 | cosx = math.cos(x)
158 | sinx = math.sin(x)
159 | Ms.append(np.array(
160 | [[1, 0, 0],
161 | [0, cosx, -sinx],
162 | [0, sinx, cosx]]))
163 | if Ms:
164 | return reduce(np.dot, Ms[::-1])
165 | return np.eye(3)
166 |
167 |
168 | def mat2euler(M, cy_thresh=None):
169 | ''' Discover Euler angle vector from 3x3 matrix
170 | Uses the conventions above.
171 | Parameters
172 | ----------
173 | M : array-like, shape (3,3)
174 | cy_thresh : None or scalar, optional
175 | threshold below which to give up on straightforward arctan for
176 | estimating x rotation. If None (default), estimate from
177 | precision of input.
178 | Returns
179 | -------
180 | z : scalar
181 | y : scalar
182 | x : scalar
183 | Rotations in radians around z, y, x axes, respectively
184 | Notes
185 | -----
186 | If there was no numerical error, the routine could be derived using
187 | Sympy expression for z then y then x rotation matrix, which is::
188 | [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)],
189 | [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)],
190 | [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)]
191 | with the obvious derivations for z, y, and x
192 | z = atan2(-r12, r11)
193 | y = asin(r13)
194 | x = atan2(-r23, r33)
195 | Problems arise when cos(y) is close to zero, because both of::
196 | z = atan2(cos(y)*sin(z), cos(y)*cos(z))
197 | x = atan2(cos(y)*sin(x), cos(x)*cos(y))
198 | will be close to atan2(0, 0), and highly unstable.
199 | The ``cy`` fix for numerical instability below is from: *Graphics
200 | Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN:
201 | 0123361559. Specifically it comes from EulerAngles.c by Ken
202 | Shoemake, and deals with the case where cos(y) is close to zero:
203 | See: http://www.graphicsgems.org/
204 | The code appears to be licensed (from the website) as "can be used
205 | without restrictions".
206 | '''
207 | M = np.asarray(M)
208 | if cy_thresh is None:
209 | try:
210 | cy_thresh = np.finfo(M.dtype).eps * 4
211 | except ValueError:
212 | cy_thresh = _FLOAT_EPS_4
213 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat
214 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2)
215 | cy = math.sqrt(r33*r33 + r23*r23)
216 | if cy > cy_thresh: # cos(y) not close to zero, standard form
217 | z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z))
218 | y = math.atan2(r13, cy) # atan2(sin(y), cy)
219 | x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y))
220 | else: # cos(y) (close to) zero, so x -> 0.0 (see above)
221 | # so r21 -> sin(z), r22 -> cos(z) and
222 | z = math.atan2(r21, r22)
223 | y = math.atan2(r13, cy) # atan2(sin(y), cy)
224 | x = 0.0
225 | return z, y, x
226 |
227 |
228 | def euler2quat(z=0, y=0, x=0):
229 | ''' Return quaternion corresponding to these Euler angles
230 | Uses the z, then y, then x convention above
231 | Parameters
232 | ----------
233 | z : scalar
234 | Rotation angle in radians around z-axis (performed first)
235 | y : scalar
236 | Rotation angle in radians around y-axis
237 | x : scalar
238 | Rotation angle in radians around x-axis (performed last)
239 | Returns
240 | -------
241 | quat : array shape (4,)
242 | Quaternion in w, x, y z (real, then vector) format
243 | Notes
244 | -----
245 | We can derive this formula in Sympy using:
246 | 1. Formula giving quaternion corresponding to rotation of theta radians
247 | about arbitrary axis:
248 | http://mathworld.wolfram.com/EulerParameters.html
249 | 2. Generated formulae from 1.) for quaternions corresponding to
250 | theta radians rotations about ``x, y, z`` axes
251 | 3. Apply quaternion multiplication formula -
252 | http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to
253 | formulae from 2.) to give formula for combined rotations.
254 | '''
255 | z = z/2.0
256 | y = y/2.0
257 | x = x/2.0
258 | cz = math.cos(z)
259 | sz = math.sin(z)
260 | cy = math.cos(y)
261 | sy = math.sin(y)
262 | cx = math.cos(x)
263 | sx = math.sin(x)
264 | return np.array([
265 | cx*cy*cz - sx*sy*sz,
266 | cx*sy*sz + cy*cz*sx,
267 | cx*cz*sy - sx*cy*sz,
268 | cx*cy*sz + sx*cz*sy])
269 |
270 |
271 | def quat2euler(q):
272 | ''' Return Euler angles corresponding to quaternion `q`
273 | Parameters
274 | ----------
275 | q : 4 element sequence
276 | w, x, y, z of quaternion
277 | Returns
278 | -------
279 | z : scalar
280 | Rotation angle in radians around z-axis (performed first)
281 | y : scalar
282 | Rotation angle in radians around y-axis
283 | x : scalar
284 | Rotation angle in radians around x-axis (performed last)
285 | Notes
286 | -----
287 | It's possible to reduce the amount of calculation a little, by
288 | combining parts of the ``quat2mat`` and ``mat2euler`` functions, but
289 | the reduction in computation is small, and the code repetition is
290 | large.
291 | '''
292 | # delayed import to avoid cyclic dependencies
293 | import nibabel.quaternions as nq
294 | return mat2euler(nq.quat2mat(q))
295 |
296 |
297 | def euler2angle_axis(z=0, y=0, x=0):
298 | ''' Return angle, axis corresponding to these Euler angles
299 | Uses the z, then y, then x convention above
300 | Parameters
301 | ----------
302 | z : scalar
303 | Rotation angle in radians around z-axis (performed first)
304 | y : scalar
305 | Rotation angle in radians around y-axis
306 | x : scalar
307 | Rotation angle in radians around x-axis (performed last)
308 | Returns
309 | -------
310 | theta : scalar
311 | angle of rotation
312 | vector : array shape (3,)
313 | axis around which rotation occurs
314 | Examples
315 | --------
316 | >>> theta, vec = euler2angle_axis(0, 1.5, 0)
317 | >>> print(theta)
318 | 1.5
319 | >>> np.allclose(vec, [0, 1, 0])
320 | True
321 | '''
322 | # delayed import to avoid cyclic dependencies
323 | import nibabel.quaternions as nq
324 | return nq.quat2angle_axis(euler2quat(z, y, x))
325 |
326 |
327 | def angle_axis2euler(theta, vector, is_normalized=False):
328 | ''' Convert angle, axis pair to Euler angles
329 | Parameters
330 | ----------
331 | theta : scalar
332 | angle of rotation
333 | vector : 3 element sequence
334 | vector specifying axis for rotation.
335 | is_normalized : bool, optional
336 | True if vector is already normalized (has norm of 1). Default
337 | False
338 | Returns
339 | -------
340 | z : scalar
341 | y : scalar
342 | x : scalar
343 | Rotations in radians around z, y, x axes, respectively
344 | Examples
345 | --------
346 | >>> z, y, x = angle_axis2euler(0, [1, 0, 0])
347 | >>> np.allclose((z, y, x), 0)
348 | True
349 | Notes
350 | -----
351 | It's possible to reduce the amount of calculation a little, by
352 | combining parts of the ``angle_axis2mat`` and ``mat2euler``
353 | functions, but the reduction in computation is small, and the code
354 | repetition is large.
355 | '''
356 | # delayed import to avoid cyclic dependencies
357 | import nibabel.quaternions as nq
358 | M = nq.angle_axis2mat(theta, vector, is_normalized)
359 | return mat2euler(M)
--------------------------------------------------------------------------------
/visualizer/pc_utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions for processing point clouds.
2 | Author: Charles R. Qi, Hao Su
3 | Date: November 2016
4 | """
5 |
6 | import os
7 | import sys
8 |
9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10 | sys.path.append(BASE_DIR)
11 |
12 | # Draw point cloud
13 | from visualizer.eulerangles import euler2mat
14 |
15 | # Point cloud IO
16 | import numpy as np
17 | from visualizer.plyfile import PlyData, PlyElement
18 |
19 | # ----------------------------------------
20 | # Point Cloud/Volume Conversions
21 | # ----------------------------------------
22 |
23 | def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True):
24 | """ Input is BxNx3 batch of point cloud
25 | Output is Bx(vsize^3)
26 | """
27 | vol_list = []
28 | for b in range(point_clouds.shape[0]):
29 | vol = point_cloud_to_volume(np.squeeze(point_clouds[b, :, :]), vsize, radius)
30 | if flatten:
31 | vol_list.append(vol.flatten())
32 | else:
33 | vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0))
34 | if flatten:
35 | return np.vstack(vol_list)
36 | else:
37 | return np.concatenate(vol_list, 0)
38 |
39 |
40 | def point_cloud_to_volume(points, vsize, radius=1.0):
41 | """ input is Nx3 points.
42 | output is vsize*vsize*vsize
43 | assumes points are in range [-radius, radius]
44 | """
45 | vol = np.zeros((vsize, vsize, vsize))
46 | voxel = 2 * radius / float(vsize)
47 | locations = (points + radius) / voxel
48 | locations = locations.astype(int)
49 | vol[locations[:, 0], locations[:, 1], locations[:, 2]] = 1.0
50 | return vol
51 |
52 |
53 | # a = np.zeros((16,1024,3))
54 | # print point_cloud_to_volume_batch(a, 12, 1.0, False).shape
55 |
56 | def volume_to_point_cloud(vol):
57 | """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize
58 | return Nx3 numpy array.
59 | """
60 | vsize = vol.shape[0]
61 | assert (vol.shape[1] == vsize and vol.shape[1] == vsize)
62 | points = []
63 | for a in range(vsize):
64 | for b in range(vsize):
65 | for c in range(vsize):
66 | if vol[a, b, c] == 1:
67 | points.append(np.array([a, b, c]))
68 | if len(points) == 0:
69 | return np.zeros((0, 3))
70 | points = np.vstack(points)
71 | return points
72 |
73 |
74 | # ----------------------------------------
75 | # Point cloud IO
76 | # ----------------------------------------
77 |
78 | def read_ply(filename):
79 | """ read XYZ point cloud from filename PLY file """
80 | plydata = PlyData.read(filename)
81 | pc = plydata['vertex'].data
82 | pc_array = np.array([[x, y, z] for x, y, z in pc])
83 | return pc_array
84 |
85 |
86 | def write_ply(points, filename, text=True):
87 | """ input: Nx3, write points to filename as PLY format. """
88 | points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
89 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
90 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
91 | PlyData([el], text=text).write(filename)
92 |
93 |
94 | # ----------------------------------------
95 | # Simple Point cloud and Volume Renderers
96 | # ----------------------------------------
97 |
98 | def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25,
99 | xrot=0, yrot=0, zrot=0, switch_xyz=[0, 1, 2], normalize=True):
100 | """ Render point cloud to image with alpha channel.
101 | Input:
102 | points: Nx3 numpy array (+y is up direction)
103 | Output:
104 | gray image as numpy array of size canvasSizexcanvasSize
105 | """
106 | image = np.zeros((canvasSize, canvasSize))
107 | if input_points is None or input_points.shape[0] == 0:
108 | return image
109 |
110 | points = input_points[:, switch_xyz]
111 | M = euler2mat(zrot, yrot, xrot)
112 | points = (np.dot(M, points.transpose())).transpose()
113 |
114 | # Normalize the point cloud
115 | # We normalize scale to fit points in a unit sphere
116 | if normalize:
117 | centroid = np.mean(points, axis=0)
118 | points -= centroid
119 | furthest_distance = np.max(np.sqrt(np.sum(abs(points) ** 2, axis=-1)))
120 | points /= furthest_distance
121 |
122 | # Pre-compute the Gaussian disk
123 | radius = (diameter - 1) / 2.0
124 | disk = np.zeros((diameter, diameter))
125 | for i in range(diameter):
126 | for j in range(diameter):
127 | if (i - radius) * (i - radius) + (j - radius) * (j - radius) <= radius * radius:
128 | disk[i, j] = np.exp((-(i - radius) ** 2 - (j - radius) ** 2) / (radius ** 2))
129 | mask = np.argwhere(disk > 0)
130 | dx = mask[:, 0]
131 | dy = mask[:, 1]
132 | dv = disk[disk > 0]
133 |
134 | # Order points by z-buffer
135 | zorder = np.argsort(points[:, 2])
136 | points = points[zorder, :]
137 | points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2])))
138 | max_depth = np.max(points[:, 2])
139 |
140 | for i in range(points.shape[0]):
141 | j = points.shape[0] - i - 1
142 | x = points[j, 0]
143 | y = points[j, 1]
144 | xc = canvasSize / 2 + (x * space)
145 | yc = canvasSize / 2 + (y * space)
146 | xc = int(np.round(xc))
147 | yc = int(np.round(yc))
148 |
149 | px = dx + xc
150 | py = dy + yc
151 |
152 | image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3
153 |
154 | image = image / np.max(image)
155 | return image
156 |
157 |
158 | def point_cloud_three_views(points):
159 | """ input points Nx3 numpy array (+y is up direction).
160 | return an numpy array gray image of size 500x1500. """
161 | # +y is up direction
162 | # xrot is azimuth
163 | # yrot is in-plane
164 | # zrot is elevation
165 | img1 = draw_point_cloud(points, zrot=110 / 180.0 * np.pi, xrot=45 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi)
166 | img2 = draw_point_cloud(points, zrot=70 / 180.0 * np.pi, xrot=135 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi)
167 | img3 = draw_point_cloud(points, zrot=180.0 / 180.0 * np.pi, xrot=90 / 180.0 * np.pi, yrot=0 / 180.0 * np.pi)
168 | image_large = np.concatenate([img1, img2, img3], 1)
169 | return image_large
170 |
171 |
172 | from PIL import Image
173 |
174 |
175 | def point_cloud_three_views_demo():
176 | """ Demo for draw_point_cloud function """
177 | DATA_PATH = '../data/ShapeNet/'
178 | train_data, _, _, _, _, _ = load_data(DATA_PATH,classification=False)
179 | points = train_data[1]
180 | im_array = point_cloud_three_views(points)
181 | img = Image.fromarray(np.uint8(im_array * 255.0))
182 | img.save('example.jpg')
183 |
184 |
185 | if __name__ == "__main__":
186 | from data_utils.ShapeNetDataLoader import load_data
187 | point_cloud_three_views_demo()
188 |
189 | import matplotlib.pyplot as plt
190 |
191 |
192 | def pyplot_draw_point_cloud(points, output_filename):
193 | """ points is a Nx3 numpy array """
194 | fig = plt.figure()
195 | ax = fig.add_subplot(111, projection='3d')
196 | ax.scatter(points[:, 0], points[:, 1], points[:, 2])
197 | ax.set_xlabel('x')
198 | ax.set_ylabel('y')
199 | ax.set_zlabel('z')
200 | # savefig(output_filename)
201 |
202 |
203 | def pyplot_draw_volume(vol, output_filename):
204 | """ vol is of size vsize*vsize*vsize
205 | output an image to output_filename
206 | """
207 | points = volume_to_point_cloud(vol)
208 | pyplot_draw_point_cloud(points, output_filename)
209 |
210 |
211 |
212 |
--------------------------------------------------------------------------------
/visualizer/render_balls_so.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | using namespace std;
6 |
7 | struct PointInfo{
8 | int x,y,z;
9 | float r,g,b;
10 | };
11 |
12 | extern "C"{
13 |
14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){
15 | r=max(r,1);
16 | vector depth(h*w,-2100000000);
17 | vector pattern;
18 | for (int dx=-r;dx<=r;dx++)
19 | for (int dy=-r;dy<=r;dy++)
20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0:
95 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=0))
96 | if magnifyBlue >= 2:
97 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=0))
98 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], 1, axis=1))
99 | if magnifyBlue >= 2:
100 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(show[:, :, 0], -1, axis=1))
101 | if showrot:
102 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), (30, showsz - 30), 0, 0.5,
103 | cv2.cv.CV_RGB(255, 0, 0))
104 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), (30, showsz - 50), 0, 0.5,
105 | cv2.cv.CV_RGB(255, 0, 0))
106 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
107 |
108 | changed = True
109 | while True:
110 | if changed:
111 | render()
112 | changed = False
113 | cv2.imshow('show3d', show)
114 | if waittime == 0:
115 | cmd = cv2.waitKey(10) % 256
116 | else:
117 | cmd = cv2.waitKey(waittime) % 256
118 | if cmd == ord('q'):
119 | break
120 | elif cmd == ord('Q'):
121 | sys.exit(0)
122 |
123 | if cmd == ord('t') or cmd == ord('p'):
124 | if cmd == ord('t'):
125 | if c_gt is None:
126 | c0 = np.zeros((len(xyz),), dtype='float32') + 255
127 | c1 = np.zeros((len(xyz),), dtype='float32') + 255
128 | c2 = np.zeros((len(xyz),), dtype='float32') + 255
129 | else:
130 | c0 = c_gt[:, 0]
131 | c1 = c_gt[:, 1]
132 | c2 = c_gt[:, 2]
133 | else:
134 | if c_pred is None:
135 | c0 = np.zeros((len(xyz),), dtype='float32') + 255
136 | c1 = np.zeros((len(xyz),), dtype='float32') + 255
137 | c2 = np.zeros((len(xyz),), dtype='float32') + 255
138 | else:
139 | c0 = c_pred[:, 0]
140 | c1 = c_pred[:, 1]
141 | c2 = c_pred[:, 2]
142 | if normalizecolor:
143 | c0 /= (c0.max() + 1e-14) / 255.0
144 | c1 /= (c1.max() + 1e-14) / 255.0
145 | c2 /= (c2.max() + 1e-14) / 255.0
146 | c0 = np.require(c0, 'float32', 'C')
147 | c1 = np.require(c1, 'float32', 'C')
148 | c2 = np.require(c2, 'float32', 'C')
149 | changed = True
150 |
151 | if cmd == ord('n'):
152 | zoom *= 1.1
153 | changed = True
154 | elif cmd == ord('m'):
155 | zoom /= 1.1
156 | changed = True
157 | elif cmd == ord('r'):
158 | zoom = 1.0
159 | changed = True
160 | elif cmd == ord('s'):
161 | cv2.imwrite('show3d.png', show)
162 | if waittime != 0:
163 | break
164 | return cmd
165 |
166 |
167 | if __name__ == '__main__':
168 | import os
169 | import numpy as np
170 | import argparse
171 |
172 | parser = argparse.ArgumentParser()
173 | parser.add_argument('--dataset', type=str, default='../data/shapenet', help='dataset path')
174 | parser.add_argument('--category', type=str, default='Airplane', help='select category')
175 | parser.add_argument('--npoints', type=int, default=2500, help='resample points number')
176 | parser.add_argument('--ballradius', type=int, default=10, help='ballradius')
177 | opt = parser.parse_args()
178 | '''
179 | Airplane 02691156
180 | Bag 02773838
181 | Cap 02954340
182 | Car 02958343
183 | Chair 03001627
184 | Earphone 03261776
185 | Guitar 03467517
186 | Knife 03624134
187 | Lamp 03636649
188 | Laptop 03642806
189 | Motorbike 03790512
190 | Mug 03797390
191 | Pistol 03948459
192 | Rocket 04099429
193 | Skateboard 04225987
194 | Table 04379243'''
195 |
196 | cmap = np.array([[1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
197 | [3.12493437e-02, 1.00000000e+00, 1.31250131e-06],
198 | [0.00000000e+00, 6.25019688e-02, 1.00000000e+00],
199 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
200 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
201 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
202 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
203 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
204 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02],
205 | [1.00000000e+00, 0.00000000e+00, 9.37500000e-02]])
206 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
207 | ROOT_DIR = os.path.dirname(BASE_DIR)
208 | sys.path.append(BASE_DIR)
209 | sys.path.append(os.path.join(ROOT_DIR, 'data_utils'))
210 |
211 | from ShapeNetDataLoader import PartNormalDataset
212 | root = '../data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'
213 | dataset = PartNormalDataset(root = root, npoints=2048, split='test', normal_channel=False)
214 | idx = np.random.randint(0, len(dataset))
215 | data = dataset[idx]
216 | point_set, _, seg = data
217 | choice = np.random.choice(point_set.shape[0], opt.npoints, replace=True)
218 | point_set, seg = point_set[choice, :], seg[choice]
219 | seg = seg - seg.min()
220 | gt = cmap[seg, :]
221 | pred = cmap[seg, :]
222 | showpoints(point_set, gt, c_pred=pred, waittime=0, showrot=False, magnifyBlue=0, freezerot=False,
223 | background=(255, 255, 255), normalizecolor=True, ballradius=opt.ballradius)
224 |
225 |
226 |
--------------------------------------------------------------------------------