├── LICENSE
├── README.md
├── data_utils
├── ModelNetDataLoader.py
├── bhcp_dataloader.py
├── data_flower.py
├── data_utils.py
├── dataset.py
├── keypointnet_dataloader.py
└── shapenet_seg_dataloader.py
├── images
├── Pipeline.png
└── VisualizationResults.png
├── models
├── __init__.py
├── chamfer_distance.py
├── model_weightchamfer.py
├── pointnetpp
│ ├── pointnet.py
│ ├── pointnet2_cls_msg.py
│ ├── pointnet2_cls_ssg.py
│ ├── pointnet2_part_seg_msg.py
│ ├── pointnet2_part_seg_ssg.py
│ ├── pointnet2_sem_seg.py
│ ├── pointnet2_sem_seg_msg.py
│ ├── pointnet_cls.py
│ ├── pointnet_part_seg.py
│ ├── pointnet_sem_seg.py
│ └── pointnet_util.py
└── torch_pointnet_utils.py
├── pointnet2_ops_lib
├── MANIFEST.in
├── pointnet2_ops.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ ├── requires.txt
│ └── top_level.txt
├── pointnet2_ops
│ ├── __init__.py
│ ├── _ext-src
│ │ ├── include
│ │ │ ├── ball_query.h
│ │ │ ├── cuda_utils.h
│ │ │ ├── group_points.h
│ │ │ ├── interpolate.h
│ │ │ ├── sampling.h
│ │ │ └── utils.h
│ │ └── src
│ │ │ ├── ball_query.cpp
│ │ │ ├── ball_query_gpu.cu
│ │ │ ├── bindings.cpp
│ │ │ ├── group_points.cpp
│ │ │ ├── group_points_gpu.cu
│ │ │ ├── interpolate.cpp
│ │ │ ├── interpolate_gpu.cu
│ │ │ ├── sampling.cpp
│ │ │ └── sampling_gpu.cu
│ ├── _version.py
│ ├── pointnet2_modules.py
│ └── pointnet2_utils.py
└── setup.py
├── provider.py
├── test_correspondence.py
├── train.py
└── utils
├── __init__.py
├── check_points_utils.py
├── logutils.py
├── mesh_utils.py
└── point_cloud_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds
2 |
3 | ## Description
4 | This repository contains the code for our paper: **Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds**.
5 | > [**Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds**](https://doi.org/10.1016/j.neunet.2024.106158),
6 | > Yiqi Wu, Xingye Chen, Xuan Huang, Kelin Song, Dejun Zhang
7 | > [Bibetex](https://github.com/Chenguoz/Keypoints#citation)
8 |
9 |
10 |

11 |
12 |
13 |
14 |

15 |
16 |
17 |
18 | ## Environment setup
19 | ```
20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
21 | pip install pointnet2_ops_lib/.
22 | ```
23 | ## Dataset
24 | The training and testing data for correspondence is provided by [KeypointNet](https://github.com/qq456cvb/KeypointNet) and [ShapeNet](https://github.com/antao97/PointCloudDatasets)
25 |
26 |
27 | ## Citation
28 |
29 | ```bibtex
30 | @article{wu2024unsupervised,
31 | title={Unsupervised distribution-aware keypoints generation from 3D point clouds},
32 | author={Wu, Yiqi and Chen, Xingye and Huang, Xuan and Song, Kelin and Zhang, Dejun},
33 | journal={Neural Networks},
34 | pages={106158},
35 | year={2024},
36 | publisher={Elsevier}
37 | }
38 | ```
39 |
40 | ## Acknowledgment
41 |
42 | Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.
43 |
44 | [3DStructurePoints](https://github.com/NolenChen/3DStructurePoints),
45 | [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
46 |
--------------------------------------------------------------------------------
/data_utils/ModelNetDataLoader.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author: Xu Yan
3 | @file: ModelNet.py
4 | @time: 2021/3/19 15:51
5 | '''
6 | import argparse
7 | import os
8 | import numpy as np
9 | import warnings
10 | import pickle
11 |
12 | from tqdm import tqdm
13 | from torch.utils.data import Dataset
14 |
15 | warnings.filterwarnings('ignore')
16 |
17 |
18 | def pc_normalize(pc):
19 | centroid = np.mean(pc, axis=0)
20 | pc = pc - centroid
21 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
22 | pc = pc / m
23 | return pc
24 |
25 |
26 | def farthest_point_sample(point, npoint):
27 | """
28 | Input:
29 | xyz: pointcloud data, [N, D]
30 | npoint: number of samples
31 | Return:
32 | centroids: sampled pointcloud index, [npoint, D]
33 | """
34 | N, D = point.shape
35 | xyz = point[:, :3]
36 | centroids = np.zeros((npoint,))
37 | distance = np.ones((N,)) * 1e10
38 | farthest = np.random.randint(0, N)
39 | for i in range(npoint):
40 | centroids[i] = farthest
41 | centroid = xyz[farthest, :]
42 | dist = np.sum((xyz - centroid) ** 2, -1)
43 | mask = dist < distance
44 | distance[mask] = dist[mask]
45 | farthest = np.argmax(distance, -1)
46 | point = point[centroids.astype(np.int32)]
47 | return point
48 |
49 |
50 | class ModelNetDataLoader(Dataset):
51 | def __init__(self, root, args, category='chair', split='train', process_data=False):
52 | self.root = root
53 | self.npoints = args.num_inputs
54 | self.process_data = process_data
55 | self.uniform = args.use_uniform_sample
56 | self.use_normals = args.use_normals
57 | self.num_category = args.num_category
58 | self.category = category
59 | if self.num_category == 10:
60 | self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
61 | else:
62 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
63 |
64 | self.cat = [line.rstrip() for line in open(self.catfile)]
65 | self.classes = dict(zip(self.cat, range(len(self.cat))))
66 |
67 | shape_ids = {}
68 | if self.num_category == 10:
69 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
70 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
71 | else:
72 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
73 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
74 |
75 | assert (split == 'train' or split == 'test')
76 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
77 |
78 | if not category == 'all':
79 | self.datapath = [
80 | (shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
81 | in range(len(shape_ids[split])) if shape_names[i] == category]
82 | else:
83 | self.datapath = [
84 | (shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
85 | in range(len(shape_ids[split]))]
86 | print('The size of %s data is %d' % (split, len(self.datapath)))
87 |
88 | if self.uniform:
89 | self.save_path = os.path.join(root,
90 | 'modelnet%d_%s_%d_%spts_fps.dat' % (self.num_category, split, self.npoints, self.category))
91 | else:
92 | self.save_path = os.path.join(root, 'modelnet%d_%s_%d_%spts.dat' % (self.num_category, split, self.npoints, self.category))
93 |
94 | if self.process_data:
95 | if not os.path.exists(self.save_path):
96 | print('Processing data %s (only running in the first time)...' % self.save_path)
97 | self.list_of_points = [None] * len(self.datapath)
98 | self.list_of_labels = [None] * len(self.datapath)
99 |
100 | for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
101 | fn = self.datapath[index]
102 | cls = self.classes[self.datapath[index][0]]
103 | cls = np.array([cls]).astype(np.int32)
104 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
105 |
106 | if self.uniform:
107 | point_set = farthest_point_sample(point_set, self.npoints)
108 | else:
109 | point_set = point_set[0:self.npoints, :]
110 |
111 | self.list_of_points[index] = point_set
112 | self.list_of_labels[index] = cls
113 |
114 | with open(self.save_path, 'wb') as f:
115 | pickle.dump([self.list_of_points, self.list_of_labels], f)
116 | else:
117 | print('Load processed data from %s...' % self.save_path)
118 | with open(self.save_path, 'rb') as f:
119 | self.list_of_points, self.list_of_labels = pickle.load(f)
120 |
121 | def __len__(self):
122 | return len(self.datapath)
123 |
124 | def _get_item(self, index):
125 | if self.process_data:
126 | point_set, label = self.list_of_points[index], self.list_of_labels[index]
127 | else:
128 | fn = self.datapath[index]
129 | cls = self.classes[self.datapath[index][0]]
130 | label = np.array([cls]).astype(np.int32)
131 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
132 |
133 | if self.uniform:
134 | point_set = farthest_point_sample(point_set, self.npoints)
135 | else:
136 | point_set = point_set[0:self.npoints, :]
137 |
138 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
139 | if not self.use_normals:
140 | point_set = point_set[:, 0:3]
141 |
142 | return point_set, label[0]
143 |
144 | def __getitem__(self, index):
145 | return self._get_item(index)
146 |
147 |
148 | def parse_args():
149 | parser = argparse.ArgumentParser(
150 | description="Arguments",
151 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
152 | )
153 | parser.add_argument("-batch_size", type=int, default=20, help="Batch size")
154 | parser.add_argument(
155 | "-weight_decay", type=float, default=1e-5, help="L2 regularization coeff"
156 | )
157 | parser.add_argument("-lr", type=float, default=1e-2, help="Initial learning rate")
158 | parser.add_argument(
159 | "-lr_decay", type=float, default=0.7, help="Learning rate decay gamma"
160 | )
161 | parser.add_argument(
162 | "-decay_batch", type=float, default=20, help="Learning rate decay batch"
163 | )
164 | parser.add_argument(
165 | "-bn_momentum", type=float, default=0.5, help="Initial batch norm momentum"
166 | )
167 | parser.add_argument(
168 | "-bnm_decay", type=float, default=0.5, help="Batch norm momentum decay gamma"
169 | )
170 |
171 | parser.add_argument(
172 | "-checkpoint_save_step", type=int, default=50, help="Step for saving Checkpoint"
173 | )
174 |
175 | parser.add_argument(
176 | "-checkpoint", type=str, default=None
177 | , help="Checkpoint to start from"
178 | )
179 | parser.add_argument(
180 | "-num_of_transform", type=int, default=0,
181 | help="Number of transforms for rotation data augmentation. Useful when testing on shapes without alignment"
182 | )
183 |
184 | parser.add_argument(
185 | "-num_inputs", type=int, default=1024, help="sample points from initial point cloud"
186 | )
187 |
188 | parser.add_argument(
189 | "-num_structure_points", type=int, default=16
190 | , help="Number of structure points"
191 | )
192 | parser.add_argument(
193 | "-category", type=str, default='chair', help="Category of the objects to train"
194 | )
195 | parser.add_argument(
196 | "-data_dir", type=str, default="training_data/", help="Root of the training data"
197 | )
198 | parser.add_argument(
199 | "-test_data_dir", type=str, default="demo_data/", help="Root of the test data"
200 | )
201 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
202 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
203 |
204 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
205 |
206 | parser.add_argument(
207 | "-max_epochs", type=int, default=200, help="Number of epochs to train for"
208 | )
209 | parser.add_argument(
210 | "-log_dir", type=str, default=None, help="Root of the log"
211 | )
212 | parser.add_argument(
213 | "-multi_distribution", type=int, default=5, help="Multivariate normal distribution nums"
214 | )
215 | parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
216 | parser.add_argument('-model', default='PointSPN', help='model name [default: PointSPN]')
217 | args = parser.parse_args()
218 | return args
219 |
220 |
221 | if __name__ == '__main__':
222 | import torch
223 |
224 | args = parse_args()
225 | data = ModelNetDataLoader('../modelnet40/', args, split='train')
226 |
227 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True, num_workers=10)
228 | for point, label in DataLoader:
229 | print(point.shape)
230 | print(label.shape)
231 |
--------------------------------------------------------------------------------
/data_utils/bhcp_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os
3 | import ntpath
4 | import pickle
5 | from .ModelNetDataLoader import pc_normalize
6 |
7 |
8 | class bhcp_dataloader(data.Dataset):
9 | def __init__(self, data_root, category, is_pts_aligned=False, preload_data=True, split='train'):
10 | super().__init__()
11 | if split == 'train':
12 | # self.data_path = os.path.join(data_root , category)
13 |
14 | self.data_path = os.path.join(data_root + '/training_data', category)
15 | else:
16 | self.data_path = os.path.join(data_root + '/test_data', category)
17 |
18 | self.sub_dirs = [ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()]
19 | self.data_num = len(self.sub_dirs)
20 | self.is_pts_aligned = is_pts_aligned
21 |
22 | self.meta_data_list = None
23 | if preload_data:
24 | self.meta_data_list = []
25 | for i in range(len(self.sub_dirs)):
26 | meta_fname = os.path.join(self.data_path, self.sub_dirs[i], 'meta.pkl')
27 | with open(meta_fname, 'rb') as f:
28 | meta_data = pickle.load(f)
29 | self.meta_data_list.append(meta_data)
30 |
31 | def __getitem__(self, idx):
32 | if self.meta_data_list is None:
33 | meta_fname = os.path.join(self.data_path, self.sub_dirs[idx], 'meta.pkl')
34 | f = open(meta_fname, 'rb')
35 | meta_data = pickle.load(f)
36 | else:
37 | meta_data = self.meta_data_list[idx]
38 |
39 | if self.is_pts_aligned:
40 | if 'points_aligned' in meta_data:
41 | points = meta_data['points_aligned']
42 | else:
43 | points = meta_data['points']
44 | else:
45 | points = meta_data['points']
46 | points[:, 0:3] = pc_normalize(points[:, 0:3])
47 |
48 | res = {}
49 | res['points'] = points
50 | if 'feat_pts' in meta_data: # the labeled feature points on the bhcp dataset for computing correspondence accuracy
51 | if self.is_pts_aligned:
52 | res['feat_pts'] = meta_data['feat_pts_aligned']
53 | else:
54 | res['feat_pts'] = meta_data['feat_pts']
55 |
56 | res['data_id'] = idx
57 | return points,meta_data['feat_pts']
58 | # return res
59 |
60 | def __len__(self):
61 | return self.data_num
62 |
--------------------------------------------------------------------------------
/data_utils/data_flower.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Fri Apr 24 22:59:57 2020
4 |
5 | @author: eliphat
6 | """
7 | import os
8 | import random
9 | import numpy
10 | import h5py
11 |
12 |
13 | def load_h5(h5_filename, normalize=False, include_label=False):
14 | f = h5py.File(h5_filename, 'r')
15 | data = f['data'][:] # (n, 2048, 3)
16 | if normalize:
17 | # nmean = numpy.mean(data, axis=1, keepdims=True)
18 | # nstd = numpy.std(data, axis=1, keepdims=True)
19 | # nstd = numpy.mean(nstd, axis=-1, keepdims=True)
20 | dmin = data.min(axis=1, keepdims=True).min(axis=-1, keepdims=True)
21 | dmax = data.max(axis=1, keepdims=True).max(axis=-1, keepdims=True)
22 | data = (data - dmin) / (dmax - dmin)
23 | # data = (data - nmean) / nstd
24 | data = 2.0 * (data - 0.5)
25 | if include_label:
26 | label = f['label'][:]
27 | return data, label
28 | return data
29 |
30 |
31 | def all_h5(parent, normalize=False, include_label=False,
32 | subclasses=tuple(range(40)), sample=256):
33 | lazy = map(lambda x: load_h5(x, normalize, include_label),
34 | walk_files(parent))
35 | if include_label:
36 | xy = tuple(lazy)
37 | x = [x for x, y in xy]
38 | y = [y for x, y in xy]
39 | x = numpy.concatenate(x)
40 | y = numpy.concatenate(y)
41 | xf = []
42 | yf = []
43 | for xp, yp in zip(x, y):
44 | if yp[0] in subclasses:
45 | if sample is None:
46 | xf.append(xp)
47 | else:
48 | xf.append(random.choices(xp, k=sample))
49 | yf.append(numpy.eye(len(subclasses))[subclasses.index(yp[0])])
50 | return numpy.array(xf), numpy.array(yf)
51 | return numpy.concatenate(tuple(lazy))
52 |
53 |
54 | def walk_files(path):
55 | for r, ds, fs in os.walk(path):
56 | for f in fs:
57 | yield os.path.join(r, f)
58 |
59 |
60 | def last_dirname(file_path):
61 | return os.path.basename(os.path.dirname(file_path))
62 |
63 |
64 | def dataset_split(path):
65 | flist = list(walk_files(path))
66 | tr = filter(lambda p: 'train' in last_dirname(p), flist)
67 | te = filter(lambda p: 'test' in last_dirname(p), flist)
68 | return list(tr), list(te)
69 |
--------------------------------------------------------------------------------
/data_utils/data_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import sys
9 | sys.path.append("..")
10 | import torch
11 | import numpy as np
12 | from torchvision import transforms
13 | import utils.point_cloud_utils as point_cloud_utils
14 |
15 |
16 | def angle_axis_tensor(angle, axis):
17 | # type: (float, np.ndarray) -> float
18 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
19 |
20 | Parameters
21 | ----------
22 | angle : float
23 | Angle to rotate by
24 | axis: torch.Tensor
25 | Axis to rotate about
26 |
27 | Returns
28 | -------
29 | torch.Tensor
30 | 3x3 rotation matrix
31 | """
32 | u = axis / torch.norm(axis)
33 | cosval, sinval = torch.cos(angle), torch.sin(angle)
34 |
35 | # yapf: disable
36 | cross_prod_mat = torch.Tensor([[0.0, -u[2], u[1]],
37 | [u[2], 0.0, -u[0]],
38 | [-u[1], u[0], 0.0]]).type(torch.FloatTensor)
39 | R = cosval * torch.eye(3).type(torch.FloatTensor) + sinval * cross_prod_mat + (1.0 - cosval) * torch.ger(u, u)
40 | return R
41 |
42 |
43 | def angle_axis(angle, axis):
44 | # type: (float, np.ndarray) -> float
45 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
46 |
47 | Parameters
48 | ----------
49 | angle : float
50 | Angle to rotate by
51 | axis: np.ndarray
52 | Axis to rotate about
53 |
54 | Returns
55 | -------
56 | torch.Tensor
57 | 3x3 rotation matrix
58 | """
59 | u = axis / np.linalg.norm(axis)
60 | cosval, sinval = np.cos(angle), np.sin(angle)
61 |
62 | # yapf: disable
63 | cross_prod_mat = np.array([[0.0, -u[2], u[1]],
64 | [u[2], 0.0, -u[0]],
65 | [-u[1], u[0], 0.0]])
66 |
67 | R = torch.from_numpy(
68 | cosval * np.eye(3)
69 | + sinval * cross_prod_mat
70 | + (1.0 - cosval) * np.outer(u, u)
71 | )
72 | # yapf: enable
73 | return R.float()
74 |
75 |
76 | class PointcloudRandomScale(object):
77 | def __init__(self, lo=0.8, hi=1.25):
78 | self.lo, self.hi = lo, hi
79 |
80 | def __call__(self, points):
81 | scaler = np.random.uniform(self.lo, self.hi)
82 | points[:, 0:3] *= scaler
83 | return points
84 |
85 |
86 | class PointcloudRandomRotate(object):
87 | def __init__(self, axis=np.array([0.0, 1.0, 0.0])):
88 | self.axis = axis
89 |
90 | def __call__(self, points):
91 | rotation_angle = np.random.uniform() * 2 * np.pi
92 | axis = np.array([np.random.uniform(), np.random.uniform(), np.random.uniform()])
93 | axis = axis / np.sqrt(np.sum(axis * axis))
94 | rotation_matrix = angle_axis(rotation_angle, axis)
95 |
96 | normals = points.size(1) > 3
97 | if not normals:
98 | return torch.matmul(points, rotation_matrix.t())
99 | else:
100 | pc_xyz = points[:, 0:3]
101 | pc_normals = points[:, 3:]
102 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
103 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())
104 |
105 | return points
106 |
107 |
108 | class PointcloudRandomRotatePerturbation(object):
109 | def __init__(self, angle_sigma=0.06, angle_clip=0.18):
110 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip
111 |
112 | def _get_angles(self):
113 | angles = np.clip(
114 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip
115 | )
116 |
117 | return angles
118 |
119 | def __call__(self, points):
120 | angles = self._get_angles()
121 | Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0]))
122 | Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0]))
123 | Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0]))
124 |
125 | rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx)
126 |
127 | normals = points.size(1) > 3
128 | if not normals:
129 | return torch.matmul(points, rotation_matrix.t())
130 | else:
131 | pc_xyz = points[:, 0:3]
132 | pc_normals = points[:, 3:]
133 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
134 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())
135 |
136 | return points
137 |
138 |
139 | class PointcloudJitter(object):
140 | def __init__(self, std=0.01, clip=0.05):
141 | self.std, self.clip = std, clip
142 |
143 | def __call__(self, points):
144 | jittered_data = (
145 | points.new(points.size(0), 3)
146 | .normal_(mean=0.0, std=self.std)
147 | .clamp_(-self.clip, self.clip)
148 | )
149 | points[:, 0:3] += jittered_data
150 | return points
151 |
152 |
153 | class PointcloudNormalize(object):
154 | def __init__(self, max_size=1.0):
155 |
156 | self.max_size = max_size
157 |
158 | def __call__(self, points):
159 |
160 | points_max, _ = torch.max(points, dim=0)
161 | points_min, _ = torch.min(points, dim=0)
162 | points_center = (points_max + points_min) / 2
163 | points = points - points_center[None, :]
164 | max_radius = torch.max(torch.sqrt(torch.sum(points * points, dim=1)))
165 | points = points / max_radius * self.max_size / 2.0
166 | return points
167 |
168 |
169 | class PointcloudRandomPermutation(object):
170 | def __call__(self, points):
171 | num = points.shape[0]
172 | idxs = torch.randperm(num).type(torch.LongTensor)
173 | points = torch.index_select(points, 0, idxs).clone()
174 | return points
175 |
176 |
177 | class PointcloudRandomTranslate(object):
178 | def __init__(self, translate_range=0.1):
179 | self.translate_range = translate_range
180 |
181 | def __call__(self, points):
182 | translation = np.random.uniform(-self.translate_range, self.translate_range)
183 | points[:, 0:3] += translation
184 | return points
185 |
186 |
187 | class PointcloudToTensor(object):
188 | def __call__(self, points):
189 | return torch.from_numpy(points).float()
190 |
191 |
192 | class PointcloudRandomInputDropout(object):
193 | def __init__(self, max_dropout_ratio=0.875):
194 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
195 | self.max_dropout_ratio = max_dropout_ratio
196 |
197 | def __call__(self, points):
198 | pc = points.numpy()
199 |
200 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875
201 | drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0]
202 | if len(drop_idx) > 0:
203 | pc[drop_idx] = pc[0] # set to the first point
204 |
205 | return torch.from_numpy(pc).float()
206 |
207 |
208 | class PointcloudTranslate(object):
209 | def __init__(self, translation=np.array([0.0, 0.1, 0.0])):
210 | '''
211 | :param translation: pytorch tensor, translation vector(x,y,z)
212 | '''
213 | self.translation = torch.from_numpy(translation)
214 |
215 |
216 | def __call__(self, points):
217 | '''
218 |
219 | :param points: ... , num_of_points, 3
220 | :return: points after trans
221 | '''
222 | translation = self.translation
223 |
224 | if points.is_cuda is True:
225 | translation = translation.cuda()
226 | translation.requires_grad = False
227 |
228 | respoints = points[..., 0:3] + translation
229 | return respoints
230 |
231 |
232 | class PointcloudScale(object):
233 | def __init__(self, scaler):
234 | self.scaler = scaler
235 |
236 | def __call__(self, points):
237 |
238 | respoints = points * self.scaler
239 | return respoints
240 |
241 |
242 | class PointcloudRotate(object):
243 | def __init__(self, angle_in_degree=np.pi, axis=np.array([0.0, 1.0, 0.0]), is_cuda=True):
244 | self.axis = axis
245 | self.angle_in_degree = angle_in_degree
246 | self.rotation_matrix_t = angle_axis(self.angle_in_degree, self.axis).t()
247 |
248 | def __call__(self, points):
249 |
250 | '''
251 | :param points: ... , num_of_points, 3
252 | :return: points after rotate
253 | '''
254 | rotation_matrix_t = self.rotation_matrix_t.clone()
255 | if points.is_cuda is True:
256 | rotation_matrix_t = rotation_matrix_t.cuda()
257 | tpoints = torch.matmul(points, rotation_matrix_t)
258 |
259 | return tpoints
260 |
261 |
262 | def GenPointcloudRandomTransformFunction(max_rot_angle=2*np.pi):
263 | scale_lo = 0.8
264 | scale_hi = 1.25
265 | scaler = np.random.uniform(scale_lo, scale_hi)
266 | scale_func = PointcloudScale(scaler)
267 |
268 | rotation_angle = np.random.uniform() * max_rot_angle
269 | rotation_axis = np.array([np.random.uniform(), np.random.uniform(), np.random.uniform()])
270 | rotation_axis = rotation_axis / np.linalg.norm(rotation_axis)
271 | rotation_func = PointcloudRotate(rotation_angle, rotation_axis)
272 |
273 | trans_func = transforms.Compose([scale_func, rotation_func])
274 |
275 | return trans_func
276 |
277 |
278 | def AddTransformsToBatchPoints(points, num_of_trans, max_rot_angle=2*np.pi):
279 | '''
280 |
281 | :param points:bn, num_of_points, 3
282 | :return: points: (num_of_trans, bn, num_of_points, 3)
283 | transform
284 | '''
285 | transfunc_list = []
286 | res_points = None
287 | for trans_i in range(0, num_of_trans):
288 | transf = GenPointcloudRandomTransformFunction(max_rot_angle)
289 | transfunc_list.append(transf)
290 | tpoints = transf(points)
291 | if res_points is None:
292 | res_points = tpoints[None, :, :, :]
293 | else:
294 | res_points = torch.cat((res_points, tpoints[None, :, :, :]), dim=0)
295 |
296 | return res_points, transfunc_list
297 |
298 |
299 | class PointcloudRotateFuns(object):
300 | def __init__(self, rot_mats):
301 | '''
302 | :param rot_mats: bn, 3, 3
303 | '''
304 |
305 | self.rot_mats = rot_mats
306 |
307 | def __call__(self, points):
308 | '''
309 |
310 | :param points: bn, n , 3
311 | :return:
312 | '''
313 | if points.is_cuda is True:
314 | tmp_rot = self.rot_mats.cuda()
315 | else:
316 | tmp_rot = self.rot_mats
317 | transed_poitns = torch.transpose(torch.matmul(tmp_rot, torch.transpose(points, 1, 2)), 1, 2)
318 | return transed_poitns
319 |
320 |
321 | def AddPCATransformsToBatchPoints(points, num_of_trans):
322 | trans_points_all = None
323 | rot_mats_all = None
324 |
325 | transfunc_list = []
326 |
327 | for bi in range(points.shape[0]):
328 |
329 | np_points = points[bi].cpu().numpy()
330 | pca_axis_raw = point_cloud_utils.compute_pca(np_points)
331 | rot_mats = None
332 | trans_points = None
333 | for ti in range(num_of_trans):
334 | tmp_idx = np.array([0, 1, 2])
335 | pca_axis = pca_axis_raw[tmp_idx, :]
336 | tmp_sign = np.random.randint(2, size=2)
337 | tmp_sign[tmp_sign == 0] = -1
338 | pca_axis[0, :] = pca_axis[0, :] * tmp_sign[0]
339 | pca_axis[1, :] = pca_axis[1, :] * tmp_sign[1]
340 | pca_axis[2, :] = np.cross(pca_axis[0, :], pca_axis[1, :])
341 |
342 | rot_mat = torch.from_numpy(pca_axis)
343 | if points.is_cuda:
344 | rot_mat = rot_mat.cuda()
345 |
346 | if rot_mats is None:
347 | rot_mats = rot_mat[None, :, :]
348 | else:
349 | rot_mats = torch.cat((rot_mats, rot_mat[None, :, :]), dim=0)
350 |
351 | tmp_trans_points = torch.transpose(torch.matmul(rot_mat, torch.transpose(points[bi], 0, 1)), 0, 1)
352 |
353 | if trans_points is None:
354 | trans_points = tmp_trans_points[None, :, :]
355 | else:
356 | trans_points = torch.cat((trans_points, tmp_trans_points[None, :, :]), dim=0)
357 |
358 | if trans_points_all is None:
359 | trans_points_all = trans_points[:, None, :, :]
360 | else:
361 | trans_points_all = torch.cat((trans_points_all, trans_points[:, None, :, :]), dim=1)
362 |
363 | if rot_mats_all is None:
364 | rot_mats_all = rot_mats[:, None, :, :]
365 | else:
366 | rot_mats_all = torch.cat((rot_mats_all, rot_mats[:, None, :, :]), dim=1)
367 |
368 | for ti in range(num_of_trans):
369 | trans_func = PointcloudRotateFuns(rot_mats_all[ti, :, :, :])
370 | transfunc_list.append(trans_func)
371 |
372 | return trans_points_all, rot_mats_all, transfunc_list
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
--------------------------------------------------------------------------------
/data_utils/dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Author: An Tao
5 | @Contact: ta19@mails.tsinghua.edu.cn
6 | @File: dataset.py
7 | @Time: 2020/1/2 10:26 AM
8 | """
9 |
10 | import os
11 | import torch
12 | import json
13 | import h5py
14 | from glob import glob
15 | import numpy as np
16 | import torch.utils.data as data
17 |
18 | shapenetpart_cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4,
19 | 'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9,
20 | 'motorbike': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15}
21 | shapenetpart_seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
22 | shapenetpart_seg_start_index = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]
23 |
24 |
25 | def translate_pointcloud(pointcloud):
26 | xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3])
27 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
28 |
29 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
30 | return translated_pointcloud
31 |
32 |
33 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
34 | N, C = pointcloud.shape
35 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
36 | return pointcloud
37 |
38 |
39 | def rotate_pointcloud(pointcloud):
40 | theta = np.pi * 2 * np.random.rand()
41 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
42 | pointcloud[:, [0, 2]] = pointcloud[:, [0, 2]].dot(rotation_matrix) # random rotation (x,z)
43 | return pointcloud
44 |
45 |
46 | class Dataset(data.Dataset):
47 | def __init__(self, root, dataset_name='modelnet40', class_choice=None,
48 | num_points=2048, split='train', load_name=True, load_file=True,
49 | segmentation=False, random_rotate=False, random_jitter=False,
50 | random_translate=False):
51 |
52 | assert dataset_name.lower() in ['shapenetcorev2', 'shapenetpart',
53 | 'modelnet10', 'modelnet40', 'shapenetpartpart']
54 | assert num_points <= 2048
55 |
56 | if dataset_name in ['shapenetcorev2', 'shapenetpart', 'shapenetpartpart']:
57 | assert split.lower() in ['train', 'test', 'val', 'trainval', 'all']
58 | else:
59 | assert split.lower() in ['train', 'test', 'all']
60 |
61 | if dataset_name not in ['shapenetpart'] and segmentation == True:
62 | raise AssertionError
63 |
64 | self.root = os.path.join(root, dataset_name + '_' + '*hdf5_2048')
65 | self.dataset_name = dataset_name
66 | self.class_choice = class_choice
67 | self.num_points = num_points
68 | self.split = split
69 | self.load_name = load_name
70 | self.load_file = load_file
71 | self.segmentation = segmentation
72 | self.random_rotate = random_rotate
73 | self.random_jitter = random_jitter
74 | self.random_translate = random_translate
75 |
76 | self.path_h5py_all = []
77 | self.path_name_all = []
78 | self.path_file_all = []
79 |
80 | if self.split in ['train', 'trainval', 'all']:
81 | self.get_path('train')
82 | if self.dataset_name in ['shapenetcorev2', 'shapenetpart', 'shapenetpartpart']:
83 | if self.split in ['val', 'trainval', 'all']:
84 | self.get_path('val')
85 | if self.split in ['test', 'all']:
86 | self.get_path('test')
87 |
88 | self.path_h5py_all.sort()
89 | data, label, seg = self.load_h5py(self.path_h5py_all)
90 |
91 | if self.load_name or self.class_choice != None:
92 | self.path_name_all.sort()
93 | self.name = np.array(self.load_json(self.path_name_all)) # load label name
94 |
95 | if self.load_file:
96 | self.path_file_all.sort()
97 | self.file = np.array(self.load_json(self.path_file_all)) # load file name
98 |
99 | self.data = np.concatenate(data, axis=0)
100 | self.label = np.concatenate(label, axis=0)
101 | if self.segmentation:
102 | self.seg = np.concatenate(seg, axis=0)
103 |
104 | if self.class_choice != None:
105 | indices = (self.name == class_choice)
106 | self.data = self.data[indices]
107 | self.label = self.label[indices]
108 | self.name = self.name[indices]
109 | if self.segmentation:
110 | self.seg = self.seg[indices]
111 | id_choice = shapenetpart_cat2id[class_choice]
112 | self.seg_num_all = shapenetpart_seg_num[id_choice]
113 | self.seg_start_index = shapenetpart_seg_start_index[id_choice]
114 | if self.load_file:
115 | self.file = self.file[indices]
116 | elif self.segmentation:
117 | self.seg_num_all = 50
118 | self.seg_start_index = 0
119 |
120 | def get_path(self, type):
121 | path_h5py = os.path.join(self.root, '*%s*.h5' % type)
122 | self.path_h5py_all += glob(path_h5py)
123 | if self.load_name:
124 | path_json = os.path.join(self.root, '%s*_id2name.json' % type)
125 | self.path_name_all += glob(path_json)
126 | if self.load_file:
127 | path_json = os.path.join(self.root, '%s*_id2file.json' % type)
128 | self.path_file_all += glob(path_json)
129 | return
130 |
131 | def load_h5py(self, path):
132 | all_data = []
133 | all_label = []
134 | all_seg = []
135 | for h5_name in path:
136 | f = h5py.File(h5_name, 'r+')
137 | data = f['data'][:].astype('float32')
138 | label = f['label'][:].astype('int64')
139 | if self.segmentation:
140 | seg = f['seg'][:].astype('int64')
141 | f.close()
142 | all_data.append(data)
143 | all_label.append(label)
144 | if self.segmentation:
145 | all_seg.append(seg)
146 | return all_data, all_label, all_seg
147 |
148 | def load_json(self, path):
149 | all_data = []
150 | for json_name in path:
151 | j = open(json_name, 'r+')
152 | data = json.load(j)
153 | all_data += data
154 | return all_data
155 |
156 | def __getitem__(self, item):
157 | point_set = self.data[item][:self.num_points]
158 | label = self.label[item]
159 | if self.load_name:
160 | name = self.name[item] # get label name
161 | if self.load_file:
162 | file = self.file[item] # get file name
163 |
164 | if self.random_rotate:
165 | point_set = rotate_pointcloud(point_set)
166 | if self.random_jitter:
167 | point_set = jitter_pointcloud(point_set)
168 | if self.random_translate:
169 | point_set = translate_pointcloud(point_set)
170 |
171 | # convert numpy array to pytorch Tensor
172 | point_set = torch.from_numpy(point_set)
173 | label = torch.from_numpy(np.array([label]).astype(np.int64))
174 | label = label.squeeze(0)
175 |
176 | if self.segmentation:
177 | seg = self.seg[item][:self.num_points]
178 | seg = torch.from_numpy(seg)
179 | return point_set, label, seg, name, file
180 | else:
181 | return point_set, label, name, file
182 |
183 | def __len__(self):
184 | return self.data.shape[0]
185 |
186 |
187 | if __name__ == '__main__':
188 | # root = os.getcwd()
189 |
190 | # choose dataset name from 'shapenetcorev2', 'shapenetpart', 'modelnet40' and 'modelnet10'
191 | dataset_name = 'shapenetpart'
192 |
193 | # choose split type from 'train', 'test', 'all', 'trainval' and 'val'
194 | # only shapenetcorev2 and shapenetpart dataset support 'trainval' and 'val'
195 | split = 'train'
196 | segmentation = True
197 | d = Dataset(root='../../', dataset_name=dataset_name, class_choice='motorbike', num_points=2048, split=split, segmentation=segmentation)
198 | print("datasize:", d.__len__())
199 |
200 | item = 1
201 | if segmentation:
202 | ps, lb, seg, n, f = d[item]
203 | print(ps.size(), ps.type(), seg.size(), seg.type, lb.size(), lb.type(), n, f)
204 | print(seg)
205 | else:
206 | ps, lb, n, f = d[item]
207 | print(ps.size(), ps.type(), lb.size(), lb.type(), n, f)
208 |
--------------------------------------------------------------------------------
/data_utils/keypointnet_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from .ModelNetDataLoader import pc_normalize
3 | import json
4 | import numpy as np
5 |
6 |
7 | def naive_read_pcd(path):
8 | lines = open(path, 'r').readlines()
9 | idx = -1
10 | for i, line in enumerate(lines):
11 | if line.startswith('DATA ascii'):
12 | idx = i + 1
13 | break
14 | lines = lines[idx:]
15 | lines = [line.rstrip().split(' ') for line in lines]
16 | data = np.asarray(lines)
17 | pc = np.array(data[:, :3], dtype=float)
18 | return pc
19 |
20 |
21 | class KeyPointNetDataLoader(data.Dataset):
22 | def __init__(self, num_points, json_path, pcd_path, split='train'):
23 | super().__init__()
24 | self.num_points = num_points
25 | self.data_path = json.load(open(json_path))
26 | self.pcd_path = pcd_path
27 | if split == 'train':
28 | self.data_path = self.data_path[:-200]
29 | else:
30 | self.data_path = self.data_path[-200:]
31 |
32 | def __getitem__(self, idx):
33 | # res = {}
34 | class_id = self.data_path[idx]['class_id']
35 | model_id = self.data_path[idx]['model_id']
36 | points = naive_read_pcd(r'{}/{}/{}.pcd'.format(self.pcd_path, class_id, model_id)).astype(np.float32)
37 | points = pc_normalize(points)
38 | ground_truths = np.array(
39 | [points[point['pcd_info']['point_index']] for point in self.data_path[idx]['keypoints']])
40 | ground_truths_num = ground_truths.shape[0]
41 | ground_truths = np.pad(ground_truths, ((0, 18 - ground_truths_num), (0, 0,)), 'constant',
42 | constant_values=(0, 0))
43 | return points[:self.num_points, :], ground_truths, ground_truths_num, ground_truths_num
44 |
45 | def __len__(self):
46 | return len(self.data_path)
47 |
48 |
49 | def paint(points_xyz, origin_xyz):
50 | import matplotlib.pyplot as plt
51 |
52 | x1 = points_xyz[:, 0]
53 | y1 = points_xyz[:, 1]
54 | z1 = points_xyz[:, 2]
55 |
56 | x2 = origin_xyz[:, 0]
57 | y2 = origin_xyz[:, 1]
58 | z2 = origin_xyz[:, 2]
59 |
60 | ax1 = plt.subplot(111, projection='3d')
61 | ax1.scatter(x1, y1, z1, c=COLOR_LIST[:points_xyz.shape[0], :] / 255, s=48)
62 | ax1.scatter(x2, y2, z2, c='#A9A9A9', s=1)
63 | ax1.axis('off')
64 | plt.show()
65 |
66 |
67 | def create_color_list(num):
68 | import random
69 | colors = np.ndarray(shape=(num, 3))
70 | for i in range(0, num):
71 | colors[i, 0] = random.randint(0, 255)
72 | colors[i, 1] = random.randint(0, 255)
73 | colors[i, 2] = random.randint(100, 255)
74 |
75 | colors[0, :] = np.array([0, 0, 0]).astype(int)
76 | colors[1, :] = np.array([146, 61, 10]).astype(int)
77 | colors[2, :] = np.array([102, 97, 0]).astype(int)
78 | colors[3, :] = np.array([255, 0, 0]).astype(int)
79 | colors[4, :] = np.array([113, 0, 17]).astype(int)
80 | colors[5, :] = np.array([255, 127, 39]).astype(int)
81 | colors[6, :] = np.array([255, 242, 0]).astype(int)
82 | colors[7, :] = np.array([0, 255, 0]).astype(int)
83 | colors[8, :] = np.array([0, 0, 255]).astype(int)
84 | colors[9, :] = np.array([15, 77, 33]).astype(int)
85 | colors[10, :] = np.array([163, 73, 164]).astype(int)
86 | colors[11, :] = np.array([255, 174, 201]).astype(int)
87 | colors[12, :] = np.array([255, 220, 14]).astype(int)
88 | colors[13, :] = np.array([181, 230, 29]).astype(int)
89 | colors[14, :] = np.array([153, 217, 234]).astype(int)
90 | colors[15, :] = np.array([112, 146, 190]).astype(int)
91 |
92 | return colors
93 |
94 |
95 | COLOR_LIST = create_color_list(200)
96 |
97 | if __name__ == '__main__':
98 | import torch
99 |
100 | data = KeyPointNetDataLoader(json_path='./keypointnet/annotations/mug.json', pcd_path='../keypointnet/pcds',
101 | split='val',num_points=1024)
102 |
103 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True, num_workers=10, drop_last=True)
104 | print(len(DataLoader))
105 | for point, label, ground_truths_num in DataLoader:
106 | paint(label[0], point[0])
107 | # print(point.shape)
108 | # print(label.shape)
109 | # print(ground_truths_num)
110 |
--------------------------------------------------------------------------------
/data_utils/shapenet_seg_dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import torch.utils.data as data
9 | import os
10 | import ntpath
11 | import pickle
12 | import re
13 |
14 |
15 | def natural_sort_key(string_):
16 | """See http://www.codinghorror.com/blog/archives/001018.html"""
17 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)]
18 |
19 |
20 | def sort_strs_by_number(strs):
21 | return sorted(strs, key=natural_sort_key)
22 |
23 |
24 | class ShapenetSegDataloader(data.Dataset):
25 | def __init__(self, data_root, category, preload_data=True, train=True):
26 | super().__init__()
27 | self.category = category
28 | self.data_root = data_root
29 | self.data_path = os.path.join(data_root, category)
30 | self.sub_dirs = [ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()]
31 | #self.sub_dirs = sort_strs_by_number([ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()])
32 | self.data_num = len(self.sub_dirs)
33 | self.train = train
34 |
35 | self.meta_data_list = None
36 | if preload_data:
37 | self.meta_data_list = []
38 | for i in range(self.data_num):
39 | meta_fname = os.path.join(self.data_path, self.sub_dirs[i], 'meta.pkl')
40 | with open(meta_fname, 'rb') as f:
41 | meta_data = pickle.load(f)
42 | self.meta_data_list.append(meta_data)
43 |
44 | def __getitem__(self, idx):
45 | if self.meta_data_list is None:
46 | meta_fname = os.path.join(self.data_path, self.sub_dirs[idx], 'meta.pkl')
47 | f = open(meta_fname, 'rb')
48 | meta_data = pickle.load(f)
49 | else:
50 | meta_data = self.meta_data_list[idx]
51 |
52 | points = meta_data['points']
53 | gt_points = meta_data['gt_points']
54 | gt_points_labels = meta_data['gt_points_labels'] - 1
55 | meta_data = {'points': points}
56 |
57 | if self.train is not True:
58 | meta_data['gt_points'] = gt_points
59 | meta_data['gt_points_labels'] = gt_points_labels
60 | return meta_data
61 |
62 | def __len__(self):
63 | return self.data_num
--------------------------------------------------------------------------------
/images/Pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/images/Pipeline.png
--------------------------------------------------------------------------------
/images/VisualizationResults.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/images/VisualizationResults.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/models/__init__.py
--------------------------------------------------------------------------------
/models/chamfer_distance.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import torch
9 | import torch.nn as nn
10 | from torch_pointnet_utils import query_ball_point, index_points
11 | import torch.nn.functional as F
12 |
13 |
14 | def query_KNN_tensor(points, query_pts, k):
15 | '''
16 |
17 | :param points: bn x n x 3
18 | :param query_pts: bn x m x 3
19 | :param k: num of neighbors
20 | :return: nb x m x k ids, sorted_squared_dis
21 | '''
22 |
23 | diff = query_pts[:, :, None, :] - points[:, None, :, :]
24 |
25 | squared_dis = torch.sum(diff * diff, dim=3) # bn x m x n
26 | sorted_squared_dis, sorted_idxs = torch.sort(squared_dis, dim=2)
27 | sorted_idxs = sorted_idxs[:, :, :k]
28 | sorted_squared_dis = sorted_squared_dis[:, :, :k]
29 |
30 | return sorted_idxs, sorted_squared_dis
31 |
32 |
33 | def compute_chamfer_distance(p1, p2):
34 | '''
35 | Calculate Chamfer Distance between two point sets
36 | :param p1: size[bn, N, D]
37 | :param p2: size[bn, M, D]
38 | :return: sum of Chamfer Distance of two point sets
39 | '''
40 |
41 | diff = p1[:, :, None, :] - p2[:, None, :, :]
42 | dist = torch.sum(diff * diff, dim=3)
43 | dist1 = dist
44 | dist2 = torch.transpose(dist, 1, 2)
45 |
46 | dist_min1, _ = torch.min(dist1, dim=2)
47 | dist_min2, _ = torch.min(dist2, dim=2)
48 |
49 | return dist_min1, dist_min2
50 |
51 |
52 | # def compute_vector_distance(p1, p2):
53 | # '''
54 | # Calculate Chamfer Distance between two point sets
55 | # :param p1: size[bn, N, D]
56 | #
57 | # :return: sum of Chamfer Distance of two point sets
58 | # '''
59 | # # device = p1.device
60 | #
61 | # vectors = p1[:, :, None, :] - p1[:, None, :, :]
62 | # # vectors=vectors.float()
63 | # vectors = torch.flatten(vectors, start_dim=1, end_dim=2)
64 | # vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3)
65 | # vectors_dot = torch.abs(vectors_dot)
66 | #
67 | #
68 | # # edge_dict = compute_edge_distance(p1, p2, maxdis=0.001)
69 | # # idx = torch.zeros(vectors_dot.shape).cuda()
70 | # # idx_num = 0
71 | # # for key, values in edge_dict.items():
72 | # # for value in values:
73 | # # idx[key, value[0], value[1]] = 1
74 | # # idx_num += 1
75 | # #
76 | # # vectors_dot = vectors_dot * idx
77 | #
78 | #
79 | # # print(vectors_dot[vectors_dot>0])
80 | # vectors_abs = torch.sqrt(torch.sum(vectors * vectors, dim=2, keepdim=True))
81 | # vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3)
82 | # vectors_abs_dot[vectors_abs_dot == 0] = 1
83 | # vectors_cos = vectors_dot / vectors_abs_dot
84 | #
85 | # return torch.mean(vectors_cos)
86 | #
87 | # # return torch.sum(vectors_cos) / idx_num
88 |
89 |
90 | def compute_vector_distance(p1, p2=None):
91 | '''
92 | Calculate Chamfer Distance between two point sets
93 | :param p1: size[bn, N, D]
94 | :param p2: size[bn, M, D]
95 | :return: sum of Chamfer Distance of two point sets
96 | '''
97 |
98 | vectors = p1[:, :, None, :] - p1[:, None, :, :]
99 |
100 | vectors = torch.flatten(vectors, start_dim=1, end_dim=2)
101 | vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3)
102 | vectors_dot = torch.abs(vectors_dot)
103 |
104 | edge_dict = compute_edge_distance(p1, p2, maxdis=0.001)
105 | idx = torch.zeros(vectors_dot.shape).cuda()
106 | idx_num = 1
107 | for key, values in edge_dict.items():
108 | for value in values:
109 | idx[key, value[0], value[1]] = 1
110 | idx_num += 1
111 | vectors_abs = torch.sqrt(torch.sum(vectors * vectors + 1e-5, dim=2, keepdim=True))
112 | vectors_dot = vectors_dot * idx
113 | # vectors_abs = torch.sum(vectors * vectors, dim=2, keepdim=True)
114 | vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3)
115 | vectors_abs_dot[vectors_abs_dot == 0] = 1
116 | vectors_cos = vectors_dot / vectors_abs_dot
117 | return torch.sum(vectors_cos) / idx_num
118 |
119 |
120 | def compute_edge_distance(p1, p2, maxdis=0.001):
121 | point_num = p1.shape[1]
122 | edge_list = []
123 | batch_dict = {}
124 | # for batch in range(p1.shape[0]):
125 | for i in range(point_num):
126 | for j in range(i):
127 | start = p1[:, i, :]
128 | end = p1[:, j, :]
129 | dist = torch.mean(torch.sqrt(1e-3 + (torch.sum(torch.square(start - end), dim=-1))))
130 | count = 5
131 | device = dist.device
132 | f_interp = torch.linspace(0.0, 1.0, count).unsqueeze(0).unsqueeze(-1).to(device)
133 | b_interp = 1.0 - f_interp
134 | K = start.unsqueeze(-2) * f_interp + end.unsqueeze(-2) * b_interp
135 | dist1, dist2 = compute_chamfer_distance(K, p2)
136 | cdis = (torch.mean(dist1, dim=1))
137 | for x in torch.where(cdis < maxdis)[0].cpu().numpy():
138 | edge_list.append([x, i, j])
139 |
140 | for x in edge_list:
141 | if x[0] in batch_dict:
142 | batch_dict[x[0]].append(x[1:])
143 | else:
144 | batch_dict[x[0]] = [x[1:]]
145 |
146 | return batch_dict
147 |
148 |
149 | def compute_offset_distance(p1):
150 | return torch.sum(p1 * p1)
151 |
152 |
153 | def compute_vector_similarity_distance(p1, similarity_map, threshold):
154 | '''
155 | Calculate Chamfer Distance between two point sets
156 | :param p1: size[bn, N, D]
157 | :param p2: size[bn, M, D]
158 | :return: sum of Chamfer Distance of two point sets
159 | '''
160 |
161 | B, N, D = p1.shape
162 | device = p1.device
163 | vectors = p1[:, :, None, :] - p1[:, None, :, :]
164 | similarity_map[similarity_map < threshold] = 0
165 | vectors = vectors * similarity_map.view(B, N, N, 1).to(device)
166 | vectors = torch.flatten(vectors, start_dim=1, end_dim=2)
167 | vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3)
168 | idx_num = torch.sum(vectors_dot != 0)
169 | vectors_dot = torch.abs(vectors_dot)
170 | vectors_abs = torch.sqrt(torch.sum(vectors * vectors + 1e-9, dim=2, keepdim=True))
171 | vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3)
172 | vectors_abs_dot[vectors_abs_dot == 0] = 1
173 | vectors_cos = vectors_dot / vectors_abs_dot
174 | return torch.sum(vectors_cos) / idx_num
175 |
176 |
177 | def compute_end_distance(p1, p2, radius=0.1, nsample=16):
178 | '''
179 | Calculate Chamfer Distance between two point sets
180 | :param p1: size[bn, N, D]
181 | :param p2: size[bn, M, D]
182 | :return: sum of Chamfer Distance of two point sets
183 | '''
184 | B, S, C = p1.shape
185 | idx = query_ball_point(radius, nsample, p2, p1)
186 | p2 = index_points(p2, idx) # [B, npoint, nsample, C]
187 | p2_norm = p2 - p1.view(B, S, 1, C)
188 | # p2_norm = p2_norm / torch.norm(p2_norm, p=2, dim=3, keepdim=True)
189 | dist = torch.sum(p2_norm, dim=2)
190 | dist = torch.sqrt(torch.sum(dist * dist, dim=2))
191 | return torch.mean(dist)
192 |
193 |
194 | class ComputeCDLoss(nn.Module):
195 | def __init__(self):
196 | super(ComputeCDLoss, self).__init__()
197 |
198 | def forward(self, recon_points, gt_points):
199 | dist1, dist2 = compute_chamfer_distance(recon_points, gt_points)
200 | loss = (torch.sum(dist1) + torch.sum(dist2)) / ((recon_points.shape[0]) * gt_points.shape[1]) * 1024
201 | # print(torch.sum(dist1), torch.sum(dist2))
202 | # loss = (torch.mean(dist1) + torch.mean(dist2) * 32)
203 |
204 | return loss
205 |
206 |
207 | class ComputeVecLoss(nn.Module):
208 | def __init__(self):
209 | super(ComputeVecLoss, self).__init__()
210 |
211 | def forward(self, recon_points, gt_points):
212 | dist3 = compute_vector_distance(recon_points, gt_points)
213 | # loss_align = torch.sum(dist3) / (recon_points.shape[0])
214 | return dist3
215 |
216 |
217 | class ComputeEdgeLoss(nn.Module):
218 | def __init__(self):
219 | super(ComputeEdgeLoss, self).__init__()
220 |
221 | def forward(self, recon_points, gt_points):
222 | dist3 = compute_edge_distance(recon_points, gt_points)
223 | # loss_align = torch.sum(dist3) / (recon_points.shape[0])
224 | return dist3
225 |
226 |
227 | class ComputeOffsetLoss(nn.Module):
228 | def __init__(self):
229 | super(ComputeOffsetLoss, self).__init__()
230 |
231 | def forward(self, fps_points_offset):
232 | return compute_offset_distance(fps_points_offset) / (fps_points_offset.shape[0])
233 |
234 |
235 | class ComputeVecSimilarityLoss(nn.Module):
236 | def __init__(self):
237 | super(ComputeVecSimilarityLoss, self).__init__()
238 |
239 | def forward(self, gt_points, cos_similarity, threshold):
240 | return compute_vector_similarity_distance(gt_points, cos_similarity, threshold)
241 |
242 |
243 | class ComputeEndLoss(nn.Module):
244 | def __init__(self):
245 | super(ComputeEndLoss, self).__init__()
246 |
247 | def forward(self, recon_points, gt_points):
248 | dist1 = compute_end_distance(recon_points, gt_points, radius=0.1, nsample=16)
249 | loss = dist1 / recon_points.shape[1] * 24
250 |
251 | return loss
252 |
253 |
254 | if __name__ == '__main__':
255 | p1 = torch.rand((2, 16, 3))
256 | p2 = torch.rand((2, 128, 3))
257 | p3 = torch.rand(2, 16, 16)
258 | a = ComputeVecSimilarityLoss()
259 |
260 | print(a(p1, p3, 0.95))
261 |
--------------------------------------------------------------------------------
/models/model_weightchamfer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_pointnet_utils import PointNetSetAbstractionMsg
4 | from models import chamfer_distance
5 |
6 |
7 | class ComputeLoss3d(nn.Module):
8 | def __init__(self):
9 | super(ComputeLoss3d, self).__init__()
10 |
11 | def compute_chamfer_distance(self, p1, p2):
12 | '''
13 | Calculate Chamfer Distance between two point sets
14 | :param p1: size[bn, N, D]
15 | :param p2: size[bn, M, D]
16 | :return: sum of Chamfer Distance of two point sets
17 | '''
18 |
19 | diff = p1[:, :, None, :] - p2[:, None, :, :]
20 | dist = torch.sum(diff * diff, dim=3)
21 | dist_min1, _ = torch.min(dist, dim=2)
22 | dist_min2, _ = torch.min(dist, dim=1)
23 |
24 | return (torch.sum(dist_min1) + torch.sum(dist_min2)) / (p1.shape[0])
25 |
26 | def forward(self, gt_points, structure_points, origin_points=None):
27 | if origin_points is None:
28 | return self.compute_chamfer_distance(gt_points, structure_points)
29 | else:
30 | return self.compute_chamfer_distance(gt_points, structure_points) + self.compute_chamfer_distance(origin_points, structure_points)
31 |
32 |
33 | class VecLoss(nn.Module):
34 | def __init__(self):
35 | super(VecLoss, self).__init__()
36 | self.vec_loss_fun = chamfer_distance.ComputeVecSimilarityLoss()
37 |
38 | def forward(self, structure_points, similarity_map, threshold=0.95):
39 | structure_points = structure_points.cuda()
40 | self.vec_loss = self.vec_loss_fun(structure_points, similarity_map, threshold) * 10
41 | return self.vec_loss
42 |
43 |
44 | class WeightedChamferLoss(nn.Module):
45 | def __init__(self):
46 | super(WeightedChamferLoss, self).__init__()
47 |
48 | def compute_chamfer_distance(self, p1, p2):
49 | '''
50 | Calculate Chamfer Distance between two point sets
51 | :param p1: size[bn, N, D]
52 | :param p2: size[bn, M, D]
53 | :return: sum of Chamfer Distance of two point sets
54 | '''
55 |
56 | diff = p1[:, :, None, :] - p2[:, None, :, :]
57 | dist = torch.sum(diff * diff, dim=3)
58 | dist_min1, _ = torch.min(dist, dim=2)
59 | dist_min2, _ = torch.min(dist, dim=1)
60 |
61 | return (torch.mean(dist_min1) + torch.mean(dist_min2))
62 |
63 | def compute_end_distance(self, fps_points, structure_points, weight_map):
64 | '''
65 | Calculate Chamfer Distance between two point sets
66 | :param p1: size[bn, N, D]
67 | :param p2: size[bn, M, D]
68 | :return: sum of Chamfer Distance of two point sets
69 | '''
70 | # weight_map = torch.sigmoid(weight_map)
71 | weight_map = torch.sum(weight_map, dim=1)
72 | weight_map = torch.sigmoid(weight_map) + 1
73 | diff = structure_points[:, :, None, :] - fps_points[:, None, :, :]
74 | dist = torch.sum(diff * diff, dim=3)
75 | dist_min1, _ = torch.min(dist, dim=2)
76 | dist_min2, _ = torch.min(dist, dim=1)
77 | dist_min2 = dist_min2 * weight_map
78 | return (torch.mean(dist_min1) + torch.mean(dist_min2))
79 |
80 | def forward(self, fps_points, structure_points, weight_map, origin_points=None):
81 | if origin_points is None:
82 | return self.compute_end_distance(fps_points, structure_points, weight_map)
83 | else:
84 | return self.compute_end_distance(fps_points, structure_points, weight_map) + self.compute_chamfer_distance(origin_points, structure_points)
85 |
86 |
87 | class Conv1dProbLayer(nn.Module):
88 | def __init__(self, in_channels, out_channels, out=False, kernel_size=1, dropout=0.2, normalize=False):
89 | super(Conv1dProbLayer, self).__init__()
90 | self.out = out
91 | self.dropout_conv_bn_layer = nn.Sequential(
92 | nn.Dropout(dropout),
93 | nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size),
94 | nn.BatchNorm1d(num_features=out_channels),
95 | )
96 | self.relu = nn.ReLU()
97 | self.softmax = nn.Softmax(dim=2)
98 | self.normalize = normalize
99 | self.normalize_layer = Normalize(dim=1)
100 |
101 | def forward(self, x):
102 | x = self.dropout_conv_bn_layer(x)
103 | if self.normalize:
104 | x = self.normalize_layer(x)
105 | if self.out:
106 | x = self.softmax(x)
107 | else:
108 | x = self.relu(x)
109 | return x
110 |
111 |
112 | class RefineNet(nn.Module):
113 | def __init__(self, num_structure_points, in_channel=128 + 256 + 256, out=True, normalize=False):
114 | super(RefineNet, self).__init__()
115 |
116 | conv1d_stpts_prob_modules = []
117 | if num_structure_points <= in_channel:
118 | conv1d_stpts_prob_modules.append(
119 | Conv1dProbLayer(in_channels=in_channel, out_channels=512, kernel_size=1, out=False))
120 | in_channels = 512
121 | while in_channels >= num_structure_points * 2:
122 | out_channels = int(in_channels / 2)
123 | conv1d_stpts_prob_modules.append(
124 | Conv1dProbLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, out=False))
125 | in_channels = out_channels
126 | conv1d_stpts_prob_modules.append(
127 | Conv1dProbLayer(in_channels=in_channels, out_channels=num_structure_points, kernel_size=1, out=out,
128 | normalize=normalize))
129 | else:
130 | conv1d_stpts_prob_modules.append(
131 | Conv1dProbLayer(in_channels=in_channel, out_channels=1024, kernel_size=1, out=False))
132 | in_channels = 1024
133 | while in_channels <= num_structure_points / 2:
134 | out_channels = int(in_channels * 2)
135 | conv1d_stpts_prob_modules.append(
136 | Conv1dProbLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, out=False))
137 | in_channels = out_channels
138 | conv1d_stpts_prob_modules.append(
139 | Conv1dProbLayer(in_channels=in_channels, out_channels=num_structure_points, kernel_size=1, out=out,
140 | normalize=normalize))
141 |
142 | self.conv1d_stpts_prob = nn.Sequential(*conv1d_stpts_prob_modules)
143 |
144 | def forward(self, features):
145 | return self.conv1d_stpts_prob(features)
146 |
147 |
148 |
149 |
150 | class FeatureMergeBlock(nn.Module):
151 | def __init__(self, num_structure_points, boom_rate=2):
152 | super(FeatureMergeBlock, self).__init__()
153 | self.relu = nn.ReLU(inplace=True)
154 | self.merge_layer = nn.Sequential(
155 | nn.Conv2d(num_structure_points, boom_rate * num_structure_points, kernel_size=1),
156 | nn.BatchNorm2d(boom_rate * num_structure_points),
157 | nn.ReLU(inplace=True),
158 | nn.Conv2d(boom_rate * num_structure_points, num_structure_points, kernel_size=1),
159 | nn.BatchNorm2d(num_structure_points),
160 | )
161 |
162 | def forward(self, features):
163 | return self.relu(self.merge_layer(features) + features)
164 | # return self.merge_layer(features) + features
165 |
166 |
167 | class Normalize(nn.Module):
168 | def __init__(self, dim):
169 | super().__init__()
170 | self.dim = dim
171 |
172 | def forward(self, x):
173 | norm = torch.norm(x, p=2, dim=self.dim, keepdim=True)
174 | return x / norm
175 |
176 |
177 | class Pointnet2StructurePointNet(nn.Module):
178 |
179 | def __init__(self, num_structure_points, input_channels=3, multi_distribution_num=1, offset=False,
180 | merge_block_num=1):
181 | super(Pointnet2StructurePointNet, self).__init__()
182 | self.point_dim = 3
183 | self.num_structure_points = num_structure_points
184 | self.offset = offset
185 | self.input_channels = input_channels
186 | self.num_structure_points = num_structure_points
187 |
188 | self.sa1 = PointNetSetAbstractionMsg(npoint=512, radius_list=[0.1, 0.2, 0.4], nsample_list=[16, 32, 128],
189 | in_channel=0, mlp_list=[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
190 | self.sa2 = PointNetSetAbstractionMsg(npoint=128, radius_list=[0.2, 0.4, 0.8], nsample_list=[32, 64, 128],
191 | in_channel=64 + 128 + 128,
192 | mlp_list=[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
193 |
194 | self.multi_distri_layers = nn.ModuleList()
195 | for i in range(multi_distribution_num):
196 | self.multi_distri_layers.append(RefineNet(num_structure_points, in_channel=128 + 256 + 256, out=False))
197 |
198 | feature_merge = []
199 | for _ in range(merge_block_num):
200 | feature_merge.append(FeatureMergeBlock(num_structure_points=num_structure_points))
201 | self.feature_merge = nn.Sequential(*feature_merge)
202 | self.softmax = nn.Softmax(dim=2)
203 |
204 | def forward(self, pointcloud):
205 | '''
206 | :param pointcloud: input point cloud with shape (bn, num_of_pts, 3)
207 | :param return_weighted_feature: whether return features for the structure points or not
208 | :return:
209 | '''
210 | _ = None
211 | B = pointcloud.shape[0]
212 | if pointcloud.shape[2] == 3:
213 | pointcloud = pointcloud.permute(0, 2, 1)
214 |
215 | if pointcloud.shape[1] > 3:
216 | xyz = pointcloud[:, :3, :]
217 | features = pointcloud[:, 3:, :]
218 | else:
219 | xyz = pointcloud
220 | features = None
221 | xyz, features = self.sa1(xyz, features)
222 | xyz, features = self.sa2(xyz, features)
223 |
224 | stpts_prob_map = []
225 | for i in range(len(self.multi_distri_layers)):
226 | stpts_prob_map.append(self.multi_distri_layers[i](features))
227 |
228 | stpts_prob_map = torch.stack(stpts_prob_map, dim=3)
229 | stpts_prob_map = torch.max(self.feature_merge(stpts_prob_map), dim=3)[0]
230 | stpts_prob_map = self.softmax(stpts_prob_map)
231 |
232 | # (4,16,128) *(4,3,128)
233 | if not xyz.shape[2] == 3:
234 | xyz = xyz.permute(0, 2, 1)
235 |
236 | structure_points = torch.sum(stpts_prob_map[:, :, :, None] * xyz[:, None, :, :], dim=2)
237 | # features = features * torch.sum(stpts_prob_map, dim=1, keepdim=True)
238 | # weighted_features = torch.sum(stpts_prob_map[:, None, :, :] * features[:, :, None, :], dim=3)
239 | # cos_similarity = SimilarityPoints(weighted_features)
240 | cos_similarity = None
241 | return structure_points, xyz, cos_similarity, stpts_prob_map
242 |
243 |
244 | if __name__ == '__main__':
245 | data = torch.rand(2, 3, 1024)
246 | print("===> testing pointSPN ...")
247 | model = Pointnet2StructurePointNet(num_structure_points=16, offset=False)
248 | data = data.cuda()
249 | model = model.cuda()
250 | structure_points, xyz, cos_similarity, stpts_prob_map = model(data)
251 | loss = ComputeLoss3d()
252 | loss = loss.cuda()
253 | print(loss(xyz, structure_points))
254 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.utils.data
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.nn.functional as F
8 |
9 |
10 | class STN3d(nn.Module):
11 | def __init__(self, channel):
12 | super(STN3d, self).__init__()
13 | self.conv1 = torch.nn.Conv1d(channel, 64, 1)
14 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
15 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
16 | self.fc1 = nn.Linear(1024, 512)
17 | self.fc2 = nn.Linear(512, 256)
18 | self.fc3 = nn.Linear(256, 9)
19 | self.relu = nn.ReLU()
20 |
21 | self.bn1 = nn.BatchNorm1d(64)
22 | self.bn2 = nn.BatchNorm1d(128)
23 | self.bn3 = nn.BatchNorm1d(1024)
24 | self.bn4 = nn.BatchNorm1d(512)
25 | self.bn5 = nn.BatchNorm1d(256)
26 |
27 | def forward(self, x):
28 | batchsize = x.size()[0]
29 | x = F.relu(self.bn1(self.conv1(x)))
30 | x = F.relu(self.bn2(self.conv2(x)))
31 | x = F.relu(self.bn3(self.conv3(x)))
32 | x = torch.max(x, 2, keepdim=True)[0]
33 | x = x.view(-1, 1024)
34 |
35 | x = F.relu(self.bn4(self.fc1(x)))
36 | x = F.relu(self.bn5(self.fc2(x)))
37 | x = self.fc3(x)
38 |
39 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
40 | batchsize, 1)
41 | if x.is_cuda:
42 | iden = iden.cuda()
43 | x = x + iden
44 | x = x.view(-1, 3, 3)
45 | return x
46 |
47 |
48 | class STNkd(nn.Module):
49 | def __init__(self, k=64):
50 | super(STNkd, self).__init__()
51 | self.conv1 = torch.nn.Conv1d(k, 64, 1)
52 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
53 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
54 | self.fc1 = nn.Linear(1024, 512)
55 | self.fc2 = nn.Linear(512, 256)
56 | self.fc3 = nn.Linear(256, k * k)
57 | self.relu = nn.ReLU()
58 |
59 | self.bn1 = nn.BatchNorm1d(64)
60 | self.bn2 = nn.BatchNorm1d(128)
61 | self.bn3 = nn.BatchNorm1d(1024)
62 | self.bn4 = nn.BatchNorm1d(512)
63 | self.bn5 = nn.BatchNorm1d(256)
64 |
65 | self.k = k
66 |
67 | def forward(self, x):
68 | batchsize = x.size()[0]
69 | x = F.relu(self.bn1(self.conv1(x)))
70 | x = F.relu(self.bn2(self.conv2(x)))
71 | x = F.relu(self.bn3(self.conv3(x)))
72 | x = torch.max(x, 2, keepdim=True)[0]
73 | x = x.view(-1, 1024)
74 |
75 | x = F.relu(self.bn4(self.fc1(x)))
76 | x = F.relu(self.bn5(self.fc2(x)))
77 | x = self.fc3(x)
78 |
79 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
80 | batchsize, 1)
81 | if x.is_cuda:
82 | iden = iden.cuda()
83 | x = x + iden
84 | x = x.view(-1, self.k, self.k)
85 | return x
86 |
87 |
88 | class PointNetEncoder(nn.Module):
89 | def __init__(self, global_feat=True, feature_transform=False, channel=3):
90 | super(PointNetEncoder, self).__init__()
91 | self.stn = STN3d(channel)
92 | self.conv1 = torch.nn.Conv1d(channel, 64, 1)
93 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
94 | self.conv3 = torch.nn.Conv1d(128, 1024, 1)
95 | self.bn1 = nn.BatchNorm1d(64)
96 | self.bn2 = nn.BatchNorm1d(128)
97 | self.bn3 = nn.BatchNorm1d(1024)
98 | self.global_feat = global_feat
99 | self.feature_transform = feature_transform
100 | if self.feature_transform:
101 | self.fstn = STNkd(k=64)
102 |
103 | def forward(self, x):
104 | B, D, N = x.size()
105 | trans = self.stn(x)
106 | x = x.transpose(2, 1)
107 | if D >3 :
108 | x, feature = x.split(3,dim=2)
109 | x = torch.bmm(x, trans)
110 | if D > 3:
111 | x = torch.cat([x,feature],dim=2)
112 | x = x.transpose(2, 1)
113 | x = F.relu(self.bn1(self.conv1(x)))
114 |
115 | if self.feature_transform:
116 | trans_feat = self.fstn(x)
117 | x = x.transpose(2, 1)
118 | x = torch.bmm(x, trans_feat)
119 | x = x.transpose(2, 1)
120 | else:
121 | trans_feat = None
122 |
123 | pointfeat = x
124 | x = F.relu(self.bn2(self.conv2(x)))
125 | x = self.bn3(self.conv3(x))
126 | x = torch.max(x, 2, keepdim=True)[0]
127 | x = x.view(-1, 1024)
128 | if self.global_feat:
129 | return x, trans, trans_feat
130 | else:
131 | x = x.view(-1, 1024, 1).repeat(1, 1, N)
132 | return torch.cat([x, pointfeat], 1), trans, trans_feat
133 |
134 |
135 | def feature_transform_reguliarzer(trans):
136 | d = trans.size()[1]
137 | I = torch.eye(d)[None, :, :]
138 | if trans.is_cuda:
139 | I = I.cuda()
140 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2)))
141 | return loss
142 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_cls_msg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetSetAbstraction
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self,num_class,normal_channel=True):
8 | super(get_model, self).__init__()
9 | in_channel = 3 if normal_channel else 0
10 | self.normal_channel = normal_channel
11 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
12 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
13 | self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
14 | self.fc1 = nn.Linear(1024, 512)
15 | self.bn1 = nn.BatchNorm1d(512)
16 | self.drop1 = nn.Dropout(0.4)
17 | self.fc2 = nn.Linear(512, 256)
18 | self.bn2 = nn.BatchNorm1d(256)
19 | self.drop2 = nn.Dropout(0.5)
20 | self.fc3 = nn.Linear(256, num_class)
21 |
22 | def forward(self, xyz):
23 | B, _, _ = xyz.shape
24 | if self.normal_channel:
25 | norm = xyz[:, 3:, :]
26 | xyz = xyz[:, :3, :]
27 | else:
28 | norm = None
29 | l1_xyz, l1_points = self.sa1(xyz, norm)
30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
32 | x = l3_points.view(B, 1024)
33 | x = self.drop1(F.relu(self.bn1(self.fc1(x))))
34 | x = self.drop2(F.relu(self.bn2(self.fc2(x))))
35 | x = self.fc3(x)
36 | x = F.log_softmax(x, -1)
37 |
38 |
39 | return x,l3_points
40 |
41 |
42 | class get_loss(nn.Module):
43 | def __init__(self):
44 | super(get_loss, self).__init__()
45 |
46 | def forward(self, pred, target, trans_feat):
47 | total_loss = F.nll_loss(pred, target)
48 |
49 | return total_loss
50 |
51 |
52 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_cls_ssg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from .pointnet_util import PointNetSetAbstraction
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self,num_class,normal_channel=True):
8 | super(get_model, self).__init__()
9 | in_channel = 6 if normal_channel else 3
10 | self.normal_channel = normal_channel
11 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
12 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
13 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
14 | self.fc1 = nn.Linear(1024, 512)
15 | self.bn1 = nn.BatchNorm1d(512)
16 | self.drop1 = nn.Dropout(0.4)
17 | self.fc2 = nn.Linear(512, 256)
18 | self.bn2 = nn.BatchNorm1d(256)
19 | self.drop2 = nn.Dropout(0.4)
20 | self.fc3 = nn.Linear(256, num_class)
21 |
22 | def forward(self, xyz):
23 | B, _, _ = xyz.shape
24 | if self.normal_channel:
25 | norm = xyz[:, 3:, :]
26 | xyz = xyz[:, :3, :]
27 | else:
28 | norm = None
29 | l1_xyz, l1_points = self.sa1(xyz, norm)
30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
32 | x = l3_points.view(B, 1024)
33 | x = self.drop1(F.relu(self.bn1(self.fc1(x))))
34 | x = self.drop2(F.relu(self.bn2(self.fc2(x))))
35 | x = self.fc3(x)
36 | x = F.log_softmax(x, -1)
37 |
38 |
39 | return x, l3_points
40 |
41 |
42 |
43 | class get_loss(nn.Module):
44 | def __init__(self):
45 | super(get_loss, self).__init__()
46 |
47 | def forward(self, pred, target, trans_feat):
48 | total_loss = F.nll_loss(pred, target)
49 |
50 | return total_loss
51 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_part_seg_msg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from .pointnet_util import PointNetSetAbstractionMsg,PointNetSetAbstraction,PointNetFeaturePropagation
5 |
6 |
7 | class get_model(nn.Module):
8 | def __init__(self, num_classes, normal_channel=False):
9 | super(get_model, self).__init__()
10 | if normal_channel:
11 | additional_channel = 3
12 | else:
13 | additional_channel = 0
14 | self.normal_channel = normal_channel
15 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
16 | self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
18 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
19 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
20 | self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128])
21 | self.conv1 = nn.Conv1d(128, 128, 1)
22 | self.bn1 = nn.BatchNorm1d(128)
23 | self.drop1 = nn.Dropout(0.5)
24 | self.conv2 = nn.Conv1d(128, num_classes, 1)
25 |
26 | def forward(self, xyz, cls_label):
27 | # Set Abstraction layers
28 | B,C,N = xyz.shape
29 | if self.normal_channel:
30 | l0_points = xyz
31 | l0_xyz = xyz[:,:3,:]
32 | else:
33 | l0_points = xyz
34 | l0_xyz = xyz
35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
38 | # Feature Propagation layers
39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
43 | # FC layers
44 | feat = F.relu(self.bn1(self.conv1(l0_points)))
45 | x = self.drop1(feat)
46 | x = self.conv2(x)
47 | x = F.log_softmax(x, dim=1)
48 | x = x.permute(0, 2, 1)
49 | return x, l3_points
50 |
51 |
52 | class get_loss(nn.Module):
53 | def __init__(self):
54 | super(get_loss, self).__init__()
55 |
56 | def forward(self, pred, target, trans_feat):
57 | total_loss = F.nll_loss(pred, target)
58 |
59 | return total_loss
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_part_seg_ssg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | from .pointnet_util import PointNetSetAbstraction,PointNetFeaturePropagation
5 |
6 |
7 | class get_model(nn.Module):
8 | def __init__(self, num_classes, normal_channel=False):
9 | super(get_model, self).__init__()
10 | if normal_channel:
11 | additional_channel = 3
12 | else:
13 | additional_channel = 0
14 | self.normal_channel = normal_channel
15 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=6+additional_channel, mlp=[64, 64, 128], group_all=False)
16 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
18 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256])
19 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128])
20 | self.fp1 = PointNetFeaturePropagation(in_channel=128+16+6+additional_channel, mlp=[128, 128, 128])
21 | self.conv1 = nn.Conv1d(128, 128, 1)
22 | self.bn1 = nn.BatchNorm1d(128)
23 | self.drop1 = nn.Dropout(0.5)
24 | self.conv2 = nn.Conv1d(128, num_classes, 1)
25 |
26 | def forward(self, xyz, cls_label):
27 | # Set Abstraction layers
28 | B,C,N = xyz.shape
29 | if self.normal_channel:
30 | l0_points = xyz
31 | l0_xyz = xyz[:,:3,:]
32 | else:
33 | l0_points = xyz
34 | l0_xyz = xyz
35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
38 | # Feature Propagation layers
39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)
42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)
43 | # FC layers
44 | feat = F.relu(self.bn1(self.conv1(l0_points)))
45 | x = self.drop1(feat)
46 | x = self.conv2(x)
47 | x = F.log_softmax(x, dim=1)
48 | x = x.permute(0, 2, 1)
49 | return x, l3_points
50 |
51 |
52 | class get_loss(nn.Module):
53 | def __init__(self):
54 | super(get_loss, self).__init__()
55 |
56 | def forward(self, pred, target, trans_feat):
57 | total_loss = F.nll_loss(pred, target)
58 |
59 | return total_loss
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_sem_seg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from .pointnet_util import PointNetSetAbstraction,PointNetFeaturePropagation
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self, num_classes):
8 | super(get_model, self).__init__()
9 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
10 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
11 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
12 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
13 | self.fp4 = PointNetFeaturePropagation(768, [256, 256])
14 | self.fp3 = PointNetFeaturePropagation(384, [256, 256])
15 | self.fp2 = PointNetFeaturePropagation(320, [256, 128])
16 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
17 | self.conv1 = nn.Conv1d(128, 128, 1)
18 | self.bn1 = nn.BatchNorm1d(128)
19 | self.drop1 = nn.Dropout(0.5)
20 | self.conv2 = nn.Conv1d(128, num_classes, 1)
21 |
22 | def forward(self, xyz):
23 | l0_points = xyz
24 | l0_xyz = xyz[:,:3,:]
25 |
26 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
27 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
28 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
29 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
30 |
31 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
32 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
33 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
34 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
35 |
36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
37 | x = self.conv2(x)
38 | x = F.log_softmax(x, dim=1)
39 | x = x.permute(0, 2, 1)
40 | return x, l4_points
41 |
42 |
43 | class get_loss(nn.Module):
44 | def __init__(self):
45 | super(get_loss, self).__init__()
46 | def forward(self, pred, target, trans_feat, weight):
47 | total_loss = F.nll_loss(pred, target, weight=weight)
48 |
49 | return total_loss
50 |
51 | if __name__ == '__main__':
52 | import torch
53 | model = get_model(13)
54 | xyz = torch.rand(6, 9, 2048)
55 | (model(xyz))
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet2_sem_seg_msg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetFeaturePropagation
4 |
5 |
6 | class get_model(nn.Module):
7 | def __init__(self, num_classes):
8 | super(get_model, self).__init__()
9 |
10 | self.sa1 = PointNetSetAbstractionMsg(1024, [0.05, 0.1], [16, 32], 9, [[16, 16, 32], [32, 32, 64]])
11 | self.sa2 = PointNetSetAbstractionMsg(256, [0.1, 0.2], [16, 32], 32 + 64, [[64, 64, 128], [64, 96, 128]])
12 | self.sa3 = PointNetSetAbstractionMsg(64, [0.2, 0.4], [16, 32], 128 + 128, [[128, 196, 256], [128, 196, 256]])
13 | self.sa4 = PointNetSetAbstractionMsg(16, [0.4, 0.8], [16, 32], 256 + 256, [[256, 256, 512], [256, 384, 512]])
14 | self.fp4 = PointNetFeaturePropagation(512 + 512 + 256 + 256, [256, 256])
15 | self.fp3 = PointNetFeaturePropagation(128 + 128 + 256, [256, 256])
16 | self.fp2 = PointNetFeaturePropagation(32 + 64 + 256, [256, 128])
17 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
18 | self.conv1 = nn.Conv1d(128, 128, 1)
19 | self.bn1 = nn.BatchNorm1d(128)
20 | self.drop1 = nn.Dropout(0.5)
21 | self.conv2 = nn.Conv1d(128, num_classes, 1)
22 |
23 | def forward(self, xyz):
24 | l0_points = xyz
25 | l0_xyz = xyz[:, :3, :]
26 |
27 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
28 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
29 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
30 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
31 |
32 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
33 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
34 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
35 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
37 | x = self.conv2(x)
38 | x = x.float()
39 | x = F.log_softmax(x, dim=1)
40 | x = x.permute(0, 2, 1)
41 |
42 | return x, l4_points
43 |
44 |
45 | class get_loss(nn.Module):
46 | def __init__(self):
47 | super(get_loss, self).__init__()
48 |
49 | def forward(self, pred, target, trans_feat, weight):
50 | total_loss = F.nll_loss(pred, target, weight=weight)
51 |
52 | return total_loss
53 |
54 |
55 | if __name__ == '__main__':
56 | import torch
57 |
58 | model = get_model(13)
59 | xyz = torch.rand(6, 9, 2048)
60 | print(model(xyz))
61 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet_cls.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.data
3 | import torch.nn.functional as F
4 | from .pointnet import PointNetEncoder, feature_transform_reguliarzer
5 |
6 | class get_model(nn.Module):
7 | def __init__(self, k=40, normal_channel=True):
8 | super(get_model, self).__init__()
9 | if normal_channel:
10 | channel = 6
11 | else:
12 | channel = 3
13 | self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel)
14 | self.fc1 = nn.Linear(1024, 512)
15 | self.fc2 = nn.Linear(512, 256)
16 | self.fc3 = nn.Linear(256, k)
17 | self.dropout = nn.Dropout(p=0.4)
18 | self.bn1 = nn.BatchNorm1d(512)
19 | self.bn2 = nn.BatchNorm1d(256)
20 | self.relu = nn.ReLU()
21 |
22 | def forward(self, x):
23 | x, trans, trans_feat = self.feat(x)
24 | x = F.relu(self.bn1(self.fc1(x)))
25 | x = F.relu(self.bn2(self.dropout(self.fc2(x))))
26 | x = self.fc3(x)
27 | x = F.log_softmax(x, dim=1)
28 | return x, trans_feat
29 |
30 | class get_loss(torch.nn.Module):
31 | def __init__(self, mat_diff_loss_scale=0.001):
32 | super(get_loss, self).__init__()
33 | self.mat_diff_loss_scale = mat_diff_loss_scale
34 |
35 | def forward(self, pred, target, trans_feat):
36 | loss = F.nll_loss(pred, target)
37 | mat_diff_loss = feature_transform_reguliarzer(trans_feat)
38 |
39 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
40 | return total_loss
41 |
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet_part_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.utils.data
5 | import torch.nn.functional as F
6 | from .pointnet import STN3d, STNkd, feature_transform_reguliarzer
7 |
8 |
9 | class get_model(nn.Module):
10 | def __init__(self, part_num=50, normal_channel=True):
11 | super(get_model, self).__init__()
12 | if normal_channel:
13 | channel = 6
14 | else:
15 | channel = 3
16 | self.part_num = part_num
17 | self.stn = STN3d(channel)
18 | self.conv1 = torch.nn.Conv1d(channel, 64, 1)
19 | self.conv2 = torch.nn.Conv1d(64, 128, 1)
20 | self.conv3 = torch.nn.Conv1d(128, 128, 1)
21 | self.conv4 = torch.nn.Conv1d(128, 512, 1)
22 | self.conv5 = torch.nn.Conv1d(512, 2048, 1)
23 | self.bn1 = nn.BatchNorm1d(64)
24 | self.bn2 = nn.BatchNorm1d(128)
25 | self.bn3 = nn.BatchNorm1d(128)
26 | self.bn4 = nn.BatchNorm1d(512)
27 | self.bn5 = nn.BatchNorm1d(2048)
28 | self.fstn = STNkd(k=128)
29 | self.convs1 = torch.nn.Conv1d(4944, 256, 1)
30 | self.convs2 = torch.nn.Conv1d(256, 256, 1)
31 | self.convs3 = torch.nn.Conv1d(256, 128, 1)
32 | self.convs4 = torch.nn.Conv1d(128, part_num, 1)
33 | self.bns1 = nn.BatchNorm1d(256)
34 | self.bns2 = nn.BatchNorm1d(256)
35 | self.bns3 = nn.BatchNorm1d(128)
36 |
37 | def forward(self, point_cloud, label):
38 | B, D, N = point_cloud.size()
39 | trans = self.stn(point_cloud)
40 | point_cloud = point_cloud.transpose(2, 1)
41 | if D > 3:
42 | point_cloud, feature = point_cloud.split(3, dim=2)
43 | point_cloud = torch.bmm(point_cloud, trans)
44 | if D > 3:
45 | point_cloud = torch.cat([point_cloud, feature], dim=2)
46 |
47 | point_cloud = point_cloud.transpose(2, 1)
48 |
49 | out1 = F.relu(self.bn1(self.conv1(point_cloud)))
50 | out2 = F.relu(self.bn2(self.conv2(out1)))
51 | out3 = F.relu(self.bn3(self.conv3(out2)))
52 |
53 | trans_feat = self.fstn(out3)
54 | x = out3.transpose(2, 1)
55 | net_transformed = torch.bmm(x, trans_feat)
56 | net_transformed = net_transformed.transpose(2, 1)
57 |
58 | out4 = F.relu(self.bn4(self.conv4(net_transformed)))
59 | out5 = self.bn5(self.conv5(out4))
60 | out_max = torch.max(out5, 2, keepdim=True)[0]
61 | out_max = out_max.view(-1, 2048)
62 |
63 | out_max = torch.cat([out_max,label.squeeze(1)],1)
64 | expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
65 | concat = torch.cat([expand, out1, out2, out3, out4, out5], 1)
66 | net = F.relu(self.bns1(self.convs1(concat)))
67 | net = F.relu(self.bns2(self.convs2(net)))
68 | net = F.relu(self.bns3(self.convs3(net)))
69 | net = self.convs4(net)
70 | net = net.transpose(2, 1).contiguous()
71 | net = F.log_softmax(net.view(-1, self.part_num), dim=-1)
72 | net = net.view(B, N, self.part_num) # [B, N, 50]
73 |
74 | return net, trans_feat
75 |
76 |
77 | class get_loss(torch.nn.Module):
78 | def __init__(self, mat_diff_loss_scale=0.001):
79 | super(get_loss, self).__init__()
80 | self.mat_diff_loss_scale = mat_diff_loss_scale
81 |
82 | def forward(self, pred, target, trans_feat):
83 | loss = F.nll_loss(pred, target)
84 | mat_diff_loss = feature_transform_reguliarzer(trans_feat)
85 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
86 | return total_loss
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet_sem_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import torch.utils.data
5 | import torch.nn.functional as F
6 | from .pointnet import PointNetEncoder, feature_transform_reguliarzer
7 |
8 |
9 | class get_model(nn.Module):
10 | def __init__(self, num_class, with_rgb=True):
11 | super(get_model, self).__init__()
12 | if with_rgb:
13 | channel = 6
14 | else:
15 | channel = 3
16 | self.k = num_class
17 | self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=channel)
18 | self.conv1 = torch.nn.Conv1d(1088, 512, 1)
19 | self.conv2 = torch.nn.Conv1d(512, 256, 1)
20 | self.conv3 = torch.nn.Conv1d(256, 128, 1)
21 | self.conv4 = torch.nn.Conv1d(128, self.k, 1)
22 | self.bn1 = nn.BatchNorm1d(512)
23 | self.bn2 = nn.BatchNorm1d(256)
24 | self.bn3 = nn.BatchNorm1d(128)
25 |
26 | def forward(self, x):
27 | batchsize = x.size()[0]
28 | n_pts = x.size()[2]
29 | x, trans, trans_feat = self.feat(x)
30 | x = F.relu(self.bn1(self.conv1(x)))
31 | x = F.relu(self.bn2(self.conv2(x)))
32 | x = F.relu(self.bn3(self.conv3(x)))
33 | x = self.conv4(x)
34 | x = x.transpose(2,1).contiguous()
35 | x = F.log_softmax(x.view(-1,self.k), dim=-1)
36 | x = x.view(batchsize, n_pts, self.k)
37 | return x, trans_feat
38 |
39 | class get_loss(torch.nn.Module):
40 | def __init__(self, mat_diff_loss_scale=0.001):
41 | super(get_loss, self).__init__()
42 | self.mat_diff_loss_scale = mat_diff_loss_scale
43 |
44 | def forward(self, pred, target, trans_feat, weight):
45 | loss = F.nll_loss(pred, target, weight = weight)
46 | mat_diff_loss = feature_transform_reguliarzer(trans_feat)
47 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
48 | return total_loss
49 |
50 |
51 | if __name__ == '__main__':
52 | model = get_model(13, with_rgb=False)
53 | xyz = torch.rand(12, 3, 2048)
54 | (model(xyz))
--------------------------------------------------------------------------------
/models/pointnetpp/pointnet_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 | def timeit(tag, t):
8 | print("{}: {}s".format(tag, time() - t))
9 | return time()
10 |
11 | def pc_normalize(pc):
12 | l = pc.shape[0]
13 | centroid = np.mean(pc, axis=0)
14 | pc = pc - centroid
15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16 | pc = pc / m
17 | return pc
18 |
19 | def square_distance(src, dst):
20 | """
21 | Calculate Euclid distance between each two points.
22 |
23 | src^T * dst = xn * xm + yn * ym + zn * zm;
24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
28 |
29 | Input:
30 | src: source points, [B, N, C]
31 | dst: target points, [B, M, C]
32 | Output:
33 | dist: per-point square distance, [B, N, M]
34 | """
35 | B, N, _ = src.shape
36 | _, M, _ = dst.shape
37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
38 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
40 | return dist
41 |
42 |
43 | def index_points(points, idx):
44 | """
45 |
46 | Input:
47 | points: input points data, [B, N, C]
48 | idx: sample index data, [B, S]
49 | Return:
50 | new_points:, indexed points data, [B, S, C]
51 | """
52 | device = points.device
53 | B = points.shape[0]
54 | view_shape = list(idx.shape)
55 | view_shape[1:] = [1] * (len(view_shape) - 1)
56 | repeat_shape = list(idx.shape)
57 | repeat_shape[0] = 1
58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
59 | new_points = points[batch_indices, idx, :]
60 | return new_points
61 |
62 |
63 | def farthest_point_sample(xyz, npoint):
64 | """
65 | Input:
66 | xyz: pointcloud data, [B, N, 3]
67 | npoint: number of samples
68 | Return:
69 | centroids: sampled pointcloud index, [B, npoint]
70 | """
71 | device = xyz.device
72 | B, N, C = xyz.shape
73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
74 | distance = torch.ones(B, N).to(device) * 1e10
75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
76 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
77 | for i in range(npoint):
78 | centroids[:, i] = farthest
79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
80 | dist = torch.sum((xyz - centroid) ** 2, -1)
81 | mask = dist < distance
82 | distance[mask] = dist[mask]
83 | farthest = torch.max(distance, -1)[1]
84 | return centroids
85 |
86 |
87 | def query_ball_point(radius, nsample, xyz, new_xyz):
88 | """
89 | Input:
90 | radius: local region radius
91 | nsample: max sample number in local region
92 | xyz: all points, [B, N, 3]
93 | new_xyz: query points, [B, S, 3]
94 | Return:
95 | group_idx: grouped points index, [B, S, nsample]
96 | """
97 | device = xyz.device
98 | B, N, C = xyz.shape
99 | _, S, _ = new_xyz.shape
100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
101 | sqrdists = square_distance(new_xyz, xyz)
102 | group_idx[sqrdists > radius ** 2] = N
103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
105 | mask = group_idx == N
106 | group_idx[mask] = group_first[mask]
107 | return group_idx
108 |
109 |
110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
111 | """
112 | Input:
113 | npoint:
114 | radius:
115 | nsample:
116 | xyz: input points position data, [B, N, 3]
117 | points: input points data, [B, N, D]
118 | Return:
119 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
120 | new_points: sampled points data, [B, npoint, nsample, 3+D]
121 | """
122 | B, N, C = xyz.shape
123 | S = npoint
124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
125 | torch.cuda.empty_cache()
126 | new_xyz = index_points(xyz, fps_idx)
127 | torch.cuda.empty_cache()
128 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
129 | torch.cuda.empty_cache()
130 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
131 | torch.cuda.empty_cache()
132 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
133 | torch.cuda.empty_cache()
134 |
135 | if points is not None:
136 | grouped_points = index_points(points, idx)
137 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
138 | else:
139 | new_points = grouped_xyz_norm
140 | if returnfps:
141 | return new_xyz, new_points, grouped_xyz, fps_idx
142 | else:
143 | return new_xyz, new_points
144 |
145 |
146 | def sample_and_group_all(xyz, points):
147 | """
148 | Input:
149 | xyz: input points position data, [B, N, 3]
150 | points: input points data, [B, N, D]
151 | Return:
152 | new_xyz: sampled points position data, [B, 1, 3]
153 | new_points: sampled points data, [B, 1, N, 3+D]
154 | """
155 | device = xyz.device
156 | B, N, C = xyz.shape
157 | new_xyz = torch.zeros(B, 1, C).to(device)
158 | grouped_xyz = xyz.view(B, 1, N, C)
159 | if points is not None:
160 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
161 | else:
162 | new_points = grouped_xyz
163 | return new_xyz, new_points
164 |
165 |
166 | class PointNetSetAbstraction(nn.Module):
167 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
168 | super(PointNetSetAbstraction, self).__init__()
169 | self.npoint = npoint
170 | self.radius = radius
171 | self.nsample = nsample
172 | self.mlp_convs = nn.ModuleList()
173 | self.mlp_bns = nn.ModuleList()
174 | last_channel = in_channel
175 | for out_channel in mlp:
176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
178 | last_channel = out_channel
179 | self.group_all = group_all
180 |
181 | def forward(self, xyz, points):
182 | """
183 | Input:
184 | xyz: input points position data, [B, C, N]
185 | points: input points data, [B, D, N]
186 | Return:
187 | new_xyz: sampled points position data, [B, C, S]
188 | new_points_concat: sample points feature data, [B, D', S]
189 | """
190 | xyz = xyz.permute(0, 2, 1)
191 | if points is not None:
192 | points = points.permute(0, 2, 1)
193 |
194 | if self.group_all:
195 | new_xyz, new_points = sample_and_group_all(xyz, points)
196 | else:
197 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
198 | # new_xyz: sampled points position data, [B, npoint, C]
199 | # new_points: sampled points data, [B, npoint, nsample, C+D]
200 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
201 | for i, conv in enumerate(self.mlp_convs):
202 | bn = self.mlp_bns[i]
203 | new_points = F.relu(bn(conv(new_points)))
204 |
205 | new_points = torch.max(new_points, 2)[0]
206 | new_xyz = new_xyz.permute(0, 2, 1)
207 | return new_xyz, new_points
208 |
209 |
210 | class PointNetSetAbstractionMsg(nn.Module):
211 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
212 | super(PointNetSetAbstractionMsg, self).__init__()
213 | self.npoint = npoint
214 | self.radius_list = radius_list
215 | self.nsample_list = nsample_list
216 | self.conv_blocks = nn.ModuleList()
217 | self.bn_blocks = nn.ModuleList()
218 | for i in range(len(mlp_list)):
219 | convs = nn.ModuleList()
220 | bns = nn.ModuleList()
221 | last_channel = in_channel + 3
222 | for out_channel in mlp_list[i]:
223 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
224 | bns.append(nn.BatchNorm2d(out_channel))
225 | last_channel = out_channel
226 | self.conv_blocks.append(convs)
227 | self.bn_blocks.append(bns)
228 |
229 | def forward(self, xyz, points):
230 | """
231 | Input:
232 | xyz: input points position data, [B, C, N]
233 | points: input points data, [B, D, N]
234 | Return:
235 | new_xyz: sampled points position data, [B, C, S]
236 | new_points_concat: sample points feature data, [B, D', S]
237 | """
238 | xyz = xyz.permute(0, 2, 1)
239 | if points is not None:
240 | points = points.permute(0, 2, 1)
241 |
242 | B, N, C = xyz.shape
243 | S = self.npoint
244 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
245 | new_points_list = []
246 | for i, radius in enumerate(self.radius_list):
247 | K = self.nsample_list[i]
248 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
249 | grouped_xyz = index_points(xyz, group_idx)
250 | grouped_xyz -= new_xyz.view(B, S, 1, C)
251 | if points is not None:
252 | grouped_points = index_points(points, group_idx)
253 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
254 | else:
255 | grouped_points = grouped_xyz
256 |
257 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
258 | for j in range(len(self.conv_blocks[i])):
259 | conv = self.conv_blocks[i][j]
260 | bn = self.bn_blocks[i][j]
261 | grouped_points = F.relu(bn(conv(grouped_points)))
262 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
263 | new_points_list.append(new_points)
264 |
265 | new_xyz = new_xyz.permute(0, 2, 1)
266 | new_points_concat = torch.cat(new_points_list, dim=1)
267 | return new_xyz, new_points_concat
268 |
269 |
270 | class PointNetFeaturePropagation(nn.Module):
271 | def __init__(self, in_channel, mlp):
272 | super(PointNetFeaturePropagation, self).__init__()
273 | self.mlp_convs = nn.ModuleList()
274 | self.mlp_bns = nn.ModuleList()
275 | last_channel = in_channel
276 | for out_channel in mlp:
277 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
278 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
279 | last_channel = out_channel
280 |
281 | def forward(self, xyz1, xyz2, points1, points2):
282 | """
283 | Input:
284 | xyz1: input points position data, [B, C, N]
285 | xyz2: sampled input points position data, [B, C, S]
286 | points1: input points data, [B, D, N]
287 | points2: sampled input points data, [B, D, S]
288 | Return:
289 | new_points: upsampled points data, [B, D', N]
290 | """
291 | xyz1 = xyz1.permute(0, 2, 1)
292 | xyz2 = xyz2.permute(0, 2, 1)
293 |
294 | points2 = points2.permute(0, 2, 1)
295 | B, N, C = xyz1.shape
296 | _, S, _ = xyz2.shape
297 |
298 | if S == 1:
299 | interpolated_points = points2.repeat(1, N, 1)
300 | else:
301 | dists = square_distance(xyz1, xyz2)
302 | dists, idx = dists.sort(dim=-1)
303 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
304 |
305 | dist_recip = 1.0 / (dists + 1e-8)
306 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
307 | weight = dist_recip / norm
308 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
309 |
310 | if points1 is not None:
311 | points1 = points1.permute(0, 2, 1)
312 | new_points = torch.cat([points1, interpolated_points], dim=-1)
313 | else:
314 | new_points = interpolated_points
315 |
316 | new_points = new_points.permute(0, 2, 1)
317 | for i, conv in enumerate(self.mlp_convs):
318 | bn = self.mlp_bns[i]
319 | new_points = F.relu(bn(conv(new_points)))
320 | return new_points
321 |
322 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/MANIFEST.in:
--------------------------------------------------------------------------------
1 | graft pointnet2_ops/_ext-src
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: pointnet2-ops
3 | Version: 3.0.0
4 | Summary: UNKNOWN
5 | Author: Erik Wijmans
6 | License: UNKNOWN
7 | Platform: UNKNOWN
8 |
9 | UNKNOWN
10 |
11 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | MANIFEST.in
2 | setup.py
3 | pointnet2_ops/__init__.py
4 | pointnet2_ops/_version.py
5 | pointnet2_ops/pointnet2_modules.py
6 | pointnet2_ops/pointnet2_utils.py
7 | pointnet2_ops.egg-info/PKG-INFO
8 | pointnet2_ops.egg-info/SOURCES.txt
9 | pointnet2_ops.egg-info/dependency_links.txt
10 | pointnet2_ops.egg-info/requires.txt
11 | pointnet2_ops.egg-info/top_level.txt
12 | pointnet2_ops/_ext-src/include/ball_query.h
13 | pointnet2_ops/_ext-src/include/cuda_utils.h
14 | pointnet2_ops/_ext-src/include/group_points.h
15 | pointnet2_ops/_ext-src/include/interpolate.h
16 | pointnet2_ops/_ext-src/include/sampling.h
17 | pointnet2_ops/_ext-src/include/utils.h
18 | pointnet2_ops/_ext-src/src/ball_query.cpp
19 | pointnet2_ops/_ext-src/src/ball_query_gpu.cu
20 | pointnet2_ops/_ext-src/src/bindings.cpp
21 | pointnet2_ops/_ext-src/src/group_points.cpp
22 | pointnet2_ops/_ext-src/src/group_points_gpu.cu
23 | pointnet2_ops/_ext-src/src/interpolate.cpp
24 | pointnet2_ops/_ext-src/src/interpolate_gpu.cu
25 | pointnet2_ops/_ext-src/src/sampling.cpp
26 | pointnet2_ops/_ext-src/src/sampling_gpu.cu
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | pointnet2_ops
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/__init__.py:
--------------------------------------------------------------------------------
1 | import pointnet2_ops.pointnet2_modules
2 | import pointnet2_ops.pointnet2_utils
3 | from pointnet2_ops._version import __version__
4 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5 | const int nsample);
6 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #define TOTAL_THREADS 512
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 |
18 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
19 | }
20 |
21 | inline dim3 opt_block_config(int x, int y) {
22 | const int x_threads = opt_n_threads(x);
23 | const int y_threads =
24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25 | dim3 block_config(x_threads, y_threads, 1);
26 |
27 | return block_config;
28 | }
29 |
30 | #define CUDA_CHECK_ERRORS() \
31 | do { \
32 | cudaError_t err = cudaGetLastError(); \
33 | if (cudaSuccess != err) { \
34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36 | __FILE__); \
37 | exit(-1); \
38 | } \
39 | } while (0)
40 |
41 | #endif
42 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
8 | at::Tensor weight);
9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
10 | at::Tensor weight, const int m);
11 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
7 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define CHECK_CUDA(x) \
6 | do { \
7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
8 | } while (0)
9 |
10 | #define CHECK_CONTIGUOUS(x) \
11 | do { \
12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_IS_INT(x) \
16 | do { \
17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
18 | #x " must be an int tensor"); \
19 | } while (0)
20 |
21 | #define CHECK_IS_FLOAT(x) \
22 | do { \
23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
24 | #x " must be a float tensor"); \
25 | } while (0)
26 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "utils.h"
3 |
4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5 | int nsample, const float *new_xyz,
6 | const float *xyz, int *idx);
7 |
8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9 | const int nsample) {
10 | CHECK_CONTIGUOUS(new_xyz);
11 | CHECK_CONTIGUOUS(xyz);
12 | CHECK_IS_FLOAT(new_xyz);
13 | CHECK_IS_FLOAT(xyz);
14 |
15 | if (new_xyz.is_cuda()) {
16 | CHECK_CUDA(xyz);
17 | }
18 |
19 | at::Tensor idx =
20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
22 |
23 | if (new_xyz.is_cuda()) {
24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25 | radius, nsample, new_xyz.data_ptr(),
26 | xyz.data_ptr(), idx.data_ptr());
27 | } else {
28 | AT_ASSERT(false, "CPU not supported");
29 | }
30 |
31 | return idx;
32 | }
33 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
8 | // output: idx(b, m, nsample)
9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10 | int nsample,
11 | const float *__restrict__ new_xyz,
12 | const float *__restrict__ xyz,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | xyz += batch_index * n * 3;
16 | new_xyz += batch_index * m * 3;
17 | idx += m * nsample * batch_index;
18 |
19 | int index = threadIdx.x;
20 | int stride = blockDim.x;
21 |
22 | float radius2 = radius * radius;
23 | for (int j = index; j < m; j += stride) {
24 | float new_x = new_xyz[j * 3 + 0];
25 | float new_y = new_xyz[j * 3 + 1];
26 | float new_z = new_xyz[j * 3 + 2];
27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28 | float x = xyz[k * 3 + 0];
29 | float y = xyz[k * 3 + 1];
30 | float z = xyz[k * 3 + 2];
31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32 | (new_z - z) * (new_z - z);
33 | if (d2 < radius2) {
34 | if (cnt == 0) {
35 | for (int l = 0; l < nsample; ++l) {
36 | idx[j * nsample + l] = k;
37 | }
38 | }
39 | idx[j * nsample + cnt] = k;
40 | ++cnt;
41 | }
42 | }
43 | }
44 | }
45 |
46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47 | int nsample, const float *new_xyz,
48 | const float *xyz, int *idx) {
49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50 | query_ball_point_kernel<<>>(
51 | b, n, m, radius, nsample, new_xyz, xyz, idx);
52 |
53 | CUDA_CHECK_ERRORS();
54 | }
55 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "group_points.h"
3 | #include "interpolate.h"
4 | #include "sampling.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("gather_points", &gather_points);
8 | m.def("gather_points_grad", &gather_points_grad);
9 | m.def("furthest_point_sampling", &furthest_point_sampling);
10 |
11 | m.def("three_nn", &three_nn);
12 | m.def("three_interpolate", &three_interpolate);
13 | m.def("three_interpolate_grad", &three_interpolate_grad);
14 |
15 | m.def("ball_query", &ball_query);
16 |
17 | m.def("group_points", &group_points);
18 | m.def("group_points_grad", &group_points_grad);
19 | }
20 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include "group_points.h"
2 | #include "utils.h"
3 |
4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5 | const float *points, const int *idx,
6 | float *out);
7 |
8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9 | int nsample, const float *grad_out,
10 | const int *idx, float *grad_points);
11 |
12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13 | CHECK_CONTIGUOUS(points);
14 | CHECK_CONTIGUOUS(idx);
15 | CHECK_IS_FLOAT(points);
16 | CHECK_IS_INT(idx);
17 |
18 | if (points.is_cuda()) {
19 | CHECK_CUDA(idx);
20 | }
21 |
22 | at::Tensor output =
23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24 | at::device(points.device()).dtype(at::ScalarType::Float));
25 |
26 | if (points.is_cuda()) {
27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28 | idx.size(1), idx.size(2),
29 | points.data_ptr(), idx.data_ptr(),
30 | output.data_ptr());
31 | } else {
32 | AT_ASSERT(false, "CPU not supported");
33 | }
34 |
35 | return output;
36 | }
37 |
38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
39 | CHECK_CONTIGUOUS(grad_out);
40 | CHECK_CONTIGUOUS(idx);
41 | CHECK_IS_FLOAT(grad_out);
42 | CHECK_IS_INT(idx);
43 |
44 | if (grad_out.is_cuda()) {
45 | CHECK_CUDA(idx);
46 | }
47 |
48 | at::Tensor output =
49 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
50 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
51 |
52 | if (grad_out.is_cuda()) {
53 | group_points_grad_kernel_wrapper(
54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
55 | grad_out.data_ptr(), idx.data_ptr(),
56 | output.data_ptr());
57 | } else {
58 | AT_ASSERT(false, "CPU not supported");
59 | }
60 |
61 | return output;
62 | }
63 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, npoints, nsample)
7 | // output: out(b, c, npoints, nsample)
8 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
9 | int nsample,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | int batch_index = blockIdx.x;
14 | points += batch_index * n * c;
15 | idx += batch_index * npoints * nsample;
16 | out += batch_index * npoints * nsample * c;
17 |
18 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
19 | const int stride = blockDim.y * blockDim.x;
20 | for (int i = index; i < c * npoints; i += stride) {
21 | const int l = i / npoints;
22 | const int j = i % npoints;
23 | for (int k = 0; k < nsample; ++k) {
24 | int ii = idx[j * nsample + k];
25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26 | }
27 | }
28 | }
29 |
30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31 | const float *points, const int *idx,
32 | float *out) {
33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 |
35 | group_points_kernel<<>>(
36 | b, c, n, npoints, nsample, points, idx, out);
37 |
38 | CUDA_CHECK_ERRORS();
39 | }
40 |
41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42 | // output: grad_points(b, c, n)
43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44 | int nsample,
45 | const float *__restrict__ grad_out,
46 | const int *__restrict__ idx,
47 | float *__restrict__ grad_points) {
48 | int batch_index = blockIdx.x;
49 | grad_out += batch_index * npoints * nsample * c;
50 | idx += batch_index * npoints * nsample;
51 | grad_points += batch_index * n * c;
52 |
53 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
54 | const int stride = blockDim.y * blockDim.x;
55 | for (int i = index; i < c * npoints; i += stride) {
56 | const int l = i / npoints;
57 | const int j = i % npoints;
58 | for (int k = 0; k < nsample; ++k) {
59 | int ii = idx[j * nsample + k];
60 | atomicAdd(grad_points + l * n + ii,
61 | grad_out[(l * npoints + j) * nsample + k]);
62 | }
63 | }
64 | }
65 |
66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67 | int nsample, const float *grad_out,
68 | const int *idx, float *grad_points) {
69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70 |
71 | group_points_grad_kernel<<>>(
72 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
73 |
74 | CUDA_CHECK_ERRORS();
75 | }
76 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include "interpolate.h"
2 | #include "utils.h"
3 |
4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5 | const float *known, float *dist2, int *idx);
6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
7 | const float *points, const int *idx,
8 | const float *weight, float *out);
9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
10 | const float *grad_out,
11 | const int *idx, const float *weight,
12 | float *grad_points);
13 |
14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
15 | CHECK_CONTIGUOUS(unknowns);
16 | CHECK_CONTIGUOUS(knows);
17 | CHECK_IS_FLOAT(unknowns);
18 | CHECK_IS_FLOAT(knows);
19 |
20 | if (unknowns.is_cuda()) {
21 | CHECK_CUDA(knows);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
26 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
27 | at::Tensor dist2 =
28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
29 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
30 |
31 | if (unknowns.is_cuda()) {
32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
33 | unknowns.data_ptr(), knows.data_ptr(),
34 | dist2.data_ptr(), idx.data_ptr());
35 | } else {
36 | AT_ASSERT(false, "CPU not supported");
37 | }
38 |
39 | return {dist2, idx};
40 | }
41 |
42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
43 | at::Tensor weight) {
44 | CHECK_CONTIGUOUS(points);
45 | CHECK_CONTIGUOUS(idx);
46 | CHECK_CONTIGUOUS(weight);
47 | CHECK_IS_FLOAT(points);
48 | CHECK_IS_INT(idx);
49 | CHECK_IS_FLOAT(weight);
50 |
51 | if (points.is_cuda()) {
52 | CHECK_CUDA(idx);
53 | CHECK_CUDA(weight);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
58 | at::device(points.device()).dtype(at::ScalarType::Float));
59 |
60 | if (points.is_cuda()) {
61 | three_interpolate_kernel_wrapper(
62 | points.size(0), points.size(1), points.size(2), idx.size(1),
63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
64 | output.data_ptr());
65 | } else {
66 | AT_ASSERT(false, "CPU not supported");
67 | }
68 |
69 | return output;
70 | }
71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
72 | at::Tensor weight, const int m) {
73 | CHECK_CONTIGUOUS(grad_out);
74 | CHECK_CONTIGUOUS(idx);
75 | CHECK_CONTIGUOUS(weight);
76 | CHECK_IS_FLOAT(grad_out);
77 | CHECK_IS_INT(idx);
78 | CHECK_IS_FLOAT(weight);
79 |
80 | if (grad_out.is_cuda()) {
81 | CHECK_CUDA(idx);
82 | CHECK_CUDA(weight);
83 | }
84 |
85 | at::Tensor output =
86 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
87 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
88 |
89 | if (grad_out.is_cuda()) {
90 | three_interpolate_grad_kernel_wrapper(
91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
92 | grad_out.data_ptr(), idx.data_ptr(),
93 | weight.data_ptr(), output.data_ptr());
94 | } else {
95 | AT_ASSERT(false, "CPU not supported");
96 | }
97 |
98 | return output;
99 | }
100 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: unknown(b, n, 3) known(b, m, 3)
8 | // output: dist2(b, n, 3), idx(b, n, 3)
9 | __global__ void three_nn_kernel(int b, int n, int m,
10 | const float *__restrict__ unknown,
11 | const float *__restrict__ known,
12 | float *__restrict__ dist2,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | unknown += batch_index * n * 3;
16 | known += batch_index * m * 3;
17 | dist2 += batch_index * n * 3;
18 | idx += batch_index * n * 3;
19 |
20 | int index = threadIdx.x;
21 | int stride = blockDim.x;
22 | for (int j = index; j < n; j += stride) {
23 | float ux = unknown[j * 3 + 0];
24 | float uy = unknown[j * 3 + 1];
25 | float uz = unknown[j * 3 + 2];
26 |
27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28 | int besti1 = 0, besti2 = 0, besti3 = 0;
29 | for (int k = 0; k < m; ++k) {
30 | float x = known[k * 3 + 0];
31 | float y = known[k * 3 + 1];
32 | float z = known[k * 3 + 2];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34 | if (d < best1) {
35 | best3 = best2;
36 | besti3 = besti2;
37 | best2 = best1;
38 | besti2 = besti1;
39 | best1 = d;
40 | besti1 = k;
41 | } else if (d < best2) {
42 | best3 = best2;
43 | besti3 = besti2;
44 | best2 = d;
45 | besti2 = k;
46 | } else if (d < best3) {
47 | best3 = d;
48 | besti3 = k;
49 | }
50 | }
51 | dist2[j * 3 + 0] = best1;
52 | dist2[j * 3 + 1] = best2;
53 | dist2[j * 3 + 2] = best3;
54 |
55 | idx[j * 3 + 0] = besti1;
56 | idx[j * 3 + 1] = besti2;
57 | idx[j * 3 + 2] = besti3;
58 | }
59 | }
60 |
61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62 | const float *known, float *dist2, int *idx) {
63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64 | three_nn_kernel<<>>(b, n, m, unknown, known,
65 | dist2, idx);
66 |
67 | CUDA_CHECK_ERRORS();
68 | }
69 |
70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
71 | // output: out(b, c, n)
72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
73 | const float *__restrict__ points,
74 | const int *__restrict__ idx,
75 | const float *__restrict__ weight,
76 | float *__restrict__ out) {
77 | int batch_index = blockIdx.x;
78 | points += batch_index * m * c;
79 |
80 | idx += batch_index * n * 3;
81 | weight += batch_index * n * 3;
82 |
83 | out += batch_index * n * c;
84 |
85 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
86 | const int stride = blockDim.y * blockDim.x;
87 | for (int i = index; i < c * n; i += stride) {
88 | const int l = i / n;
89 | const int j = i % n;
90 | float w1 = weight[j * 3 + 0];
91 | float w2 = weight[j * 3 + 1];
92 | float w3 = weight[j * 3 + 2];
93 |
94 | int i1 = idx[j * 3 + 0];
95 | int i2 = idx[j * 3 + 1];
96 | int i3 = idx[j * 3 + 2];
97 |
98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
99 | points[l * m + i3] * w3;
100 | }
101 | }
102 |
103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
104 | const float *points, const int *idx,
105 | const float *weight, float *out) {
106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107 | three_interpolate_kernel<<>>(
108 | b, c, m, n, points, idx, weight, out);
109 |
110 | CUDA_CHECK_ERRORS();
111 | }
112 |
113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
114 | // output: grad_points(b, c, m)
115 |
116 | __global__ void three_interpolate_grad_kernel(
117 | int b, int c, int n, int m, const float *__restrict__ grad_out,
118 | const int *__restrict__ idx, const float *__restrict__ weight,
119 | float *__restrict__ grad_points) {
120 | int batch_index = blockIdx.x;
121 | grad_out += batch_index * n * c;
122 | idx += batch_index * n * 3;
123 | weight += batch_index * n * 3;
124 | grad_points += batch_index * m * c;
125 |
126 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
127 | const int stride = blockDim.y * blockDim.x;
128 | for (int i = index; i < c * n; i += stride) {
129 | const int l = i / n;
130 | const int j = i % n;
131 | float w1 = weight[j * 3 + 0];
132 | float w2 = weight[j * 3 + 1];
133 | float w3 = weight[j * 3 + 2];
134 |
135 | int i1 = idx[j * 3 + 0];
136 | int i2 = idx[j * 3 + 1];
137 | int i3 = idx[j * 3 + 2];
138 |
139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
142 | }
143 | }
144 |
145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
146 | const float *grad_out,
147 | const int *idx, const float *weight,
148 | float *grad_points) {
149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150 | three_interpolate_grad_kernel<<>>(
151 | b, c, n, m, grad_out, idx, weight, grad_points);
152 |
153 | CUDA_CHECK_ERRORS();
154 | }
155 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "sampling.h"
2 | #include "utils.h"
3 |
4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5 | const float *points, const int *idx,
6 | float *out);
7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8 | const float *grad_out, const int *idx,
9 | float *grad_points);
10 |
11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12 | const float *dataset, float *temp,
13 | int *idxs);
14 |
15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16 | CHECK_CONTIGUOUS(points);
17 | CHECK_CONTIGUOUS(idx);
18 | CHECK_IS_FLOAT(points);
19 | CHECK_IS_INT(idx);
20 |
21 | if (points.is_cuda()) {
22 | CHECK_CUDA(idx);
23 | }
24 |
25 | at::Tensor output =
26 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
27 | at::device(points.device()).dtype(at::ScalarType::Float));
28 |
29 | if (points.is_cuda()) {
30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31 | idx.size(1), points.data_ptr(),
32 | idx.data_ptr(), output.data_ptr());
33 | } else {
34 | AT_ASSERT(false, "CPU not supported");
35 | }
36 |
37 | return output;
38 | }
39 |
40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41 | const int n) {
42 | CHECK_CONTIGUOUS(grad_out);
43 | CHECK_CONTIGUOUS(idx);
44 | CHECK_IS_FLOAT(grad_out);
45 | CHECK_IS_INT(idx);
46 |
47 | if (grad_out.is_cuda()) {
48 | CHECK_CUDA(idx);
49 | }
50 |
51 | at::Tensor output =
52 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
53 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
54 |
55 | if (grad_out.is_cuda()) {
56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57 | idx.size(1), grad_out.data_ptr(),
58 | idx.data_ptr(),
59 | output.data_ptr());
60 | } else {
61 | AT_ASSERT(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
67 | CHECK_CONTIGUOUS(points);
68 | CHECK_IS_FLOAT(points);
69 |
70 | at::Tensor output =
71 | torch::zeros({points.size(0), nsamples},
72 | at::device(points.device()).dtype(at::ScalarType::Int));
73 |
74 | at::Tensor tmp =
75 | torch::full({points.size(0), points.size(1)}, 1e10,
76 | at::device(points.device()).dtype(at::ScalarType::Float));
77 |
78 | if (points.is_cuda()) {
79 | furthest_point_sampling_kernel_wrapper(
80 | points.size(0), points.size(1), nsamples, points.data_ptr(),
81 | tmp.data_ptr(), output.data_ptr());
82 | } else {
83 | AT_ASSERT(false, "CPU not supported");
84 | }
85 |
86 | return output;
87 | }
88 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, m)
7 | // output: out(b, c, m)
8 | __global__ void gather_points_kernel(int b, int c, int n, int m,
9 | const float *__restrict__ points,
10 | const int *__restrict__ idx,
11 | float *__restrict__ out) {
12 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
13 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
14 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
15 | int a = idx[i * m + j];
16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
17 | }
18 | }
19 | }
20 | }
21 |
22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
23 | const float *points, const int *idx,
24 | float *out) {
25 | gather_points_kernel<<>>(b, c, n, npoints,
27 | points, idx, out);
28 |
29 | CUDA_CHECK_ERRORS();
30 | }
31 |
32 | // input: grad_out(b, c, m) idx(b, m)
33 | // output: grad_points(b, c, n)
34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
35 | const float *__restrict__ grad_out,
36 | const int *__restrict__ idx,
37 | float *__restrict__ grad_points) {
38 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
39 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
40 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
41 | int a = idx[i * m + j];
42 | atomicAdd(grad_points + (i * c + l) * n + a,
43 | grad_out[(i * c + l) * m + j]);
44 | }
45 | }
46 | }
47 | }
48 |
49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
50 | const float *grad_out, const int *idx,
51 | float *grad_points) {
52 | gather_points_grad_kernel<<>>(
54 | b, c, n, npoints, grad_out, idx, grad_points);
55 |
56 | CUDA_CHECK_ERRORS();
57 | }
58 |
59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
60 | int idx1, int idx2) {
61 | const float v1 = dists[idx1], v2 = dists[idx2];
62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
63 | dists[idx1] = max(v1, v2);
64 | dists_i[idx1] = v2 > v1 ? i2 : i1;
65 | }
66 |
67 | // Input dataset: (b, n, 3), tmp: (b, n)
68 | // Ouput idxs (b, m)
69 | template
70 | __global__ void furthest_point_sampling_kernel(
71 | int b, int n, int m, const float *__restrict__ dataset,
72 | float *__restrict__ temp, int *__restrict__ idxs) {
73 | if (m <= 0) return;
74 | __shared__ float dists[block_size];
75 | __shared__ int dists_i[block_size];
76 |
77 | int batch_index = blockIdx.x;
78 | dataset += batch_index * n * 3;
79 | temp += batch_index * n;
80 | idxs += batch_index * m;
81 |
82 | int tid = threadIdx.x;
83 | const int stride = block_size;
84 |
85 | int old = 0;
86 | if (threadIdx.x == 0) idxs[0] = old;
87 |
88 | __syncthreads();
89 | for (int j = 1; j < m; j++) {
90 | int besti = 0;
91 | float best = -1;
92 | float x1 = dataset[old * 3 + 0];
93 | float y1 = dataset[old * 3 + 1];
94 | float z1 = dataset[old * 3 + 2];
95 | for (int k = tid; k < n; k += stride) {
96 | float x2, y2, z2;
97 | x2 = dataset[k * 3 + 0];
98 | y2 = dataset[k * 3 + 1];
99 | z2 = dataset[k * 3 + 2];
100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
101 | if (mag <= 1e-3) continue;
102 |
103 | float d =
104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
105 |
106 | float d2 = min(d, temp[k]);
107 | temp[k] = d2;
108 | besti = d2 > best ? k : besti;
109 | best = d2 > best ? d2 : best;
110 | }
111 | dists[tid] = best;
112 | dists_i[tid] = besti;
113 | __syncthreads();
114 |
115 | if (block_size >= 512) {
116 | if (tid < 256) {
117 | __update(dists, dists_i, tid, tid + 256);
118 | }
119 | __syncthreads();
120 | }
121 | if (block_size >= 256) {
122 | if (tid < 128) {
123 | __update(dists, dists_i, tid, tid + 128);
124 | }
125 | __syncthreads();
126 | }
127 | if (block_size >= 128) {
128 | if (tid < 64) {
129 | __update(dists, dists_i, tid, tid + 64);
130 | }
131 | __syncthreads();
132 | }
133 | if (block_size >= 64) {
134 | if (tid < 32) {
135 | __update(dists, dists_i, tid, tid + 32);
136 | }
137 | __syncthreads();
138 | }
139 | if (block_size >= 32) {
140 | if (tid < 16) {
141 | __update(dists, dists_i, tid, tid + 16);
142 | }
143 | __syncthreads();
144 | }
145 | if (block_size >= 16) {
146 | if (tid < 8) {
147 | __update(dists, dists_i, tid, tid + 8);
148 | }
149 | __syncthreads();
150 | }
151 | if (block_size >= 8) {
152 | if (tid < 4) {
153 | __update(dists, dists_i, tid, tid + 4);
154 | }
155 | __syncthreads();
156 | }
157 | if (block_size >= 4) {
158 | if (tid < 2) {
159 | __update(dists, dists_i, tid, tid + 2);
160 | }
161 | __syncthreads();
162 | }
163 | if (block_size >= 2) {
164 | if (tid < 1) {
165 | __update(dists, dists_i, tid, tid + 1);
166 | }
167 | __syncthreads();
168 | }
169 |
170 | old = dists_i[0];
171 | if (tid == 0) idxs[j] = old;
172 | }
173 | }
174 |
175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
176 | const float *dataset, float *temp,
177 | int *idxs) {
178 | unsigned int n_threads = opt_n_threads(n);
179 |
180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
181 |
182 | switch (n_threads) {
183 | case 512:
184 | furthest_point_sampling_kernel<512>
185 | <<>>(b, n, m, dataset, temp, idxs);
186 | break;
187 | case 256:
188 | furthest_point_sampling_kernel<256>
189 | <<>>(b, n, m, dataset, temp, idxs);
190 | break;
191 | case 128:
192 | furthest_point_sampling_kernel<128>
193 | <<>>(b, n, m, dataset, temp, idxs);
194 | break;
195 | case 64:
196 | furthest_point_sampling_kernel<64>
197 | <<>>(b, n, m, dataset, temp, idxs);
198 | break;
199 | case 32:
200 | furthest_point_sampling_kernel<32>
201 | <<>>(b, n, m, dataset, temp, idxs);
202 | break;
203 | case 16:
204 | furthest_point_sampling_kernel<16>
205 | <<>>(b, n, m, dataset, temp, idxs);
206 | break;
207 | case 8:
208 | furthest_point_sampling_kernel<8>
209 | <<>>(b, n, m, dataset, temp, idxs);
210 | break;
211 | case 4:
212 | furthest_point_sampling_kernel<4>
213 | <<>>(b, n, m, dataset, temp, idxs);
214 | break;
215 | case 2:
216 | furthest_point_sampling_kernel<2>
217 | <<>>(b, n, m, dataset, temp, idxs);
218 | break;
219 | case 1:
220 | furthest_point_sampling_kernel<1>
221 | <<>>(b, n, m, dataset, temp, idxs);
222 | break;
223 | default:
224 | furthest_point_sampling_kernel<512>
225 | <<>>(b, n, m, dataset, temp, idxs);
226 | }
227 |
228 | CUDA_CHECK_ERRORS();
229 | }
230 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "3.0.0"
2 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from pointnet2_ops import pointnet2_utils
7 |
8 |
9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
10 | layers = []
11 | for i in range(1, len(mlp_spec)):
12 | layers.append(
13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
14 | )
15 | if bn:
16 | layers.append(nn.BatchNorm2d(mlp_spec[i]))
17 | layers.append(nn.ReLU(True))
18 |
19 | return nn.Sequential(*layers)
20 |
21 |
22 | class _PointnetSAModuleBase(nn.Module):
23 | def __init__(self):
24 | super(_PointnetSAModuleBase, self).__init__()
25 | self.npoint = None
26 | self.groupers = None
27 | self.mlps = None
28 |
29 | def forward(
30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor]
31 | ) -> Tuple[torch.Tensor, torch.Tensor]:
32 | r"""
33 | Parameters
34 | ----------
35 | xyz : torch.Tensor
36 | (B, N, 3) tensor of the xyz coordinates of the features
37 | features : torch.Tensor
38 | (B, C, N) tensor of the descriptors of the the features
39 |
40 | Returns
41 | -------
42 | new_xyz : torch.Tensor
43 | (B, npoint, 3) tensor of the new features' xyz
44 | new_features : torch.Tensor
45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
46 | """
47 |
48 | new_features_list = []
49 |
50 | xyz_flipped = xyz.transpose(1, 2).contiguous()
51 | new_xyz = (
52 | pointnet2_utils.gather_operation(
53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
54 | )
55 | .transpose(1, 2)
56 | .contiguous()
57 | if self.npoint is not None
58 | else None
59 | )
60 |
61 | for i in range(len(self.groupers)):
62 | new_features = self.groupers[i](
63 | xyz, new_xyz, features
64 | ) # (B, C, npoint, nsample)
65 |
66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
67 | new_features = F.max_pool2d(
68 | new_features, kernel_size=[1, new_features.size(3)]
69 | ) # (B, mlp[-1], npoint, 1)
70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
71 |
72 | new_features_list.append(new_features)
73 |
74 | return new_xyz, torch.cat(new_features_list, dim=1)
75 |
76 |
77 | class PointnetSAModuleMSG(_PointnetSAModuleBase):
78 | r"""Pointnet set abstrction layer with multiscale grouping
79 |
80 | Parameters
81 | ----------
82 | npoint : int
83 | Number of features
84 | radii : list of float32
85 | list of radii to group with
86 | nsamples : list of int32
87 | Number of samples in each ball query
88 | mlps : list of list of int32
89 | Spec of the pointnet before the global max_pool for each scale
90 | bn : bool
91 | Use batchnorm
92 | """
93 |
94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
96 | super(PointnetSAModuleMSG, self).__init__()
97 |
98 | assert len(radii) == len(nsamples) == len(mlps)
99 |
100 | self.npoint = npoint
101 | self.groupers = nn.ModuleList()
102 | self.mlps = nn.ModuleList()
103 | for i in range(len(radii)):
104 | radius = radii[i]
105 | nsample = nsamples[i]
106 | self.groupers.append(
107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
108 | if npoint is not None
109 | else pointnet2_utils.GroupAll(use_xyz)
110 | )
111 | mlp_spec = mlps[i]
112 | if use_xyz:
113 | mlp_spec[0] += 3
114 |
115 | self.mlps.append(build_shared_mlp(mlp_spec, bn))
116 |
117 |
118 | class PointnetSAModule(PointnetSAModuleMSG):
119 | r"""Pointnet set abstrction layer
120 |
121 | Parameters
122 | ----------
123 | npoint : int
124 | Number of features
125 | radius : float
126 | Radius of ball
127 | nsample : int
128 | Number of samples in the ball query
129 | mlp : list
130 | Spec of the pointnet before the global max_pool
131 | bn : bool
132 | Use batchnorm
133 | """
134 |
135 | def __init__(
136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
137 | ):
138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
139 | super(PointnetSAModule, self).__init__(
140 | mlps=[mlp],
141 | npoint=npoint,
142 | radii=[radius],
143 | nsamples=[nsample],
144 | bn=bn,
145 | use_xyz=use_xyz,
146 | )
147 |
148 |
149 | class PointnetFPModule(nn.Module):
150 | r"""Propigates the features of one set to another
151 |
152 | Parameters
153 | ----------
154 | mlp : list
155 | Pointnet module parameters
156 | bn : bool
157 | Use batchnorm
158 | """
159 |
160 | def __init__(self, mlp, bn=True):
161 | # type: (PointnetFPModule, List[int], bool) -> None
162 | super(PointnetFPModule, self).__init__()
163 | self.mlp = build_shared_mlp(mlp, bn=bn)
164 |
165 | def forward(self, unknown, known, unknow_feats, known_feats):
166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
167 | r"""
168 | Parameters
169 | ----------
170 | unknown : torch.Tensor
171 | (B, n, 3) tensor of the xyz positions of the unknown features
172 | known : torch.Tensor
173 | (B, m, 3) tensor of the xyz positions of the known features
174 | unknow_feats : torch.Tensor
175 | (B, C1, n) tensor of the features to be propigated to
176 | known_feats : torch.Tensor
177 | (B, C2, m) tensor of features to be propigated
178 |
179 | Returns
180 | -------
181 | new_features : torch.Tensor
182 | (B, mlp[-1], n) tensor of the features of the unknown features
183 | """
184 |
185 | if known is not None:
186 | dist, idx = pointnet2_utils.three_nn(unknown, known)
187 | dist_recip = 1.0 / (dist + 1e-8)
188 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
189 | weight = dist_recip / norm
190 |
191 | interpolated_feats = pointnet2_utils.three_interpolate(
192 | known_feats, idx, weight
193 | )
194 | else:
195 | interpolated_feats = known_feats.expand(
196 | *(known_feats.size()[0:2] + [unknown.size(1)])
197 | )
198 |
199 | if unknow_feats is not None:
200 | new_features = torch.cat(
201 | [interpolated_feats, unknow_feats], dim=1
202 | ) # (B, C2 + C1, n)
203 | else:
204 | new_features = interpolated_feats
205 |
206 | new_features = new_features.unsqueeze(-1)
207 | new_features = self.mlp(new_features)
208 |
209 | return new_features.squeeze(-1)
210 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import warnings
4 | from torch.autograd import Function
5 | from typing import *
6 |
7 | try:
8 | import pointnet2_ops._ext as _ext
9 | except ImportError:
10 | from torch.utils.cpp_extension import load
11 | import glob
12 | import os.path as osp
13 | import os
14 |
15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
16 |
17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
19 | osp.join(_ext_src_root, "src", "*.cu")
20 | )
21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
22 |
23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
24 | _ext = load(
25 | "_ext",
26 | sources=_ext_sources,
27 | extra_include_paths=[osp.join(_ext_src_root, "include")],
28 | extra_cflags=["-O3"],
29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
30 | with_cuda=True,
31 | )
32 |
33 |
34 | class FurthestPointSampling(Function):
35 | @staticmethod
36 | def forward(ctx, xyz, npoint):
37 | # type: (Any, torch.Tensor, int) -> torch.Tensor
38 | r"""
39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest
40 | minimum distance
41 |
42 | Parameters
43 | ----------
44 | xyz : torch.Tensor
45 | (B, N, 3) tensor where N > npoint
46 | npoint : int32
47 | number of features in the sampled set
48 |
49 | Returns
50 | -------
51 | torch.Tensor
52 | (B, npoint) tensor containing the set
53 | """
54 | out = _ext.furthest_point_sampling(xyz, npoint)
55 |
56 | ctx.mark_non_differentiable(out)
57 |
58 | return out
59 |
60 | @staticmethod
61 | def backward(ctx, grad_out):
62 | return ()
63 |
64 |
65 | furthest_point_sample = FurthestPointSampling.apply
66 |
67 |
68 | class GatherOperation(Function):
69 | @staticmethod
70 | def forward(ctx, features, idx):
71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
72 | r"""
73 |
74 | Parameters
75 | ----------
76 | features : torch.Tensor
77 | (B, C, N) tensor
78 |
79 | idx : torch.Tensor
80 | (B, npoint) tensor of the features to gather
81 |
82 | Returns
83 | -------
84 | torch.Tensor
85 | (B, C, npoint) tensor
86 | """
87 |
88 | ctx.save_for_backward(idx, features)
89 |
90 | return _ext.gather_points(features, idx)
91 |
92 | @staticmethod
93 | def backward(ctx, grad_out):
94 | idx, features = ctx.saved_tensors
95 | N = features.size(2)
96 |
97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
98 | return grad_features, None
99 |
100 |
101 | gather_operation = GatherOperation.apply
102 |
103 |
104 | class ThreeNN(Function):
105 | @staticmethod
106 | def forward(ctx, unknown, known):
107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
108 | r"""
109 | Find the three nearest neighbors of unknown in known
110 | Parameters
111 | ----------
112 | unknown : torch.Tensor
113 | (B, n, 3) tensor of known features
114 | known : torch.Tensor
115 | (B, m, 3) tensor of unknown features
116 |
117 | Returns
118 | -------
119 | dist : torch.Tensor
120 | (B, n, 3) l2 distance to the three nearest neighbors
121 | idx : torch.Tensor
122 | (B, n, 3) index of 3 nearest neighbors
123 | """
124 | dist2, idx = _ext.three_nn(unknown, known)
125 | dist = torch.sqrt(dist2)
126 |
127 | ctx.mark_non_differentiable(dist, idx)
128 |
129 | return dist, idx
130 |
131 | @staticmethod
132 | def backward(ctx, grad_dist, grad_idx):
133 | return ()
134 |
135 |
136 | three_nn = ThreeNN.apply
137 |
138 |
139 | class ThreeInterpolate(Function):
140 | @staticmethod
141 | def forward(ctx, features, idx, weight):
142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
143 | r"""
144 | Performs weight linear interpolation on 3 features
145 | Parameters
146 | ----------
147 | features : torch.Tensor
148 | (B, c, m) Features descriptors to be interpolated from
149 | idx : torch.Tensor
150 | (B, n, 3) three nearest neighbors of the target features in features
151 | weight : torch.Tensor
152 | (B, n, 3) weights
153 |
154 | Returns
155 | -------
156 | torch.Tensor
157 | (B, c, n) tensor of the interpolated features
158 | """
159 | ctx.save_for_backward(idx, weight, features)
160 |
161 | return _ext.three_interpolate(features, idx, weight)
162 |
163 | @staticmethod
164 | def backward(ctx, grad_out):
165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
166 | r"""
167 | Parameters
168 | ----------
169 | grad_out : torch.Tensor
170 | (B, c, n) tensor with gradients of ouputs
171 |
172 | Returns
173 | -------
174 | grad_features : torch.Tensor
175 | (B, c, m) tensor with gradients of features
176 |
177 | None
178 |
179 | None
180 | """
181 | idx, weight, features = ctx.saved_tensors
182 | m = features.size(2)
183 |
184 | grad_features = _ext.three_interpolate_grad(
185 | grad_out.contiguous(), idx, weight, m
186 | )
187 |
188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
189 |
190 |
191 | three_interpolate = ThreeInterpolate.apply
192 |
193 |
194 | class GroupingOperation(Function):
195 | @staticmethod
196 | def forward(ctx, features, idx):
197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
198 | r"""
199 |
200 | Parameters
201 | ----------
202 | features : torch.Tensor
203 | (B, C, N) tensor of features to group
204 | idx : torch.Tensor
205 | (B, npoint, nsample) tensor containing the indicies of features to group with
206 |
207 | Returns
208 | -------
209 | torch.Tensor
210 | (B, C, npoint, nsample) tensor
211 | """
212 | ctx.save_for_backward(idx, features)
213 |
214 | return _ext.group_points(features, idx)
215 |
216 | @staticmethod
217 | def backward(ctx, grad_out):
218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
219 | r"""
220 |
221 | Parameters
222 | ----------
223 | grad_out : torch.Tensor
224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward
225 |
226 | Returns
227 | -------
228 | torch.Tensor
229 | (B, C, N) gradient of the features
230 | None
231 | """
232 | idx, features = ctx.saved_tensors
233 | N = features.size(2)
234 |
235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
236 |
237 | return grad_features, torch.zeros_like(idx)
238 |
239 |
240 | grouping_operation = GroupingOperation.apply
241 |
242 |
243 | class BallQuery(Function):
244 | @staticmethod
245 | def forward(ctx, radius, nsample, xyz, new_xyz):
246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
247 | r"""
248 |
249 | Parameters
250 | ----------
251 | radius : float
252 | radius of the balls
253 | nsample : int
254 | maximum number of features in the balls
255 | xyz : torch.Tensor
256 | (B, N, 3) xyz coordinates of the features
257 | new_xyz : torch.Tensor
258 | (B, npoint, 3) centers of the ball query
259 |
260 | Returns
261 | -------
262 | torch.Tensor
263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls
264 | """
265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample)
266 |
267 | ctx.mark_non_differentiable(output)
268 |
269 | return output
270 |
271 | @staticmethod
272 | def backward(ctx, grad_out):
273 | return ()
274 |
275 |
276 | ball_query = BallQuery.apply
277 |
278 |
279 | class QueryAndGroup(nn.Module):
280 | r"""
281 | Groups with a ball query of radius
282 |
283 | Parameters
284 | ---------
285 | radius : float32
286 | Radius of ball
287 | nsample : int32
288 | Maximum number of features to gather in the ball
289 | """
290 |
291 | def __init__(self, radius, nsample, use_xyz=True):
292 | # type: (QueryAndGroup, float, int, bool) -> None
293 | super(QueryAndGroup, self).__init__()
294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
295 |
296 | def forward(self, xyz, new_xyz, features=None):
297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
298 | r"""
299 | Parameters
300 | ----------
301 | xyz : torch.Tensor
302 | xyz coordinates of the features (B, N, 3)
303 | new_xyz : torch.Tensor
304 | centriods (B, npoint, 3)
305 | features : torch.Tensor
306 | Descriptors of the features (B, C, N)
307 |
308 | Returns
309 | -------
310 | new_features : torch.Tensor
311 | (B, 3 + C, npoint, nsample) tensor
312 | """
313 |
314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
315 | xyz_trans = xyz.transpose(1, 2).contiguous()
316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
318 |
319 | if features is not None:
320 | grouped_features = grouping_operation(features, idx)
321 | if self.use_xyz:
322 | new_features = torch.cat(
323 | [grouped_xyz, grouped_features], dim=1
324 | ) # (B, C + 3, npoint, nsample)
325 | else:
326 | new_features = grouped_features
327 | else:
328 | assert (
329 | self.use_xyz
330 | ), "Cannot have not features and not use xyz as a feature!"
331 | new_features = grouped_xyz
332 |
333 | return new_features
334 |
335 |
336 | class GroupAll(nn.Module):
337 | r"""
338 | Groups all features
339 |
340 | Parameters
341 | ---------
342 | """
343 |
344 | def __init__(self, use_xyz=True):
345 | # type: (GroupAll, bool) -> None
346 | super(GroupAll, self).__init__()
347 | self.use_xyz = use_xyz
348 |
349 | def forward(self, xyz, new_xyz, features=None):
350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
351 | r"""
352 | Parameters
353 | ----------
354 | xyz : torch.Tensor
355 | xyz coordinates of the features (B, N, 3)
356 | new_xyz : torch.Tensor
357 | Ignored
358 | features : torch.Tensor
359 | Descriptors of the features (B, C, N)
360 |
361 | Returns
362 | -------
363 | new_features : torch.Tensor
364 | (B, C + 3, 1, N) tensor
365 | """
366 |
367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
368 | if features is not None:
369 | grouped_features = features.unsqueeze(2)
370 | if self.use_xyz:
371 | new_features = torch.cat(
372 | [grouped_xyz, grouped_features], dim=1
373 | ) # (B, 3 + C, 1, N)
374 | else:
375 | new_features = grouped_features
376 | else:
377 | new_features = grouped_xyz
378 |
379 | return new_features
380 |
--------------------------------------------------------------------------------
/pointnet2_ops_lib/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path as osp
4 |
5 | from setuptools import find_packages, setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | this_dir = osp.dirname(osp.abspath(__file__))
9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
11 | osp.join(_ext_src_root, "src", "*.cu")
12 | )
13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
14 |
15 | requirements = ["torch>=1.4"]
16 |
17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read())
18 |
19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
20 | setup(
21 | name="pointnet2_ops",
22 | version=__version__,
23 | author="Erik Wijmans",
24 | packages=find_packages(),
25 | install_requires=requirements,
26 | ext_modules=[
27 | CUDAExtension(
28 | name="pointnet2_ops._ext",
29 | sources=_ext_sources,
30 | extra_compile_args={
31 | "cxx": ["-O3"],
32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
33 | },
34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
35 | )
36 | ],
37 | cmdclass={"build_ext": BuildExtension},
38 | include_package_data=True,
39 | )
40 |
--------------------------------------------------------------------------------
/provider.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def normalize_data(batch_data):
4 | """ Normalize the batch data, use coordinates of the block centered at origin,
5 | Input:
6 | BxNxC array
7 | Output:
8 | BxNxC array
9 | """
10 | B, N, C = batch_data.shape
11 | normal_data = np.zeros((B, N, C))
12 | for b in range(B):
13 | pc = batch_data[b]
14 | centroid = np.mean(pc, axis=0)
15 | pc = pc - centroid
16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
17 | pc = pc / m
18 | normal_data[b] = pc
19 | return normal_data
20 |
21 |
22 | def shuffle_data(data, labels):
23 | """ Shuffle data and labels.
24 | Input:
25 | data: B,N,... numpy array
26 | label: B,... numpy array
27 | Return:
28 | shuffled data, label and shuffle indices
29 | """
30 | idx = np.arange(len(labels))
31 | np.random.shuffle(idx)
32 | return data[idx, ...], labels[idx], idx
33 |
34 |
35 | def shuffle_points(batch_data):
36 | """ Shuffle orders of points in each point cloud -- changes FPS behavior.
37 | Use the same shuffling idx for the entire batch.
38 | Input:
39 | BxNxC array
40 | Output:
41 | BxNxC array
42 | """
43 | idx = np.arange(batch_data.shape[1])
44 | np.random.shuffle(idx)
45 | return batch_data[:, idx, :]
46 |
47 |
48 | def rotate_point_cloud(batch_data):
49 | """ Randomly rotate the point clouds to augument the dataset
50 | rotation is per shape based along up direction
51 | Input:
52 | BxNx3 array, original batch of point clouds
53 | Return:
54 | BxNx3 array, rotated batch of point clouds
55 | """
56 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
57 | for k in range(batch_data.shape[0]):
58 | rotation_angle = np.random.uniform() * 2 * np.pi
59 | cosval = np.cos(rotation_angle)
60 | sinval = np.sin(rotation_angle)
61 | rotation_matrix = np.array([[cosval, 0, sinval],
62 | [0, 1, 0],
63 | [-sinval, 0, cosval]])
64 | shape_pc = batch_data[k, ...]
65 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
66 | return rotated_data
67 |
68 |
69 | def rotate_point_cloud_z(batch_data):
70 | """ Randomly rotate the point clouds to augument the dataset
71 | rotation is per shape based along up direction
72 | Input:
73 | BxNx3 array, original batch of point clouds
74 | Return:
75 | BxNx3 array, rotated batch of point clouds
76 | """
77 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
78 | for k in range(batch_data.shape[0]):
79 | rotation_angle = np.random.uniform() * 2 * np.pi
80 | cosval = np.cos(rotation_angle)
81 | sinval = np.sin(rotation_angle)
82 | rotation_matrix = np.array([[cosval, sinval, 0],
83 | [-sinval, cosval, 0],
84 | [0, 0, 1]])
85 | shape_pc = batch_data[k, ...]
86 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
87 | return rotated_data
88 |
89 |
90 | def rotate_point_cloud_with_normal(batch_xyz_normal):
91 | ''' Randomly rotate XYZ, normal point cloud.
92 | Input:
93 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
94 | Output:
95 | B,N,6, rotated XYZ, normal point cloud
96 | '''
97 | for k in range(batch_xyz_normal.shape[0]):
98 | rotation_angle = np.random.uniform() * 2 * np.pi
99 | cosval = np.cos(rotation_angle)
100 | sinval = np.sin(rotation_angle)
101 | rotation_matrix = np.array([[cosval, 0, sinval],
102 | [0, 1, 0],
103 | [-sinval, 0, cosval]])
104 | shape_pc = batch_xyz_normal[k, :, 0:3]
105 | shape_normal = batch_xyz_normal[k, :, 3:6]
106 | batch_xyz_normal[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
107 | batch_xyz_normal[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
108 | return batch_xyz_normal
109 |
110 |
111 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
112 | """ Randomly perturb the point clouds by small rotations
113 | Input:
114 | BxNx6 array, original batch of point clouds and point normals
115 | Return:
116 | BxNx3 array, rotated batch of point clouds
117 | """
118 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
119 | for k in range(batch_data.shape[0]):
120 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
121 | Rx = np.array([[1, 0, 0],
122 | [0, np.cos(angles[0]), -np.sin(angles[0])],
123 | [0, np.sin(angles[0]), np.cos(angles[0])]])
124 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
125 | [0, 1, 0],
126 | [-np.sin(angles[1]), 0, np.cos(angles[1])]])
127 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
128 | [np.sin(angles[2]), np.cos(angles[2]), 0],
129 | [0, 0, 1]])
130 | R = np.dot(Rz, np.dot(Ry, Rx))
131 | shape_pc = batch_data[k, :, 0:3]
132 | shape_normal = batch_data[k, :, 3:6]
133 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
134 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
135 | return rotated_data
136 |
137 |
138 | def rotate_point_cloud_by_angle(batch_data, rotation_angle):
139 | """ Rotate the point cloud along up direction with certain angle.
140 | Input:
141 | BxNx3 array, original batch of point clouds
142 | Return:
143 | BxNx3 array, rotated batch of point clouds
144 | """
145 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
146 | for k in range(batch_data.shape[0]):
147 | # rotation_angle = np.random.uniform() * 2 * np.pi
148 | cosval = np.cos(rotation_angle)
149 | sinval = np.sin(rotation_angle)
150 | rotation_matrix = np.array([[cosval, 0, sinval],
151 | [0, 1, 0],
152 | [-sinval, 0, cosval]])
153 | shape_pc = batch_data[k, :, 0:3]
154 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
155 | return rotated_data
156 |
157 |
158 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
159 | """ Rotate the point cloud along up direction with certain angle.
160 | Input:
161 | BxNx6 array, original batch of point clouds with normal
162 | scalar, angle of rotation
163 | Return:
164 | BxNx6 array, rotated batch of point clouds iwth normal
165 | """
166 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
167 | for k in range(batch_data.shape[0]):
168 | # rotation_angle = np.random.uniform() * 2 * np.pi
169 | cosval = np.cos(rotation_angle)
170 | sinval = np.sin(rotation_angle)
171 | rotation_matrix = np.array([[cosval, 0, sinval],
172 | [0, 1, 0],
173 | [-sinval, 0, cosval]])
174 | shape_pc = batch_data[k, :, 0:3]
175 | shape_normal = batch_data[k, :, 3:6]
176 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
177 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
178 | return rotated_data
179 |
180 |
181 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
182 | """ Randomly perturb the point clouds by small rotations
183 | Input:
184 | BxNx3 array, original batch of point clouds
185 | Return:
186 | BxNx3 array, rotated batch of point clouds
187 | """
188 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
189 | for k in range(batch_data.shape[0]):
190 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
191 | Rx = np.array([[1, 0, 0],
192 | [0, np.cos(angles[0]), -np.sin(angles[0])],
193 | [0, np.sin(angles[0]), np.cos(angles[0])]])
194 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
195 | [0, 1, 0],
196 | [-np.sin(angles[1]), 0, np.cos(angles[1])]])
197 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
198 | [np.sin(angles[2]), np.cos(angles[2]), 0],
199 | [0, 0, 1]])
200 | R = np.dot(Rz, np.dot(Ry, Rx))
201 | shape_pc = batch_data[k, ...]
202 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
203 | return rotated_data
204 |
205 |
206 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
207 | """ Randomly jitter points. jittering is per point.
208 | Input:
209 | BxNx3 array, original batch of point clouds
210 | Return:
211 | BxNx3 array, jittered batch of point clouds
212 | """
213 | B, N, C = batch_data.shape
214 | assert (clip > 0)
215 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
216 | jittered_data += batch_data
217 | return jittered_data
218 |
219 |
220 | def shift_point_cloud(batch_data, shift_range=0.1):
221 | """ Randomly shift point cloud. Shift is per point cloud.
222 | Input:
223 | BxNx3 array, original batch of point clouds
224 | Return:
225 | BxNx3 array, shifted batch of point clouds
226 | """
227 | B, N, C = batch_data.shape
228 | shifts = np.random.uniform(-shift_range, shift_range, (B, 3))
229 | for batch_index in range(B):
230 | batch_data[batch_index, :, :] += shifts[batch_index, :]
231 | return batch_data
232 |
233 |
234 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
235 | """ Randomly scale the point cloud. Scale is per point cloud.
236 | Input:
237 | BxNx3 array, original batch of point clouds
238 | Return:
239 | BxNx3 array, scaled batch of point clouds
240 | """
241 | B, N, C = batch_data.shape
242 | scales = np.random.uniform(scale_low, scale_high, B)
243 | for batch_index in range(B):
244 | batch_data[batch_index, :, :] *= scales[batch_index]
245 | return batch_data
246 |
247 |
248 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
249 | ''' batch_pc: BxNx3 '''
250 | for b in range(batch_pc.shape[0]):
251 | dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875
252 | drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0]
253 | if len(drop_idx) > 0:
254 | batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point
255 | return batch_pc
256 |
--------------------------------------------------------------------------------
/test_correspondence.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import importlib
3 | import sys
4 | import torch
5 | import os
6 | import argparse
7 | import utils.check_points_utils as checkpoint_util
8 | import numpy as np
9 | from tqdm import tqdm
10 | from data_utils.keypointnet_dataloader import KeyPointNetDataLoader
11 | from models.torch_pointnet_utils import knn_point
12 |
13 |
14 | class PointcloudJitter(object):
15 | def __init__(self, std=0.01, clip=0.05):
16 | self.std, self.clip = std, clip
17 |
18 | def __call__(self, points):
19 | jittered_data = (
20 | points.new(points.size(0), 3)
21 | .normal_(mean=0.0, std=self.std)
22 | .clamp_(-self.clip, self.clip)
23 | )
24 | points[:, 0:3] += jittered_data
25 | return points
26 |
27 |
28 | def main(args):
29 | experiment_dir = 'log/' + args.log_dir
30 | sys.path.append(experiment_dir)
31 | model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
32 |
33 | epoch, iters, checkpoint = checkpoint_util.load_checkpoint(model_3d=None, filename=str(
34 | experiment_dir) + '/checkpoints/' + args.model)
35 | category = checkpoint['category']
36 | num_structure_points = checkpoint['num_structure_points']
37 | multi_distribution = checkpoint['multi_distribution']
38 | offset = checkpoint['offset']
39 |
40 | model = importlib.import_module(model_name)
41 | model = model.Pointnet2StructurePointNet(num_structure_points=num_structure_points, input_channels=0,
42 | multi_distribution_num=multi_distribution,
43 | offset=offset)
44 | model.load_state_dict(checkpoint['model_state_3d'])
45 |
46 | model.cuda()
47 | model.eval()
48 |
49 | if os.path.exists(args.output_dir) is False:
50 | os.makedirs(args.output_dir)
51 |
52 | test_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, category + '.json'),
53 | pcd_path=args.pcd_path, split='val')
54 |
55 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False,
56 | num_workers=2, pin_memory=True,
57 | persistent_workers=True)
58 |
59 | thresholds = np.linspace(0., 1)
60 |
61 | model_datas = []
62 | point_cloud_jitter = PointcloudJitter(std=args.gauss_rate, clip=0.1)
63 |
64 | for batch_id, (batch_points, key_points, key_points_num, _) in tqdm(enumerate(testDataLoader, 0),
65 | total=len(testDataLoader),
66 | smoothing=0.9):
67 |
68 | with torch.no_grad():
69 | batch_points_jitter = torch.Tensor(batch_points)
70 | if not args.gauss_rate == 0:
71 | for pc in range(batch_points.shape[0]):
72 | batch_points_jitter[pc] = point_cloud_jitter(batch_points[pc])
73 |
74 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points_jitter.cuda())
75 | for i in range(0, batch_points.shape[0]):
76 | diameter_shape = torch.sqrt(torch.sum((torch.max(batch_points[i], dim=0)[0] - torch.min(batch_points[i], dim=0)[0]) ** 2))
77 |
78 | model_datas.append({'structure_pts': structure_points[i] + (torch.rand(structure_points[i].shape) * args.noise_rate).cuda(),
79 | 'gt_feat_pts': key_points[i][:key_points_num[i], :].cuda(), 'diameter_shape': diameter_shape})
80 |
81 | dis_ratios, dis_thresholds = compute_correspondence_accuracy(model_datas)
82 |
83 | if not os.path.exists('corrs_model/'):
84 | os.mkdir('corrs_model/')
85 | np.savez('corrs_model/' + args.corres_name, dis_ratios, dis_thresholds)
86 |
87 |
88 | def compute_correspondence_dis(model_data_a, model_data_b):
89 | structure_pts_a = model_data_a['structure_pts']
90 | gt_feat_pts_a = model_data_a['gt_feat_pts']
91 | structure_pts_b = model_data_b['structure_pts']
92 | gt_feat_pts_b = model_data_b['gt_feat_pts']
93 | diameter_shape_b = model_data_b['diameter_shape']
94 | res_dis = []
95 |
96 | if not gt_feat_pts_a.shape[0] == gt_feat_pts_b.shape[0]:
97 | return res_dis
98 |
99 | # knn_a_idxs, knn_a_dis = query_KNN_tensor(structure_pts_a, gt_feat_pts_a, 1)
100 | knn_a_idxs = knn_point(1, structure_pts_a[None, :, :], gt_feat_pts_a[None, :, :])
101 |
102 | corres_pts_in_b = structure_pts_b[knn_a_idxs[0, :, 0], :]
103 | diff = corres_pts_in_b - gt_feat_pts_b
104 | tmp_dis = torch.sqrt(torch.sum(diff * diff, dim=1)) / diameter_shape_b
105 |
106 | for i in range(tmp_dis.shape[0]):
107 | # nan means this feature point is missing on groundtruth model
108 | if torch.isnan(gt_feat_pts_a[i, 0]) == False and torch.isnan(gt_feat_pts_b[i, 0]) == False:
109 | res_dis.append(tmp_dis[i].item())
110 |
111 | return res_dis
112 |
113 |
114 | def compute_correspondence_accuracy(model_datas):
115 | dis_list = []
116 | for i in tqdm(range(len(model_datas)), total=len(model_datas)):
117 | for j in range(len(model_datas)):
118 | if i == j:
119 | continue
120 | model_data_i = model_datas[i]
121 | model_data_j = model_datas[j]
122 | corres_dis = compute_correspondence_dis(model_data_i, model_data_j)
123 | dis_list = dis_list + corres_dis
124 |
125 | dis_array = np.array(dis_list)
126 |
127 | dis_thresholds = np.arange(0, 0.26, 0.01)
128 | dis_ratios = []
129 |
130 | for i in range(dis_thresholds.shape[0]):
131 | threshold = dis_thresholds[i]
132 | ratio = dis_array[dis_array <= threshold].shape[0] / dis_array.shape[0]
133 | dis_ratios.append(ratio)
134 |
135 | dis_ratios = np.array(dis_ratios)
136 |
137 | return dis_ratios, dis_thresholds
138 |
139 |
140 | if __name__ == '__main__':
141 | parser = argparse.ArgumentParser(description="Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
142 | parser.add_argument("-num_inputs", type=int, default=1024, help="sample points from initial point cloud")
143 | parser.add_argument("-log_dir", type=str, default='4.8.2', help="path to the trained model log")
144 | parser.add_argument("-model", type=str, default='model_min_test_loss',
145 | help="the trained model[default: model_min_test_loss]")
146 | parser.add_argument("-output_dir", type=str, default='out', help="output dir")
147 | parser.add_argument('-prediction_output', type=str, default='merger_prediction.npz',
148 | help='Output file where prediction results are written.')
149 | parser.add_argument('-pcd_path', type=str, default='./keypointnet/pcds',
150 | help='Point cloud file folder path from KeypointNet dataset.')
151 | parser.add_argument('-json_path', default='./keypointnet/annotations/', help='')
152 | parser.add_argument('-gauss_rate', type=float, default=0, help='')
153 | parser.add_argument('-noise_rate', type=float, default=0, help='')
154 | parser.add_argument('-corres_name', type=str, help='')
155 |
156 | args = parser.parse_args()
157 | main(args)
158 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import shutil
3 | import torch
4 | import torch.optim as optim
5 | from torch.utils.data import DataLoader
6 | import os
7 | import argparse
8 | import gc
9 | import importlib
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | ROOT_DIR = BASE_DIR
13 | sys.path.append(os.path.join(ROOT_DIR, 'models'))
14 |
15 | import utils.check_points_utils as checkpoint_util
16 | from tqdm import tqdm
17 | import logging
18 | import datetime
19 | from pathlib import Path
20 | import provider
21 | from data_utils.dataset import Dataset
22 | from data_utils.keypointnet_dataloader import KeyPointNetDataLoader
23 |
24 | torch.backends.cudnn.enabled = True
25 | torch.backends.cudnn.benchmark = True
26 | scaler = torch.cuda.amp.GradScaler()
27 | autocast = torch.cuda.amp.autocast
28 |
29 |
30 | def train_one_epoch(model, optimizer, data_loader, current_iter, criterions, lr_scheduler, num_of_trans,
31 | num_inputs, logger):
32 | model.train()
33 | loss_dict = {}
34 | loss_dict['Loss'] = 0
35 | for l in criterions.keys():
36 | loss_dict[l] = 0
37 | count = 0
38 |
39 | for batch_id, (batch_points, _, _, _) in tqdm(enumerate(data_loader, 0), total=len(data_loader),
40 | smoothing=0.9):
41 | optimizer.zero_grad()
42 | # print(batch_points.shape)
43 | batch_points = batch_points.data.numpy()
44 | batch_points[:, :, 0:3] = provider.random_scale_point_cloud(batch_points[:, :, 0:3])
45 | batch_points[:, :, 0:3] = provider.shift_point_cloud(batch_points[:, :, 0:3])
46 | batch_points = torch.Tensor(batch_points)
47 | batch_points = batch_points.cuda()
48 |
49 | if args.use_half:
50 | with autocast():
51 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points)
52 |
53 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points)
54 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points)
55 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85)
56 | loss = ComputeLoss3dLoss+WeightedChamferLoss
57 |
58 | scaler.scale(loss).backward()
59 | scaler.step(optimizer)
60 | scaler.update()
61 | else:
62 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points)
63 |
64 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points)
65 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points)
66 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85)
67 | loss = ComputeLoss3dLoss+WeightedChamferLoss
68 |
69 | loss.backward()
70 | optimizer.step()
71 |
72 | current_iter += 1
73 | loss_dict['Loss'] += loss.item()
74 | loss_dict['ComputeLoss3d'] += ComputeLoss3dLoss.item()
75 | loss_dict['WeightedChamferLoss'] += WeightedChamferLoss.item()
76 | # loss_dict['VecLoss'] += loss_Vec.item()
77 |
78 | current_iter += 1
79 | # gc.collect()
80 | count += 1
81 |
82 | lr_scheduler.step()
83 | for k in loss_dict.keys():
84 | loss_dict[k] /= count
85 |
86 | return loss_dict, current_iter
87 |
88 |
89 | def test(model, data_loader, criterions):
90 | model.eval()
91 | count = 0
92 | loss_dict = {}
93 | loss_dict['Loss'] = 0
94 | for l in criterions.keys():
95 | loss_dict[l] = 0
96 | for batch_id, (batch_points, _, _, _) in tqdm(enumerate(data_loader, 0), total=len(data_loader),
97 | smoothing=0.9):
98 | batch_points = batch_points.cuda()
99 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points)
100 |
101 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points)
102 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points)
103 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85)
104 | loss = WeightedChamferLoss
105 |
106 | loss_dict['Loss'] += loss.item()
107 | loss_dict['ComputeLoss3d'] += ComputeLoss3dLoss.item()
108 | loss_dict['WeightedChamferLoss'] += WeightedChamferLoss.item()
109 | # loss_dict['VecLoss'] += loss_Vec.item()
110 |
111 | count += 1
112 |
113 | for k in loss_dict.keys():
114 | loss_dict[k] /= count
115 | return loss_dict
116 |
117 |
118 | def create_loggger(args):
119 | '''CREATE DIR'''
120 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
121 | exp_dir = Path('./log/')
122 | exp_dir.mkdir(exist_ok=True)
123 | # exp_dir = exp_dir.joinpath('')
124 | # exp_dir.mkdir(exist_ok=True)
125 | if args.log_dir is None:
126 | exp_dir = exp_dir.joinpath(timestr)
127 | else:
128 | exp_dir = exp_dir.joinpath(args.log_dir)
129 | exp_dir.mkdir(exist_ok=True)
130 | checkpoints_dir = exp_dir.joinpath('checkpoints/')
131 | checkpoints_dir.mkdir(exist_ok=True)
132 | log_dir = exp_dir.joinpath('logs/')
133 | log_dir.mkdir(exist_ok=True)
134 |
135 | '''LOG'''
136 | logger = logging.getLogger("Model")
137 | logger.setLevel(logging.INFO)
138 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
139 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model), mode='w+')
140 | file_handler.setLevel(logging.INFO)
141 | file_handler.setFormatter(formatter)
142 | logger.addHandler(file_handler)
143 |
144 | return logger, checkpoints_dir, exp_dir
145 |
146 |
147 | def log_string(str, logger):
148 | logger.info(str)
149 | print(str)
150 |
151 |
152 | def train(args):
153 | # lr_clip = 1e-5
154 | # bnm_clip = 1e-2
155 | torch.cuda.empty_cache()
156 | logger, checkpoints_dir, exp_dir = create_loggger(args)
157 | log_string('PARAMETER ...', logger=logger)
158 | log_string(args, logger=logger)
159 |
160 | '''DATA LOADING'''
161 | log_string('Load dataset ...', logger=logger)
162 |
163 | # train_dataset = bhcp_dataloader(args.data_path, args.category, is_pts_aligned=False, split='train')
164 | # test_dataset = bhcp_dataloader(args.data_path, args.category, is_pts_aligned=False, split='test')
165 | # train_dataset = KeyPointNetDataLoader(json_path=cmd_args.json_path, pcd_path=cmd_args.pcd_path, split='train')
166 | # test_dataset = KeyPointNetDataLoader(json_path=cmd_args.json_path, pcd_path=cmd_args.pcd_path, split='val')
167 |
168 | if args.dataset_name == 'keypointnet':
169 | train_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, args.category + '.json'),
170 | pcd_path=args.pcd_path, split='train')
171 | test_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, args.category + '.json'),
172 | pcd_path=args.pcd_path, split='val')
173 | else:
174 | train_dataset = Dataset(root=args.data_path, dataset_name=args.dataset_name, class_choice=args.category,
175 | num_points=args.num_inputs, split='train',
176 | segmentation=args.segmentation)
177 | test_dataset = Dataset(root=args.data_path, dataset_name=args.dataset_name, class_choice=args.category,
178 | num_points=args.num_inputs, split='test',
179 | segmentation=args.segmentation)
180 |
181 | trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
182 | num_workers=args.num_workers, drop_last=True, pin_memory=True,
183 | persistent_workers=True)
184 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
185 | num_workers=args.num_workers, pin_memory=True, persistent_workers=True)
186 |
187 | shutil.copy('./models/%s.py' % args.model, str(exp_dir))
188 | shutil.copy('./models/chamfer_distance.py', str(exp_dir))
189 | shutil.copy('./models/torch_pointnet_utils.py', str(exp_dir))
190 | shutil.copy('./train.py', str(exp_dir))
191 |
192 | '''MODEL LOADING'''
193 | model = importlib.import_module(args.model)
194 | criterions = {'ComputeLoss3d': model.ComputeLoss3d(), 'WeightedChamferLoss': model.WeightedChamferLoss(),
195 | 'VecLoss': model.VecLoss()}
196 | # criterions = {'ComputeLoss3d': model.ComputeLoss3d(), 'VecLoss': model.VecLoss()}
197 |
198 | model = model.Pointnet2StructurePointNet(num_structure_points=args.num_structure_points, input_channels=0,
199 | multi_distribution_num=args.multi_distribution,
200 | offset=args.offset)
201 |
202 | model.cuda()
203 |
204 | optimizer = optim.Adam(
205 | model.parameters(),
206 | lr=args.lr,
207 | betas=(0.9, 0.999),
208 | eps=1e-08,
209 | weight_decay=args.weight_decay
210 | )
211 |
212 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_batch, gamma=args.lr_decay)
213 |
214 | iters = -1
215 | min_test_loss = float('inf')
216 |
217 | # load status from checkpoint
218 | start_epoch = 0
219 | if args.checkpoint is not None:
220 | start_epoch, iters = checkpoint_util.load_checkpoint(model_3d=model, optimizer=optimizer,
221 | filename=args.checkpoint)
222 | start_epoch += 1
223 |
224 | log_string('Start Training Unsupervised Structure Points for %s...' % args.dataset_name, logger=logger)
225 | iters = max(iters, 0)
226 |
227 | for epoch_i in range(start_epoch, args.max_epochs):
228 | log_string('-------------------------------------------', logger=logger)
229 | log_string('Epoch %d/%s,Learning Rate %f:' % (epoch_i + 1, args.max_epochs, lr_scheduler.get_last_lr()[0]),
230 | logger=logger)
231 |
232 | loss_dict, iters = train_one_epoch(model,
233 | optimizer,
234 | trainDataLoader,
235 | iters,
236 | criterions,
237 | lr_scheduler,
238 | num_of_trans=args.num_of_transform,
239 | num_inputs=args.num_inputs,
240 | logger=logger)
241 | loss_str = ''
242 | for i in loss_dict.keys():
243 | loss_str += '%s: %f \t' % (i, loss_dict[i])
244 | log_string(loss_str, logger)
245 |
246 | with torch.no_grad():
247 | loss_dict = test(model,
248 | data_loader=testDataLoader,
249 | criterions=criterions,
250 | )
251 | loss_str = ''
252 | for i in loss_dict.keys():
253 | loss_str += '%s: %f \t' % (i, loss_dict[i])
254 | log_string(loss_str, logger)
255 |
256 | if loss_dict['Loss'] < min_test_loss:
257 | min_test_loss = loss_dict['Loss']
258 | log_string('Min Test Loss: %f' % (loss_dict['Loss']), logger=logger)
259 |
260 | log_string('Save model...', logger=logger)
261 | fname = os.path.join(checkpoints_dir, 'model_min_test_loss')
262 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters,
263 | epoch=epoch_i, category=args.category,
264 | num_structure_points=args.num_structure_points,
265 | multi_distribution=args.multi_distribution,
266 | offset=args.offset)
267 | else:
268 | log_string('Min Test Loss: %f' % (min_test_loss), logger=logger)
269 |
270 | if (epoch_i + 1) % 50 == 0:
271 | fname = os.path.join(checkpoints_dir, 'model_%d' % (epoch_i + 1))
272 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters,
273 | epoch=epoch_i, category=args.category,
274 | num_structure_points=args.num_structure_points,
275 | multi_distribution=args.multi_distribution,
276 | offset=args.offset)
277 | fname = os.path.join(checkpoints_dir, 'model')
278 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters,
279 | epoch=epoch_i, category=args.category,
280 | num_structure_points=args.num_structure_points,
281 | multi_distribution=args.multi_distribution,
282 | offset=args.offset)
283 |
284 |
285 | def parse_args():
286 | parser = argparse.ArgumentParser(description="Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
287 | parser.add_argument("-batch_size", type=int, default=36, help="Batch size")
288 | parser.add_argument("-weight_decay", type=float, default=1e-5, help="L2 regularization coeff")
289 | parser.add_argument("-num_inputs", type=int, default=1024, help="sample points from initial point cloud")
290 | parser.add_argument("-num_structure_points", type=int, default=12, help="Number of structure points")
291 | parser.add_argument("-category", type=str, default='laptop', help="Category of the objects to train")
292 | parser.add_argument("-dataset_name", type=str, default='shapenetpart', help="keypointnet,shapenetpart")
293 | parser.add_argument("-data_path", type=str, default='../', help="")
294 | parser.add_argument('-segmentation', action='store_true', default=False, help='')
295 | parser.add_argument('-offset', action='store_true', default=False, help='')
296 | parser.add_argument("-max_epochs", type=int, default=100, help="Number of epochs to train for")
297 | parser.add_argument("-log_dir", type=str, default=None, help="Root of the log")
298 | parser.add_argument("-multi_distribution", type=int, default=3, help="Multivariate normal distribution nums")
299 | parser.add_argument('-num_workers', type=int, default=4, help='dataload num worker')
300 | parser.add_argument('-model', default='model_weightchamfer', help='model name [default: model_weightchamfer Structure_pointnet]')
301 | parser.add_argument('-use_half', action='store_true', default=True, help='use mix half mode')
302 | parser.add_argument('-json_path', default='./keypointnet/annotations/', help='')
303 | parser.add_argument('-pcd_path', type=str, default='./keypointnet/pcds',
304 | help='Point cloud file folder path from KeypointNet dataset.')
305 | parser.add_argument("-lr", type=float, default=1e-3, help="Initial learning rate")
306 | parser.add_argument("-lr_decay", type=float, default=0.7, help="Learning rate decay gamma")
307 | parser.add_argument("-decay_batch", type=float, default=20, help="Learning rate decay batch")
308 | parser.add_argument("-bn_momentum", type=float, default=0.5, help="Initial batch norm momentum")
309 | parser.add_argument("-bnm_decay", type=float, default=0.5, help="Batch norm momentum decay gamma")
310 | parser.add_argument("-checkpoint_save_step", type=int, default=50, help="Step for saving Checkpoint")
311 | parser.add_argument("-checkpoint", type=str, default=None, help="Checkpoint to start from")
312 | parser.add_argument("-num_of_transform", type=int, default=0,
313 | help="Number of transforms for rotation data augmentation. Useful when testing on shapes without alignment")
314 | args = parser.parse_args()
315 | return args
316 |
317 |
318 | if __name__ == "__main__":
319 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
320 | args = parse_args()
321 | import platform
322 |
323 | sys = platform.system()
324 | if sys == "Windows":
325 | args.batch_size = 2
326 |
327 | train(args=args)
328 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import (
2 | division,
3 | absolute_import,
4 | with_statement,
5 | print_function,
6 | unicode_literals,
7 | )
8 | import sys
9 | sys.path.append("..")
--------------------------------------------------------------------------------
/utils/check_points_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 |
5 | def save_checkpoint(filename, model_3d=None, model_2d=None, optimizer=None, iters=None, epoch=None, meta_data=None, category=None, num_structure_points=None, multi_distribution=None,offset=None):
6 | print("save checkpoint '{}'".format(filename))
7 | optim_state = optimizer.state_dict() if optimizer is not None else None
8 | if model_3d is not None:
9 | if isinstance(model_3d, torch.nn.DataParallel):
10 | model_state_3d = model_3d.module.state_dict()
11 | else:
12 | model_state_3d = model_3d.state_dict()
13 | else:
14 | model_state_3d = None
15 |
16 | if model_2d is not None:
17 | if isinstance(model_2d, torch.nn.DataParallel):
18 | model_state_2d = model_2d.module.state_dict()
19 | else:
20 | model_state_2d = model_2d.state_dict()
21 | else:
22 | model_state_2d = None
23 |
24 | state = {
25 | 'iter': iters,
26 | 'epoch': epoch,
27 | 'model_state_2d': model_state_2d,
28 | 'model_state_3d': model_state_3d,
29 | 'optimizer_state': optim_state,
30 | 'meta_data': meta_data,
31 | 'category': category,
32 | 'num_structure_points': num_structure_points,
33 | 'multi_distribution': multi_distribution,
34 | 'offset':offset
35 | }
36 | torch.save(state, filename)
37 |
38 |
39 | def load_checkpoint(filename, model_3d=None, optimizer=None, meta_data=None):
40 | if os.path.isfile(filename):
41 | checkpoint = torch.load(filename)
42 | iters = checkpoint.get('iter', 0.0)
43 | epoch = checkpoint['epoch']
44 |
45 | if model_3d is not None and checkpoint['model_state_3d'] is not None:
46 | model_3d.load_state_dict(checkpoint['model_state_3d'])
47 | if optimizer is not None and checkpoint['optimizer_state'] is not None:
48 | optimizer.load_state_dict(checkpoint['optimizer_state'])
49 |
50 | if meta_data is not None and 'meta_data' in checkpoint:
51 | for key in checkpoint['meta_data']:
52 | meta_data[key] = checkpoint['meta_data'][key]
53 | return epoch, iters, checkpoint
54 | else:
55 | print("==> Checkpoint '{}' not found".format(filename))
56 | return None
57 |
--------------------------------------------------------------------------------
/utils/logutils.py:
--------------------------------------------------------------------------------
1 | class LogUtils():
2 | def __init__(self, fname, filemode):
3 | self.logf = open(fname, filemode)
4 |
5 | def write(self, text, need_display=True):
6 | if need_display is True:
7 | print(text)
8 |
9 | self.logf.write(text + '\n')
10 | self.logf.flush()
11 |
12 | def close(self):
13 | self.logf.close()
14 |
15 | def write_args(self, cmd_args):
16 | self.logf.write('cmd arguments:\n')
17 | for k in cmd_args.__dict__:
18 | val = cmd_args.__dict__[k]
19 | self.logf.write('{0}: {1}\n'.format(k, val))
20 |
21 | def write_correspondence_accuracy(fname, dis_ratio, dis_threshold):
22 | with open(fname, 'w') as f:
23 | f.write('Correspondence Accuracy:\n')
24 | f.write('distance ratio: ')
25 | for i in range(dis_ratio.shape[0]):
26 | f.write(' {0}'.format(dis_ratio[i]))
27 | f.write('\n')
28 |
29 | f.write('distance threshold: ')
30 | for i in range(dis_threshold.shape[0]):
31 | f.write(' {0}'.format(dis_threshold[i]))
32 | f.write('\n')
33 | f.close()
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/utils/mesh_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | def read_obj_file(fname):
3 | vertices = []
4 | faces = []
5 | try:
6 | f = open(fname)
7 |
8 | for line in f:
9 | if line[:2] == "v ":
10 | strs = line.split()
11 | v0 = float(strs[1])
12 | v1 = float(strs[2])
13 | v2 = float(strs[3])
14 | vertex = [v0, v1, v2]
15 | vertices.append(vertex)
16 |
17 | elif line[0] == "f":
18 | strs = line.split()
19 | f0 = int(strs[1].split('/')[0])-1
20 | f1 = int(strs[2].split('/')[0])-1
21 | f2 = int(strs[3].split('/')[0])-1
22 | face = [f0, f1, f2]
23 |
24 | faces.append(face)
25 |
26 | f.close()
27 | except IOError:
28 | print(".obj file not found.")
29 |
30 | vertices = np.array(vertices)
31 | faces = np.array(faces)
32 |
33 | return vertices, faces
34 |
35 |
36 | def rotate_obj_file(fname, rot_mat):
37 |
38 | lines = []
39 | with open(fname) as fin:
40 |
41 | for line in fin:
42 | lines.append(line)
43 |
44 | for i in range(len(lines)):
45 | line = lines[i]
46 | if line[:2] == "v ":
47 | strs = line.split(' ')
48 | vertex = np.array([float(strs[1]), float(strs[2]), float(strs[3])])
49 | vertex = np.matmul(rot_mat, vertex)
50 | line = "v {0} {1} {2}\n".format(vertex[0], vertex[1], vertex[2])
51 | lines[i] = line
52 | elif line[:3] == 'vn ':
53 | strs = line.split(' ')
54 | vn = np.array([float(strs[1]), float(strs[2]), float(strs[3])])
55 | vn = np.matmul(rot_mat, vn)
56 | line = "vn {0} {1} {2}\n".format(vn[0], vn[1], vn[2])
57 | lines[i] = line
58 |
59 | with open(fname, 'w') as f:
60 | for line in lines:
61 | f.write(line)
62 |
63 |
64 | def write_off_file(fname, vertices, faces):
65 | with open(fname, 'w') as f:
66 | vnum = len(vertices)
67 | fnum = len(faces)
68 | f.write('COFF\n')
69 | f.write('{0} {1} {2}\n'.format(vnum, fnum, 0))
70 | for i in range(0, vnum):
71 | f.write('{0} {1} {2}\n'.format(vertices[i][0], vertices[i][1], vertices[i][2]))
72 |
73 | fnum = len(faces)
74 | for i in range(0, fnum):
75 | f.write('3 {0} {1} {2}\n'.format(faces[i][0], faces[i][1], faces[i][2]))
76 |
77 | def read_off_file(fname):
78 | vertices = []
79 | faces = []
80 | try:
81 | f = open(fname)
82 | head = f.readline()
83 | strline = f.readline()
84 | strs = strline.split(' ')
85 | vnum = int(strs[0])
86 | fnum = int(strs[1])
87 | for i in range(0, vnum):
88 | strline = f.readline()
89 | strs = strline.split(' ')
90 | v0 = float(strs[0])
91 | v1 = float(strs[1])
92 | v2 = float(strs[2])
93 | vertex = [v0, v1, v2]
94 | vertices.append(vertex)
95 |
96 | for i in range(0, fnum):
97 | strline = f.readline()
98 | strs = strline.split(' ')
99 | f0 = int(strs[1])
100 | f1 = int(strs[2])
101 | f2 = int(strs[3])
102 | face = [f0, f1, f2]
103 | faces.append(face)
104 |
105 | f.close()
106 | except IOError:
107 | print(".off file not found.")
108 |
109 | vertices = np.array(vertices)
110 | faces = np.array(faces)
111 | return vertices, faces
112 |
113 |
114 | def write_obj_file(fname, vertices, faces):
115 | with open(fname, 'w') as f:
116 | vnum = len(vertices)
117 | for i in range(0, vnum):
118 | f.write('v {0} {1} {2}\n'.format(vertices[i][0], vertices[i][1], vertices[i][2]))
119 |
120 | fnum = len(faces)
121 | for i in range(0, fnum):
122 | f.write('f {0} {1} {2}\n'.format(faces[i][0]+1, faces[i][1]+1, faces[i][2]+1))
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/utils/point_cloud_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import cv2
4 | # from pointnet2 import _ext
5 |
6 | def write_points_off(fname, points, colors=None):
7 |
8 | with open(fname, 'w') as f:
9 |
10 | num = points.shape[0]
11 | f.write('COFF\n')
12 | f.write('{0} 0 0\n'.format(num))
13 | for i in range(0, num):
14 | if colors is not None:
15 | f.write('{0} {1} {2} {3} {4} {5}\n'.format(points[i, 0], points[i, 1], points[i, 2], int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2])))
16 | else:
17 | f.write('{0} {1} {2}\n'.format(points[i, 0], points[i, 1], points[i, 2]))
18 |
19 |
20 | def write_points_obj(fname, points, colors=None):
21 |
22 | with open(fname, 'w') as f:
23 |
24 | num = points.shape[0]
25 | for i in range(0, num):
26 | if colors is not None:
27 | f.write('v {0} {1} {2} {3} {4} {5}\n'.format(points[i, 0], points[i, 1], points[i, 2], int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2])))
28 | else:
29 | f.write('v {0} {1} {2}\n'.format(points[i, 0], points[i, 1], points[i, 2]))
30 |
31 |
32 | def compute_pca(points):
33 | mean, eigvec = cv2.PCACompute(points, mean=None)
34 | if np.dot(np.cross(eigvec[0], eigvec[1]), eigvec[2])<0:
35 | eigvec[2] = -eigvec[2]
36 |
37 | eigvec[0] = eigvec[0] / np.linalg.norm(eigvec[0])
38 | eigvec[1] = eigvec[1] / np.linalg.norm(eigvec[1])
39 | eigvec[2] = eigvec[2] / np.linalg.norm(eigvec[2])
40 |
41 | return eigvec
42 |
43 | def query_KNN(points, query_pts, k, return_dis=True):
44 | '''
45 |
46 | :param points: n x 3
47 | :param query_pts: m x 3
48 | :param k: num of neighbors
49 | :return: m x k ids, sorted_dis
50 | '''
51 |
52 | diff = query_pts[:, None, :] - points[None, :, :]
53 | dis = np.sqrt(np.sum(diff * diff, axis=2))# m x n
54 | sorted_idx = np.argsort(dis, axis=1)
55 | sorted_idx = sorted_idx[:, :k]
56 |
57 | if return_dis:
58 | sorted_dis = dis[None, 0, sorted_idx[0, :]]
59 | for i in range(1, query_pts.shape[0]):
60 | sorted_dis = np.concatenate((sorted_dis, dis[None, i, sorted_idx[i, :]]), axis=0)
61 |
62 | return sorted_idx, sorted_dis
63 | else:
64 | return sorted_idx
65 |
66 |
67 | def query_KNN_tensor(points, query_pts, k):
68 | '''
69 |
70 | :param points: n x 3
71 | :param query_pts: m x 3
72 | :param k: num of neighbors
73 | :return: m x k ids, sorted_dis
74 | '''
75 |
76 | diff = query_pts[:, None, :] - points[None, :, :]
77 | dis = torch.sqrt(torch.sum(diff * diff, dim=2))# m x n
78 | sorted_idx = torch.argsort(dis, dim=1)
79 | sorted_idx = sorted_idx[:, :k]
80 |
81 | sorted_dis = dis[None, 0, sorted_idx[0, :]]
82 | for i in range(1, query_pts.shape[0]):
83 | sorted_dis = torch.cat((sorted_dis, dis[None, i, sorted_idx[i, :]]), dim=0)
84 |
85 | return sorted_idx, sorted_dis
86 |
87 |
88 |
89 | def read_pointcloud_obj(fname):
90 | vertices = []
91 | try:
92 | f = open(fname)
93 |
94 | for line in f:
95 | if line[:2] == "v ":
96 | strs = line.split(' ')
97 | v0 = float(strs[1])
98 | v1 = float(strs[2])
99 | v2 = float(strs[3])
100 | vertex = [v0, v1, v2]
101 | vertices.append(vertex)
102 |
103 | f.close()
104 | except IOError:
105 | print(".obj file not found.")
106 |
107 | vertices = np.array(vertices)
108 |
109 |
110 | return vertices
111 |
112 |
113 | def read_points_off(fname, read_color=False):
114 | vertices = []
115 | colors = []
116 |
117 | try:
118 | f = open(fname)
119 | head = f.readline()
120 | strline = f.readline()
121 | strs = strline.split(' ')
122 | vnum = int(strs[0])
123 | fnum = int(strs[1])
124 | for i in range(0, vnum):
125 | strline = f.readline()
126 | strs = strline.split(' ')
127 | v0 = float(strs[0])
128 | v1 = float(strs[1])
129 | v2 = float(strs[2])
130 | vertex = [v0, v1, v2]
131 | vertices.append(vertex)
132 |
133 | if len(strs) > 3:
134 | c0 = float(strs[3])
135 | c1 = float(strs[4])
136 | c2 = float(strs[5])
137 | color = [c0, c1, c2]
138 | colors.append(color)
139 |
140 |
141 |
142 |
143 | f.close()
144 | except IOError:
145 | print(".off file not found.")
146 |
147 | pts = np.array(vertices).astype(np.float32)
148 |
149 | if len(colors) > 0 and read_color == True:
150 | colors = np.array(colors).astype(np.float32)
151 | return pts, colors
152 | else:
153 | return pts
154 |
155 | def trans_pointcloud(rot_mat, trans_mat, points):
156 | '''
157 |
158 | :param rot_mat: 3 x 3
159 | :param trans_mat: 3
160 | :param points: n x 3
161 | :return: n x 3
162 | '''
163 | tmp_points = np.matmul(rot_mat, np.transpose(points, (1, 0)))
164 | tmp_points = tmp_points + trans_mat[:, None]
165 | tmp_points = np.transpose(tmp_points, (1, 0))
166 | return tmp_points
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
--------------------------------------------------------------------------------