├── LICENSE ├── README.md ├── data_utils ├── ModelNetDataLoader.py ├── bhcp_dataloader.py ├── data_flower.py ├── data_utils.py ├── dataset.py ├── keypointnet_dataloader.py └── shapenet_seg_dataloader.py ├── images ├── Pipeline.png └── VisualizationResults.png ├── models ├── __init__.py ├── chamfer_distance.py ├── model_weightchamfer.py ├── pointnetpp │ ├── pointnet.py │ ├── pointnet2_cls_msg.py │ ├── pointnet2_cls_ssg.py │ ├── pointnet2_part_seg_msg.py │ ├── pointnet2_part_seg_ssg.py │ ├── pointnet2_sem_seg.py │ ├── pointnet2_sem_seg_msg.py │ ├── pointnet_cls.py │ ├── pointnet_part_seg.py │ ├── pointnet_sem_seg.py │ └── pointnet_util.py └── torch_pointnet_utils.py ├── pointnet2_ops_lib ├── MANIFEST.in ├── pointnet2_ops.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── requires.txt │ └── top_level.txt ├── pointnet2_ops │ ├── __init__.py │ ├── _ext-src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── pointnet2_modules.py │ └── pointnet2_utils.py └── setup.py ├── provider.py ├── test_correspondence.py ├── train.py └── utils ├── __init__.py ├── check_points_utils.py ├── logutils.py ├── mesh_utils.py └── point_cloud_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds 2 | 3 | ## Description 4 | This repository contains the code for our paper: **Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds**. 5 | > [**Unsupervised Distribution-aware Keypoints Generation from 3D Point Clouds**](https://doi.org/10.1016/j.neunet.2024.106158), 6 | > Yiqi Wu, Xingye Chen, Xuan Huang, Kelin Song, Dejun Zhang 7 | > [Bibetex](https://github.com/Chenguoz/Keypoints#citation) 8 | 9 |
10 |

11 |
12 | 13 |
14 |

15 |
16 | 17 | 18 | ## Environment setup 19 | ``` 20 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 21 | pip install pointnet2_ops_lib/. 22 | ``` 23 | ## Dataset 24 | The training and testing data for correspondence is provided by [KeypointNet](https://github.com/qq456cvb/KeypointNet) and [ShapeNet](https://github.com/antao97/PointCloudDatasets) 25 | 26 | 27 | ## Citation 28 | 29 | ```bibtex 30 | @article{wu2024unsupervised, 31 | title={Unsupervised distribution-aware keypoints generation from 3D point clouds}, 32 | author={Wu, Yiqi and Chen, Xingye and Huang, Xuan and Song, Kelin and Zhang, Dejun}, 33 | journal={Neural Networks}, 34 | pages={106158}, 35 | year={2024}, 36 | publisher={Elsevier} 37 | } 38 | ``` 39 | 40 | ## Acknowledgment 41 | 42 | Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works. 43 | 44 | [3DStructurePoints](https://github.com/NolenChen/3DStructurePoints), 45 | [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) 46 | -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: Xu Yan 3 | @file: ModelNet.py 4 | @time: 2021/3/19 15:51 5 | ''' 6 | import argparse 7 | import os 8 | import numpy as np 9 | import warnings 10 | import pickle 11 | 12 | from tqdm import tqdm 13 | from torch.utils.data import Dataset 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | 18 | def pc_normalize(pc): 19 | centroid = np.mean(pc, axis=0) 20 | pc = pc - centroid 21 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 22 | pc = pc / m 23 | return pc 24 | 25 | 26 | def farthest_point_sample(point, npoint): 27 | """ 28 | Input: 29 | xyz: pointcloud data, [N, D] 30 | npoint: number of samples 31 | Return: 32 | centroids: sampled pointcloud index, [npoint, D] 33 | """ 34 | N, D = point.shape 35 | xyz = point[:, :3] 36 | centroids = np.zeros((npoint,)) 37 | distance = np.ones((N,)) * 1e10 38 | farthest = np.random.randint(0, N) 39 | for i in range(npoint): 40 | centroids[i] = farthest 41 | centroid = xyz[farthest, :] 42 | dist = np.sum((xyz - centroid) ** 2, -1) 43 | mask = dist < distance 44 | distance[mask] = dist[mask] 45 | farthest = np.argmax(distance, -1) 46 | point = point[centroids.astype(np.int32)] 47 | return point 48 | 49 | 50 | class ModelNetDataLoader(Dataset): 51 | def __init__(self, root, args, category='chair', split='train', process_data=False): 52 | self.root = root 53 | self.npoints = args.num_inputs 54 | self.process_data = process_data 55 | self.uniform = args.use_uniform_sample 56 | self.use_normals = args.use_normals 57 | self.num_category = args.num_category 58 | self.category = category 59 | if self.num_category == 10: 60 | self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt') 61 | else: 62 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 63 | 64 | self.cat = [line.rstrip() for line in open(self.catfile)] 65 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 66 | 67 | shape_ids = {} 68 | if self.num_category == 10: 69 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))] 70 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))] 71 | else: 72 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 73 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 74 | 75 | assert (split == 'train' or split == 'test') 76 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 77 | 78 | if not category == 'all': 79 | self.datapath = [ 80 | (shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 81 | in range(len(shape_ids[split])) if shape_names[i] == category] 82 | else: 83 | self.datapath = [ 84 | (shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 85 | in range(len(shape_ids[split]))] 86 | print('The size of %s data is %d' % (split, len(self.datapath))) 87 | 88 | if self.uniform: 89 | self.save_path = os.path.join(root, 90 | 'modelnet%d_%s_%d_%spts_fps.dat' % (self.num_category, split, self.npoints, self.category)) 91 | else: 92 | self.save_path = os.path.join(root, 'modelnet%d_%s_%d_%spts.dat' % (self.num_category, split, self.npoints, self.category)) 93 | 94 | if self.process_data: 95 | if not os.path.exists(self.save_path): 96 | print('Processing data %s (only running in the first time)...' % self.save_path) 97 | self.list_of_points = [None] * len(self.datapath) 98 | self.list_of_labels = [None] * len(self.datapath) 99 | 100 | for index in tqdm(range(len(self.datapath)), total=len(self.datapath)): 101 | fn = self.datapath[index] 102 | cls = self.classes[self.datapath[index][0]] 103 | cls = np.array([cls]).astype(np.int32) 104 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 105 | 106 | if self.uniform: 107 | point_set = farthest_point_sample(point_set, self.npoints) 108 | else: 109 | point_set = point_set[0:self.npoints, :] 110 | 111 | self.list_of_points[index] = point_set 112 | self.list_of_labels[index] = cls 113 | 114 | with open(self.save_path, 'wb') as f: 115 | pickle.dump([self.list_of_points, self.list_of_labels], f) 116 | else: 117 | print('Load processed data from %s...' % self.save_path) 118 | with open(self.save_path, 'rb') as f: 119 | self.list_of_points, self.list_of_labels = pickle.load(f) 120 | 121 | def __len__(self): 122 | return len(self.datapath) 123 | 124 | def _get_item(self, index): 125 | if self.process_data: 126 | point_set, label = self.list_of_points[index], self.list_of_labels[index] 127 | else: 128 | fn = self.datapath[index] 129 | cls = self.classes[self.datapath[index][0]] 130 | label = np.array([cls]).astype(np.int32) 131 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 132 | 133 | if self.uniform: 134 | point_set = farthest_point_sample(point_set, self.npoints) 135 | else: 136 | point_set = point_set[0:self.npoints, :] 137 | 138 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 139 | if not self.use_normals: 140 | point_set = point_set[:, 0:3] 141 | 142 | return point_set, label[0] 143 | 144 | def __getitem__(self, index): 145 | return self._get_item(index) 146 | 147 | 148 | def parse_args(): 149 | parser = argparse.ArgumentParser( 150 | description="Arguments", 151 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 152 | ) 153 | parser.add_argument("-batch_size", type=int, default=20, help="Batch size") 154 | parser.add_argument( 155 | "-weight_decay", type=float, default=1e-5, help="L2 regularization coeff" 156 | ) 157 | parser.add_argument("-lr", type=float, default=1e-2, help="Initial learning rate") 158 | parser.add_argument( 159 | "-lr_decay", type=float, default=0.7, help="Learning rate decay gamma" 160 | ) 161 | parser.add_argument( 162 | "-decay_batch", type=float, default=20, help="Learning rate decay batch" 163 | ) 164 | parser.add_argument( 165 | "-bn_momentum", type=float, default=0.5, help="Initial batch norm momentum" 166 | ) 167 | parser.add_argument( 168 | "-bnm_decay", type=float, default=0.5, help="Batch norm momentum decay gamma" 169 | ) 170 | 171 | parser.add_argument( 172 | "-checkpoint_save_step", type=int, default=50, help="Step for saving Checkpoint" 173 | ) 174 | 175 | parser.add_argument( 176 | "-checkpoint", type=str, default=None 177 | , help="Checkpoint to start from" 178 | ) 179 | parser.add_argument( 180 | "-num_of_transform", type=int, default=0, 181 | help="Number of transforms for rotation data augmentation. Useful when testing on shapes without alignment" 182 | ) 183 | 184 | parser.add_argument( 185 | "-num_inputs", type=int, default=1024, help="sample points from initial point cloud" 186 | ) 187 | 188 | parser.add_argument( 189 | "-num_structure_points", type=int, default=16 190 | , help="Number of structure points" 191 | ) 192 | parser.add_argument( 193 | "-category", type=str, default='chair', help="Category of the objects to train" 194 | ) 195 | parser.add_argument( 196 | "-data_dir", type=str, default="training_data/", help="Root of the training data" 197 | ) 198 | parser.add_argument( 199 | "-test_data_dir", type=str, default="demo_data/", help="Root of the test data" 200 | ) 201 | parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') 202 | parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40') 203 | 204 | parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling') 205 | 206 | parser.add_argument( 207 | "-max_epochs", type=int, default=200, help="Number of epochs to train for" 208 | ) 209 | parser.add_argument( 210 | "-log_dir", type=str, default=None, help="Root of the log" 211 | ) 212 | parser.add_argument( 213 | "-multi_distribution", type=int, default=5, help="Multivariate normal distribution nums" 214 | ) 215 | parser.add_argument('--process_data', action='store_true', default=False, help='save data offline') 216 | parser.add_argument('-model', default='PointSPN', help='model name [default: PointSPN]') 217 | args = parser.parse_args() 218 | return args 219 | 220 | 221 | if __name__ == '__main__': 222 | import torch 223 | 224 | args = parse_args() 225 | data = ModelNetDataLoader('../modelnet40/', args, split='train') 226 | 227 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True, num_workers=10) 228 | for point, label in DataLoader: 229 | print(point.shape) 230 | print(label.shape) 231 | -------------------------------------------------------------------------------- /data_utils/bhcp_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import ntpath 4 | import pickle 5 | from .ModelNetDataLoader import pc_normalize 6 | 7 | 8 | class bhcp_dataloader(data.Dataset): 9 | def __init__(self, data_root, category, is_pts_aligned=False, preload_data=True, split='train'): 10 | super().__init__() 11 | if split == 'train': 12 | # self.data_path = os.path.join(data_root , category) 13 | 14 | self.data_path = os.path.join(data_root + '/training_data', category) 15 | else: 16 | self.data_path = os.path.join(data_root + '/test_data', category) 17 | 18 | self.sub_dirs = [ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()] 19 | self.data_num = len(self.sub_dirs) 20 | self.is_pts_aligned = is_pts_aligned 21 | 22 | self.meta_data_list = None 23 | if preload_data: 24 | self.meta_data_list = [] 25 | for i in range(len(self.sub_dirs)): 26 | meta_fname = os.path.join(self.data_path, self.sub_dirs[i], 'meta.pkl') 27 | with open(meta_fname, 'rb') as f: 28 | meta_data = pickle.load(f) 29 | self.meta_data_list.append(meta_data) 30 | 31 | def __getitem__(self, idx): 32 | if self.meta_data_list is None: 33 | meta_fname = os.path.join(self.data_path, self.sub_dirs[idx], 'meta.pkl') 34 | f = open(meta_fname, 'rb') 35 | meta_data = pickle.load(f) 36 | else: 37 | meta_data = self.meta_data_list[idx] 38 | 39 | if self.is_pts_aligned: 40 | if 'points_aligned' in meta_data: 41 | points = meta_data['points_aligned'] 42 | else: 43 | points = meta_data['points'] 44 | else: 45 | points = meta_data['points'] 46 | points[:, 0:3] = pc_normalize(points[:, 0:3]) 47 | 48 | res = {} 49 | res['points'] = points 50 | if 'feat_pts' in meta_data: # the labeled feature points on the bhcp dataset for computing correspondence accuracy 51 | if self.is_pts_aligned: 52 | res['feat_pts'] = meta_data['feat_pts_aligned'] 53 | else: 54 | res['feat_pts'] = meta_data['feat_pts'] 55 | 56 | res['data_id'] = idx 57 | return points,meta_data['feat_pts'] 58 | # return res 59 | 60 | def __len__(self): 61 | return self.data_num 62 | -------------------------------------------------------------------------------- /data_utils/data_flower.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Apr 24 22:59:57 2020 4 | 5 | @author: eliphat 6 | """ 7 | import os 8 | import random 9 | import numpy 10 | import h5py 11 | 12 | 13 | def load_h5(h5_filename, normalize=False, include_label=False): 14 | f = h5py.File(h5_filename, 'r') 15 | data = f['data'][:] # (n, 2048, 3) 16 | if normalize: 17 | # nmean = numpy.mean(data, axis=1, keepdims=True) 18 | # nstd = numpy.std(data, axis=1, keepdims=True) 19 | # nstd = numpy.mean(nstd, axis=-1, keepdims=True) 20 | dmin = data.min(axis=1, keepdims=True).min(axis=-1, keepdims=True) 21 | dmax = data.max(axis=1, keepdims=True).max(axis=-1, keepdims=True) 22 | data = (data - dmin) / (dmax - dmin) 23 | # data = (data - nmean) / nstd 24 | data = 2.0 * (data - 0.5) 25 | if include_label: 26 | label = f['label'][:] 27 | return data, label 28 | return data 29 | 30 | 31 | def all_h5(parent, normalize=False, include_label=False, 32 | subclasses=tuple(range(40)), sample=256): 33 | lazy = map(lambda x: load_h5(x, normalize, include_label), 34 | walk_files(parent)) 35 | if include_label: 36 | xy = tuple(lazy) 37 | x = [x for x, y in xy] 38 | y = [y for x, y in xy] 39 | x = numpy.concatenate(x) 40 | y = numpy.concatenate(y) 41 | xf = [] 42 | yf = [] 43 | for xp, yp in zip(x, y): 44 | if yp[0] in subclasses: 45 | if sample is None: 46 | xf.append(xp) 47 | else: 48 | xf.append(random.choices(xp, k=sample)) 49 | yf.append(numpy.eye(len(subclasses))[subclasses.index(yp[0])]) 50 | return numpy.array(xf), numpy.array(yf) 51 | return numpy.concatenate(tuple(lazy)) 52 | 53 | 54 | def walk_files(path): 55 | for r, ds, fs in os.walk(path): 56 | for f in fs: 57 | yield os.path.join(r, f) 58 | 59 | 60 | def last_dirname(file_path): 61 | return os.path.basename(os.path.dirname(file_path)) 62 | 63 | 64 | def dataset_split(path): 65 | flist = list(walk_files(path)) 66 | tr = filter(lambda p: 'train' in last_dirname(p), flist) 67 | te = filter(lambda p: 'test' in last_dirname(p), flist) 68 | return list(tr), list(te) 69 | -------------------------------------------------------------------------------- /data_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import sys 9 | sys.path.append("..") 10 | import torch 11 | import numpy as np 12 | from torchvision import transforms 13 | import utils.point_cloud_utils as point_cloud_utils 14 | 15 | 16 | def angle_axis_tensor(angle, axis): 17 | # type: (float, np.ndarray) -> float 18 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 19 | 20 | Parameters 21 | ---------- 22 | angle : float 23 | Angle to rotate by 24 | axis: torch.Tensor 25 | Axis to rotate about 26 | 27 | Returns 28 | ------- 29 | torch.Tensor 30 | 3x3 rotation matrix 31 | """ 32 | u = axis / torch.norm(axis) 33 | cosval, sinval = torch.cos(angle), torch.sin(angle) 34 | 35 | # yapf: disable 36 | cross_prod_mat = torch.Tensor([[0.0, -u[2], u[1]], 37 | [u[2], 0.0, -u[0]], 38 | [-u[1], u[0], 0.0]]).type(torch.FloatTensor) 39 | R = cosval * torch.eye(3).type(torch.FloatTensor) + sinval * cross_prod_mat + (1.0 - cosval) * torch.ger(u, u) 40 | return R 41 | 42 | 43 | def angle_axis(angle, axis): 44 | # type: (float, np.ndarray) -> float 45 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 46 | 47 | Parameters 48 | ---------- 49 | angle : float 50 | Angle to rotate by 51 | axis: np.ndarray 52 | Axis to rotate about 53 | 54 | Returns 55 | ------- 56 | torch.Tensor 57 | 3x3 rotation matrix 58 | """ 59 | u = axis / np.linalg.norm(axis) 60 | cosval, sinval = np.cos(angle), np.sin(angle) 61 | 62 | # yapf: disable 63 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 64 | [u[2], 0.0, -u[0]], 65 | [-u[1], u[0], 0.0]]) 66 | 67 | R = torch.from_numpy( 68 | cosval * np.eye(3) 69 | + sinval * cross_prod_mat 70 | + (1.0 - cosval) * np.outer(u, u) 71 | ) 72 | # yapf: enable 73 | return R.float() 74 | 75 | 76 | class PointcloudRandomScale(object): 77 | def __init__(self, lo=0.8, hi=1.25): 78 | self.lo, self.hi = lo, hi 79 | 80 | def __call__(self, points): 81 | scaler = np.random.uniform(self.lo, self.hi) 82 | points[:, 0:3] *= scaler 83 | return points 84 | 85 | 86 | class PointcloudRandomRotate(object): 87 | def __init__(self, axis=np.array([0.0, 1.0, 0.0])): 88 | self.axis = axis 89 | 90 | def __call__(self, points): 91 | rotation_angle = np.random.uniform() * 2 * np.pi 92 | axis = np.array([np.random.uniform(), np.random.uniform(), np.random.uniform()]) 93 | axis = axis / np.sqrt(np.sum(axis * axis)) 94 | rotation_matrix = angle_axis(rotation_angle, axis) 95 | 96 | normals = points.size(1) > 3 97 | if not normals: 98 | return torch.matmul(points, rotation_matrix.t()) 99 | else: 100 | pc_xyz = points[:, 0:3] 101 | pc_normals = points[:, 3:] 102 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 103 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 104 | 105 | return points 106 | 107 | 108 | class PointcloudRandomRotatePerturbation(object): 109 | def __init__(self, angle_sigma=0.06, angle_clip=0.18): 110 | self.angle_sigma, self.angle_clip = angle_sigma, angle_clip 111 | 112 | def _get_angles(self): 113 | angles = np.clip( 114 | self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip 115 | ) 116 | 117 | return angles 118 | 119 | def __call__(self, points): 120 | angles = self._get_angles() 121 | Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0])) 122 | Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0])) 123 | Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0])) 124 | 125 | rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx) 126 | 127 | normals = points.size(1) > 3 128 | if not normals: 129 | return torch.matmul(points, rotation_matrix.t()) 130 | else: 131 | pc_xyz = points[:, 0:3] 132 | pc_normals = points[:, 3:] 133 | points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t()) 134 | points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t()) 135 | 136 | return points 137 | 138 | 139 | class PointcloudJitter(object): 140 | def __init__(self, std=0.01, clip=0.05): 141 | self.std, self.clip = std, clip 142 | 143 | def __call__(self, points): 144 | jittered_data = ( 145 | points.new(points.size(0), 3) 146 | .normal_(mean=0.0, std=self.std) 147 | .clamp_(-self.clip, self.clip) 148 | ) 149 | points[:, 0:3] += jittered_data 150 | return points 151 | 152 | 153 | class PointcloudNormalize(object): 154 | def __init__(self, max_size=1.0): 155 | 156 | self.max_size = max_size 157 | 158 | def __call__(self, points): 159 | 160 | points_max, _ = torch.max(points, dim=0) 161 | points_min, _ = torch.min(points, dim=0) 162 | points_center = (points_max + points_min) / 2 163 | points = points - points_center[None, :] 164 | max_radius = torch.max(torch.sqrt(torch.sum(points * points, dim=1))) 165 | points = points / max_radius * self.max_size / 2.0 166 | return points 167 | 168 | 169 | class PointcloudRandomPermutation(object): 170 | def __call__(self, points): 171 | num = points.shape[0] 172 | idxs = torch.randperm(num).type(torch.LongTensor) 173 | points = torch.index_select(points, 0, idxs).clone() 174 | return points 175 | 176 | 177 | class PointcloudRandomTranslate(object): 178 | def __init__(self, translate_range=0.1): 179 | self.translate_range = translate_range 180 | 181 | def __call__(self, points): 182 | translation = np.random.uniform(-self.translate_range, self.translate_range) 183 | points[:, 0:3] += translation 184 | return points 185 | 186 | 187 | class PointcloudToTensor(object): 188 | def __call__(self, points): 189 | return torch.from_numpy(points).float() 190 | 191 | 192 | class PointcloudRandomInputDropout(object): 193 | def __init__(self, max_dropout_ratio=0.875): 194 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 195 | self.max_dropout_ratio = max_dropout_ratio 196 | 197 | def __call__(self, points): 198 | pc = points.numpy() 199 | 200 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 201 | drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0] 202 | if len(drop_idx) > 0: 203 | pc[drop_idx] = pc[0] # set to the first point 204 | 205 | return torch.from_numpy(pc).float() 206 | 207 | 208 | class PointcloudTranslate(object): 209 | def __init__(self, translation=np.array([0.0, 0.1, 0.0])): 210 | ''' 211 | :param translation: pytorch tensor, translation vector(x,y,z) 212 | ''' 213 | self.translation = torch.from_numpy(translation) 214 | 215 | 216 | def __call__(self, points): 217 | ''' 218 | 219 | :param points: ... , num_of_points, 3 220 | :return: points after trans 221 | ''' 222 | translation = self.translation 223 | 224 | if points.is_cuda is True: 225 | translation = translation.cuda() 226 | translation.requires_grad = False 227 | 228 | respoints = points[..., 0:3] + translation 229 | return respoints 230 | 231 | 232 | class PointcloudScale(object): 233 | def __init__(self, scaler): 234 | self.scaler = scaler 235 | 236 | def __call__(self, points): 237 | 238 | respoints = points * self.scaler 239 | return respoints 240 | 241 | 242 | class PointcloudRotate(object): 243 | def __init__(self, angle_in_degree=np.pi, axis=np.array([0.0, 1.0, 0.0]), is_cuda=True): 244 | self.axis = axis 245 | self.angle_in_degree = angle_in_degree 246 | self.rotation_matrix_t = angle_axis(self.angle_in_degree, self.axis).t() 247 | 248 | def __call__(self, points): 249 | 250 | ''' 251 | :param points: ... , num_of_points, 3 252 | :return: points after rotate 253 | ''' 254 | rotation_matrix_t = self.rotation_matrix_t.clone() 255 | if points.is_cuda is True: 256 | rotation_matrix_t = rotation_matrix_t.cuda() 257 | tpoints = torch.matmul(points, rotation_matrix_t) 258 | 259 | return tpoints 260 | 261 | 262 | def GenPointcloudRandomTransformFunction(max_rot_angle=2*np.pi): 263 | scale_lo = 0.8 264 | scale_hi = 1.25 265 | scaler = np.random.uniform(scale_lo, scale_hi) 266 | scale_func = PointcloudScale(scaler) 267 | 268 | rotation_angle = np.random.uniform() * max_rot_angle 269 | rotation_axis = np.array([np.random.uniform(), np.random.uniform(), np.random.uniform()]) 270 | rotation_axis = rotation_axis / np.linalg.norm(rotation_axis) 271 | rotation_func = PointcloudRotate(rotation_angle, rotation_axis) 272 | 273 | trans_func = transforms.Compose([scale_func, rotation_func]) 274 | 275 | return trans_func 276 | 277 | 278 | def AddTransformsToBatchPoints(points, num_of_trans, max_rot_angle=2*np.pi): 279 | ''' 280 | 281 | :param points:bn, num_of_points, 3 282 | :return: points: (num_of_trans, bn, num_of_points, 3) 283 | transform 284 | ''' 285 | transfunc_list = [] 286 | res_points = None 287 | for trans_i in range(0, num_of_trans): 288 | transf = GenPointcloudRandomTransformFunction(max_rot_angle) 289 | transfunc_list.append(transf) 290 | tpoints = transf(points) 291 | if res_points is None: 292 | res_points = tpoints[None, :, :, :] 293 | else: 294 | res_points = torch.cat((res_points, tpoints[None, :, :, :]), dim=0) 295 | 296 | return res_points, transfunc_list 297 | 298 | 299 | class PointcloudRotateFuns(object): 300 | def __init__(self, rot_mats): 301 | ''' 302 | :param rot_mats: bn, 3, 3 303 | ''' 304 | 305 | self.rot_mats = rot_mats 306 | 307 | def __call__(self, points): 308 | ''' 309 | 310 | :param points: bn, n , 3 311 | :return: 312 | ''' 313 | if points.is_cuda is True: 314 | tmp_rot = self.rot_mats.cuda() 315 | else: 316 | tmp_rot = self.rot_mats 317 | transed_poitns = torch.transpose(torch.matmul(tmp_rot, torch.transpose(points, 1, 2)), 1, 2) 318 | return transed_poitns 319 | 320 | 321 | def AddPCATransformsToBatchPoints(points, num_of_trans): 322 | trans_points_all = None 323 | rot_mats_all = None 324 | 325 | transfunc_list = [] 326 | 327 | for bi in range(points.shape[0]): 328 | 329 | np_points = points[bi].cpu().numpy() 330 | pca_axis_raw = point_cloud_utils.compute_pca(np_points) 331 | rot_mats = None 332 | trans_points = None 333 | for ti in range(num_of_trans): 334 | tmp_idx = np.array([0, 1, 2]) 335 | pca_axis = pca_axis_raw[tmp_idx, :] 336 | tmp_sign = np.random.randint(2, size=2) 337 | tmp_sign[tmp_sign == 0] = -1 338 | pca_axis[0, :] = pca_axis[0, :] * tmp_sign[0] 339 | pca_axis[1, :] = pca_axis[1, :] * tmp_sign[1] 340 | pca_axis[2, :] = np.cross(pca_axis[0, :], pca_axis[1, :]) 341 | 342 | rot_mat = torch.from_numpy(pca_axis) 343 | if points.is_cuda: 344 | rot_mat = rot_mat.cuda() 345 | 346 | if rot_mats is None: 347 | rot_mats = rot_mat[None, :, :] 348 | else: 349 | rot_mats = torch.cat((rot_mats, rot_mat[None, :, :]), dim=0) 350 | 351 | tmp_trans_points = torch.transpose(torch.matmul(rot_mat, torch.transpose(points[bi], 0, 1)), 0, 1) 352 | 353 | if trans_points is None: 354 | trans_points = tmp_trans_points[None, :, :] 355 | else: 356 | trans_points = torch.cat((trans_points, tmp_trans_points[None, :, :]), dim=0) 357 | 358 | if trans_points_all is None: 359 | trans_points_all = trans_points[:, None, :, :] 360 | else: 361 | trans_points_all = torch.cat((trans_points_all, trans_points[:, None, :, :]), dim=1) 362 | 363 | if rot_mats_all is None: 364 | rot_mats_all = rot_mats[:, None, :, :] 365 | else: 366 | rot_mats_all = torch.cat((rot_mats_all, rot_mats[:, None, :, :]), dim=1) 367 | 368 | for ti in range(num_of_trans): 369 | trans_func = PointcloudRotateFuns(rot_mats_all[ti, :, :, :]) 370 | transfunc_list.append(trans_func) 371 | 372 | return trans_points_all, rot_mats_all, transfunc_list 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | -------------------------------------------------------------------------------- /data_utils/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Author: An Tao 5 | @Contact: ta19@mails.tsinghua.edu.cn 6 | @File: dataset.py 7 | @Time: 2020/1/2 10:26 AM 8 | """ 9 | 10 | import os 11 | import torch 12 | import json 13 | import h5py 14 | from glob import glob 15 | import numpy as np 16 | import torch.utils.data as data 17 | 18 | shapenetpart_cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, 19 | 'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, 20 | 'motorbike': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15} 21 | shapenetpart_seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] 22 | shapenetpart_seg_start_index = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] 23 | 24 | 25 | def translate_pointcloud(pointcloud): 26 | xyz1 = np.random.uniform(low=2. / 3., high=3. / 2., size=[3]) 27 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 28 | 29 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 30 | return translated_pointcloud 31 | 32 | 33 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 34 | N, C = pointcloud.shape 35 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 36 | return pointcloud 37 | 38 | 39 | def rotate_pointcloud(pointcloud): 40 | theta = np.pi * 2 * np.random.rand() 41 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 42 | pointcloud[:, [0, 2]] = pointcloud[:, [0, 2]].dot(rotation_matrix) # random rotation (x,z) 43 | return pointcloud 44 | 45 | 46 | class Dataset(data.Dataset): 47 | def __init__(self, root, dataset_name='modelnet40', class_choice=None, 48 | num_points=2048, split='train', load_name=True, load_file=True, 49 | segmentation=False, random_rotate=False, random_jitter=False, 50 | random_translate=False): 51 | 52 | assert dataset_name.lower() in ['shapenetcorev2', 'shapenetpart', 53 | 'modelnet10', 'modelnet40', 'shapenetpartpart'] 54 | assert num_points <= 2048 55 | 56 | if dataset_name in ['shapenetcorev2', 'shapenetpart', 'shapenetpartpart']: 57 | assert split.lower() in ['train', 'test', 'val', 'trainval', 'all'] 58 | else: 59 | assert split.lower() in ['train', 'test', 'all'] 60 | 61 | if dataset_name not in ['shapenetpart'] and segmentation == True: 62 | raise AssertionError 63 | 64 | self.root = os.path.join(root, dataset_name + '_' + '*hdf5_2048') 65 | self.dataset_name = dataset_name 66 | self.class_choice = class_choice 67 | self.num_points = num_points 68 | self.split = split 69 | self.load_name = load_name 70 | self.load_file = load_file 71 | self.segmentation = segmentation 72 | self.random_rotate = random_rotate 73 | self.random_jitter = random_jitter 74 | self.random_translate = random_translate 75 | 76 | self.path_h5py_all = [] 77 | self.path_name_all = [] 78 | self.path_file_all = [] 79 | 80 | if self.split in ['train', 'trainval', 'all']: 81 | self.get_path('train') 82 | if self.dataset_name in ['shapenetcorev2', 'shapenetpart', 'shapenetpartpart']: 83 | if self.split in ['val', 'trainval', 'all']: 84 | self.get_path('val') 85 | if self.split in ['test', 'all']: 86 | self.get_path('test') 87 | 88 | self.path_h5py_all.sort() 89 | data, label, seg = self.load_h5py(self.path_h5py_all) 90 | 91 | if self.load_name or self.class_choice != None: 92 | self.path_name_all.sort() 93 | self.name = np.array(self.load_json(self.path_name_all)) # load label name 94 | 95 | if self.load_file: 96 | self.path_file_all.sort() 97 | self.file = np.array(self.load_json(self.path_file_all)) # load file name 98 | 99 | self.data = np.concatenate(data, axis=0) 100 | self.label = np.concatenate(label, axis=0) 101 | if self.segmentation: 102 | self.seg = np.concatenate(seg, axis=0) 103 | 104 | if self.class_choice != None: 105 | indices = (self.name == class_choice) 106 | self.data = self.data[indices] 107 | self.label = self.label[indices] 108 | self.name = self.name[indices] 109 | if self.segmentation: 110 | self.seg = self.seg[indices] 111 | id_choice = shapenetpart_cat2id[class_choice] 112 | self.seg_num_all = shapenetpart_seg_num[id_choice] 113 | self.seg_start_index = shapenetpart_seg_start_index[id_choice] 114 | if self.load_file: 115 | self.file = self.file[indices] 116 | elif self.segmentation: 117 | self.seg_num_all = 50 118 | self.seg_start_index = 0 119 | 120 | def get_path(self, type): 121 | path_h5py = os.path.join(self.root, '*%s*.h5' % type) 122 | self.path_h5py_all += glob(path_h5py) 123 | if self.load_name: 124 | path_json = os.path.join(self.root, '%s*_id2name.json' % type) 125 | self.path_name_all += glob(path_json) 126 | if self.load_file: 127 | path_json = os.path.join(self.root, '%s*_id2file.json' % type) 128 | self.path_file_all += glob(path_json) 129 | return 130 | 131 | def load_h5py(self, path): 132 | all_data = [] 133 | all_label = [] 134 | all_seg = [] 135 | for h5_name in path: 136 | f = h5py.File(h5_name, 'r+') 137 | data = f['data'][:].astype('float32') 138 | label = f['label'][:].astype('int64') 139 | if self.segmentation: 140 | seg = f['seg'][:].astype('int64') 141 | f.close() 142 | all_data.append(data) 143 | all_label.append(label) 144 | if self.segmentation: 145 | all_seg.append(seg) 146 | return all_data, all_label, all_seg 147 | 148 | def load_json(self, path): 149 | all_data = [] 150 | for json_name in path: 151 | j = open(json_name, 'r+') 152 | data = json.load(j) 153 | all_data += data 154 | return all_data 155 | 156 | def __getitem__(self, item): 157 | point_set = self.data[item][:self.num_points] 158 | label = self.label[item] 159 | if self.load_name: 160 | name = self.name[item] # get label name 161 | if self.load_file: 162 | file = self.file[item] # get file name 163 | 164 | if self.random_rotate: 165 | point_set = rotate_pointcloud(point_set) 166 | if self.random_jitter: 167 | point_set = jitter_pointcloud(point_set) 168 | if self.random_translate: 169 | point_set = translate_pointcloud(point_set) 170 | 171 | # convert numpy array to pytorch Tensor 172 | point_set = torch.from_numpy(point_set) 173 | label = torch.from_numpy(np.array([label]).astype(np.int64)) 174 | label = label.squeeze(0) 175 | 176 | if self.segmentation: 177 | seg = self.seg[item][:self.num_points] 178 | seg = torch.from_numpy(seg) 179 | return point_set, label, seg, name, file 180 | else: 181 | return point_set, label, name, file 182 | 183 | def __len__(self): 184 | return self.data.shape[0] 185 | 186 | 187 | if __name__ == '__main__': 188 | # root = os.getcwd() 189 | 190 | # choose dataset name from 'shapenetcorev2', 'shapenetpart', 'modelnet40' and 'modelnet10' 191 | dataset_name = 'shapenetpart' 192 | 193 | # choose split type from 'train', 'test', 'all', 'trainval' and 'val' 194 | # only shapenetcorev2 and shapenetpart dataset support 'trainval' and 'val' 195 | split = 'train' 196 | segmentation = True 197 | d = Dataset(root='../../', dataset_name=dataset_name, class_choice='motorbike', num_points=2048, split=split, segmentation=segmentation) 198 | print("datasize:", d.__len__()) 199 | 200 | item = 1 201 | if segmentation: 202 | ps, lb, seg, n, f = d[item] 203 | print(ps.size(), ps.type(), seg.size(), seg.type, lb.size(), lb.type(), n, f) 204 | print(seg) 205 | else: 206 | ps, lb, n, f = d[item] 207 | print(ps.size(), ps.type(), lb.size(), lb.type(), n, f) 208 | -------------------------------------------------------------------------------- /data_utils/keypointnet_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from .ModelNetDataLoader import pc_normalize 3 | import json 4 | import numpy as np 5 | 6 | 7 | def naive_read_pcd(path): 8 | lines = open(path, 'r').readlines() 9 | idx = -1 10 | for i, line in enumerate(lines): 11 | if line.startswith('DATA ascii'): 12 | idx = i + 1 13 | break 14 | lines = lines[idx:] 15 | lines = [line.rstrip().split(' ') for line in lines] 16 | data = np.asarray(lines) 17 | pc = np.array(data[:, :3], dtype=float) 18 | return pc 19 | 20 | 21 | class KeyPointNetDataLoader(data.Dataset): 22 | def __init__(self, num_points, json_path, pcd_path, split='train'): 23 | super().__init__() 24 | self.num_points = num_points 25 | self.data_path = json.load(open(json_path)) 26 | self.pcd_path = pcd_path 27 | if split == 'train': 28 | self.data_path = self.data_path[:-200] 29 | else: 30 | self.data_path = self.data_path[-200:] 31 | 32 | def __getitem__(self, idx): 33 | # res = {} 34 | class_id = self.data_path[idx]['class_id'] 35 | model_id = self.data_path[idx]['model_id'] 36 | points = naive_read_pcd(r'{}/{}/{}.pcd'.format(self.pcd_path, class_id, model_id)).astype(np.float32) 37 | points = pc_normalize(points) 38 | ground_truths = np.array( 39 | [points[point['pcd_info']['point_index']] for point in self.data_path[idx]['keypoints']]) 40 | ground_truths_num = ground_truths.shape[0] 41 | ground_truths = np.pad(ground_truths, ((0, 18 - ground_truths_num), (0, 0,)), 'constant', 42 | constant_values=(0, 0)) 43 | return points[:self.num_points, :], ground_truths, ground_truths_num, ground_truths_num 44 | 45 | def __len__(self): 46 | return len(self.data_path) 47 | 48 | 49 | def paint(points_xyz, origin_xyz): 50 | import matplotlib.pyplot as plt 51 | 52 | x1 = points_xyz[:, 0] 53 | y1 = points_xyz[:, 1] 54 | z1 = points_xyz[:, 2] 55 | 56 | x2 = origin_xyz[:, 0] 57 | y2 = origin_xyz[:, 1] 58 | z2 = origin_xyz[:, 2] 59 | 60 | ax1 = plt.subplot(111, projection='3d') 61 | ax1.scatter(x1, y1, z1, c=COLOR_LIST[:points_xyz.shape[0], :] / 255, s=48) 62 | ax1.scatter(x2, y2, z2, c='#A9A9A9', s=1) 63 | ax1.axis('off') 64 | plt.show() 65 | 66 | 67 | def create_color_list(num): 68 | import random 69 | colors = np.ndarray(shape=(num, 3)) 70 | for i in range(0, num): 71 | colors[i, 0] = random.randint(0, 255) 72 | colors[i, 1] = random.randint(0, 255) 73 | colors[i, 2] = random.randint(100, 255) 74 | 75 | colors[0, :] = np.array([0, 0, 0]).astype(int) 76 | colors[1, :] = np.array([146, 61, 10]).astype(int) 77 | colors[2, :] = np.array([102, 97, 0]).astype(int) 78 | colors[3, :] = np.array([255, 0, 0]).astype(int) 79 | colors[4, :] = np.array([113, 0, 17]).astype(int) 80 | colors[5, :] = np.array([255, 127, 39]).astype(int) 81 | colors[6, :] = np.array([255, 242, 0]).astype(int) 82 | colors[7, :] = np.array([0, 255, 0]).astype(int) 83 | colors[8, :] = np.array([0, 0, 255]).astype(int) 84 | colors[9, :] = np.array([15, 77, 33]).astype(int) 85 | colors[10, :] = np.array([163, 73, 164]).astype(int) 86 | colors[11, :] = np.array([255, 174, 201]).astype(int) 87 | colors[12, :] = np.array([255, 220, 14]).astype(int) 88 | colors[13, :] = np.array([181, 230, 29]).astype(int) 89 | colors[14, :] = np.array([153, 217, 234]).astype(int) 90 | colors[15, :] = np.array([112, 146, 190]).astype(int) 91 | 92 | return colors 93 | 94 | 95 | COLOR_LIST = create_color_list(200) 96 | 97 | if __name__ == '__main__': 98 | import torch 99 | 100 | data = KeyPointNetDataLoader(json_path='./keypointnet/annotations/mug.json', pcd_path='../keypointnet/pcds', 101 | split='val',num_points=1024) 102 | 103 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True, num_workers=10, drop_last=True) 104 | print(len(DataLoader)) 105 | for point, label, ground_truths_num in DataLoader: 106 | paint(label[0], point[0]) 107 | # print(point.shape) 108 | # print(label.shape) 109 | # print(ground_truths_num) 110 | -------------------------------------------------------------------------------- /data_utils/shapenet_seg_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch.utils.data as data 9 | import os 10 | import ntpath 11 | import pickle 12 | import re 13 | 14 | 15 | def natural_sort_key(string_): 16 | """See http://www.codinghorror.com/blog/archives/001018.html""" 17 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)] 18 | 19 | 20 | def sort_strs_by_number(strs): 21 | return sorted(strs, key=natural_sort_key) 22 | 23 | 24 | class ShapenetSegDataloader(data.Dataset): 25 | def __init__(self, data_root, category, preload_data=True, train=True): 26 | super().__init__() 27 | self.category = category 28 | self.data_root = data_root 29 | self.data_path = os.path.join(data_root, category) 30 | self.sub_dirs = [ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()] 31 | #self.sub_dirs = sort_strs_by_number([ntpath.basename(f.path) for f in os.scandir(self.data_path) if f.is_dir()]) 32 | self.data_num = len(self.sub_dirs) 33 | self.train = train 34 | 35 | self.meta_data_list = None 36 | if preload_data: 37 | self.meta_data_list = [] 38 | for i in range(self.data_num): 39 | meta_fname = os.path.join(self.data_path, self.sub_dirs[i], 'meta.pkl') 40 | with open(meta_fname, 'rb') as f: 41 | meta_data = pickle.load(f) 42 | self.meta_data_list.append(meta_data) 43 | 44 | def __getitem__(self, idx): 45 | if self.meta_data_list is None: 46 | meta_fname = os.path.join(self.data_path, self.sub_dirs[idx], 'meta.pkl') 47 | f = open(meta_fname, 'rb') 48 | meta_data = pickle.load(f) 49 | else: 50 | meta_data = self.meta_data_list[idx] 51 | 52 | points = meta_data['points'] 53 | gt_points = meta_data['gt_points'] 54 | gt_points_labels = meta_data['gt_points_labels'] - 1 55 | meta_data = {'points': points} 56 | 57 | if self.train is not True: 58 | meta_data['gt_points'] = gt_points 59 | meta_data['gt_points_labels'] = gt_points_labels 60 | return meta_data 61 | 62 | def __len__(self): 63 | return self.data_num -------------------------------------------------------------------------------- /images/Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/images/Pipeline.png -------------------------------------------------------------------------------- /images/VisualizationResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/images/VisualizationResults.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenguoz/Keypoints/32ff83db4db5b44432eec575afd65f9bee52fab2/models/__init__.py -------------------------------------------------------------------------------- /models/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | import torch.nn as nn 10 | from torch_pointnet_utils import query_ball_point, index_points 11 | import torch.nn.functional as F 12 | 13 | 14 | def query_KNN_tensor(points, query_pts, k): 15 | ''' 16 | 17 | :param points: bn x n x 3 18 | :param query_pts: bn x m x 3 19 | :param k: num of neighbors 20 | :return: nb x m x k ids, sorted_squared_dis 21 | ''' 22 | 23 | diff = query_pts[:, :, None, :] - points[:, None, :, :] 24 | 25 | squared_dis = torch.sum(diff * diff, dim=3) # bn x m x n 26 | sorted_squared_dis, sorted_idxs = torch.sort(squared_dis, dim=2) 27 | sorted_idxs = sorted_idxs[:, :, :k] 28 | sorted_squared_dis = sorted_squared_dis[:, :, :k] 29 | 30 | return sorted_idxs, sorted_squared_dis 31 | 32 | 33 | def compute_chamfer_distance(p1, p2): 34 | ''' 35 | Calculate Chamfer Distance between two point sets 36 | :param p1: size[bn, N, D] 37 | :param p2: size[bn, M, D] 38 | :return: sum of Chamfer Distance of two point sets 39 | ''' 40 | 41 | diff = p1[:, :, None, :] - p2[:, None, :, :] 42 | dist = torch.sum(diff * diff, dim=3) 43 | dist1 = dist 44 | dist2 = torch.transpose(dist, 1, 2) 45 | 46 | dist_min1, _ = torch.min(dist1, dim=2) 47 | dist_min2, _ = torch.min(dist2, dim=2) 48 | 49 | return dist_min1, dist_min2 50 | 51 | 52 | # def compute_vector_distance(p1, p2): 53 | # ''' 54 | # Calculate Chamfer Distance between two point sets 55 | # :param p1: size[bn, N, D] 56 | # 57 | # :return: sum of Chamfer Distance of two point sets 58 | # ''' 59 | # # device = p1.device 60 | # 61 | # vectors = p1[:, :, None, :] - p1[:, None, :, :] 62 | # # vectors=vectors.float() 63 | # vectors = torch.flatten(vectors, start_dim=1, end_dim=2) 64 | # vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3) 65 | # vectors_dot = torch.abs(vectors_dot) 66 | # 67 | # 68 | # # edge_dict = compute_edge_distance(p1, p2, maxdis=0.001) 69 | # # idx = torch.zeros(vectors_dot.shape).cuda() 70 | # # idx_num = 0 71 | # # for key, values in edge_dict.items(): 72 | # # for value in values: 73 | # # idx[key, value[0], value[1]] = 1 74 | # # idx_num += 1 75 | # # 76 | # # vectors_dot = vectors_dot * idx 77 | # 78 | # 79 | # # print(vectors_dot[vectors_dot>0]) 80 | # vectors_abs = torch.sqrt(torch.sum(vectors * vectors, dim=2, keepdim=True)) 81 | # vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3) 82 | # vectors_abs_dot[vectors_abs_dot == 0] = 1 83 | # vectors_cos = vectors_dot / vectors_abs_dot 84 | # 85 | # return torch.mean(vectors_cos) 86 | # 87 | # # return torch.sum(vectors_cos) / idx_num 88 | 89 | 90 | def compute_vector_distance(p1, p2=None): 91 | ''' 92 | Calculate Chamfer Distance between two point sets 93 | :param p1: size[bn, N, D] 94 | :param p2: size[bn, M, D] 95 | :return: sum of Chamfer Distance of two point sets 96 | ''' 97 | 98 | vectors = p1[:, :, None, :] - p1[:, None, :, :] 99 | 100 | vectors = torch.flatten(vectors, start_dim=1, end_dim=2) 101 | vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3) 102 | vectors_dot = torch.abs(vectors_dot) 103 | 104 | edge_dict = compute_edge_distance(p1, p2, maxdis=0.001) 105 | idx = torch.zeros(vectors_dot.shape).cuda() 106 | idx_num = 1 107 | for key, values in edge_dict.items(): 108 | for value in values: 109 | idx[key, value[0], value[1]] = 1 110 | idx_num += 1 111 | vectors_abs = torch.sqrt(torch.sum(vectors * vectors + 1e-5, dim=2, keepdim=True)) 112 | vectors_dot = vectors_dot * idx 113 | # vectors_abs = torch.sum(vectors * vectors, dim=2, keepdim=True) 114 | vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3) 115 | vectors_abs_dot[vectors_abs_dot == 0] = 1 116 | vectors_cos = vectors_dot / vectors_abs_dot 117 | return torch.sum(vectors_cos) / idx_num 118 | 119 | 120 | def compute_edge_distance(p1, p2, maxdis=0.001): 121 | point_num = p1.shape[1] 122 | edge_list = [] 123 | batch_dict = {} 124 | # for batch in range(p1.shape[0]): 125 | for i in range(point_num): 126 | for j in range(i): 127 | start = p1[:, i, :] 128 | end = p1[:, j, :] 129 | dist = torch.mean(torch.sqrt(1e-3 + (torch.sum(torch.square(start - end), dim=-1)))) 130 | count = 5 131 | device = dist.device 132 | f_interp = torch.linspace(0.0, 1.0, count).unsqueeze(0).unsqueeze(-1).to(device) 133 | b_interp = 1.0 - f_interp 134 | K = start.unsqueeze(-2) * f_interp + end.unsqueeze(-2) * b_interp 135 | dist1, dist2 = compute_chamfer_distance(K, p2) 136 | cdis = (torch.mean(dist1, dim=1)) 137 | for x in torch.where(cdis < maxdis)[0].cpu().numpy(): 138 | edge_list.append([x, i, j]) 139 | 140 | for x in edge_list: 141 | if x[0] in batch_dict: 142 | batch_dict[x[0]].append(x[1:]) 143 | else: 144 | batch_dict[x[0]] = [x[1:]] 145 | 146 | return batch_dict 147 | 148 | 149 | def compute_offset_distance(p1): 150 | return torch.sum(p1 * p1) 151 | 152 | 153 | def compute_vector_similarity_distance(p1, similarity_map, threshold): 154 | ''' 155 | Calculate Chamfer Distance between two point sets 156 | :param p1: size[bn, N, D] 157 | :param p2: size[bn, M, D] 158 | :return: sum of Chamfer Distance of two point sets 159 | ''' 160 | 161 | B, N, D = p1.shape 162 | device = p1.device 163 | vectors = p1[:, :, None, :] - p1[:, None, :, :] 164 | similarity_map[similarity_map < threshold] = 0 165 | vectors = vectors * similarity_map.view(B, N, N, 1).to(device) 166 | vectors = torch.flatten(vectors, start_dim=1, end_dim=2) 167 | vectors_dot = torch.sum(vectors[:, :, None, :] * vectors[:, None, :, :], dim=3) 168 | idx_num = torch.sum(vectors_dot != 0) 169 | vectors_dot = torch.abs(vectors_dot) 170 | vectors_abs = torch.sqrt(torch.sum(vectors * vectors + 1e-9, dim=2, keepdim=True)) 171 | vectors_abs_dot = torch.sum(vectors_abs[:, :, None, :] * vectors_abs[:, None, :, :], dim=3) 172 | vectors_abs_dot[vectors_abs_dot == 0] = 1 173 | vectors_cos = vectors_dot / vectors_abs_dot 174 | return torch.sum(vectors_cos) / idx_num 175 | 176 | 177 | def compute_end_distance(p1, p2, radius=0.1, nsample=16): 178 | ''' 179 | Calculate Chamfer Distance between two point sets 180 | :param p1: size[bn, N, D] 181 | :param p2: size[bn, M, D] 182 | :return: sum of Chamfer Distance of two point sets 183 | ''' 184 | B, S, C = p1.shape 185 | idx = query_ball_point(radius, nsample, p2, p1) 186 | p2 = index_points(p2, idx) # [B, npoint, nsample, C] 187 | p2_norm = p2 - p1.view(B, S, 1, C) 188 | # p2_norm = p2_norm / torch.norm(p2_norm, p=2, dim=3, keepdim=True) 189 | dist = torch.sum(p2_norm, dim=2) 190 | dist = torch.sqrt(torch.sum(dist * dist, dim=2)) 191 | return torch.mean(dist) 192 | 193 | 194 | class ComputeCDLoss(nn.Module): 195 | def __init__(self): 196 | super(ComputeCDLoss, self).__init__() 197 | 198 | def forward(self, recon_points, gt_points): 199 | dist1, dist2 = compute_chamfer_distance(recon_points, gt_points) 200 | loss = (torch.sum(dist1) + torch.sum(dist2)) / ((recon_points.shape[0]) * gt_points.shape[1]) * 1024 201 | # print(torch.sum(dist1), torch.sum(dist2)) 202 | # loss = (torch.mean(dist1) + torch.mean(dist2) * 32) 203 | 204 | return loss 205 | 206 | 207 | class ComputeVecLoss(nn.Module): 208 | def __init__(self): 209 | super(ComputeVecLoss, self).__init__() 210 | 211 | def forward(self, recon_points, gt_points): 212 | dist3 = compute_vector_distance(recon_points, gt_points) 213 | # loss_align = torch.sum(dist3) / (recon_points.shape[0]) 214 | return dist3 215 | 216 | 217 | class ComputeEdgeLoss(nn.Module): 218 | def __init__(self): 219 | super(ComputeEdgeLoss, self).__init__() 220 | 221 | def forward(self, recon_points, gt_points): 222 | dist3 = compute_edge_distance(recon_points, gt_points) 223 | # loss_align = torch.sum(dist3) / (recon_points.shape[0]) 224 | return dist3 225 | 226 | 227 | class ComputeOffsetLoss(nn.Module): 228 | def __init__(self): 229 | super(ComputeOffsetLoss, self).__init__() 230 | 231 | def forward(self, fps_points_offset): 232 | return compute_offset_distance(fps_points_offset) / (fps_points_offset.shape[0]) 233 | 234 | 235 | class ComputeVecSimilarityLoss(nn.Module): 236 | def __init__(self): 237 | super(ComputeVecSimilarityLoss, self).__init__() 238 | 239 | def forward(self, gt_points, cos_similarity, threshold): 240 | return compute_vector_similarity_distance(gt_points, cos_similarity, threshold) 241 | 242 | 243 | class ComputeEndLoss(nn.Module): 244 | def __init__(self): 245 | super(ComputeEndLoss, self).__init__() 246 | 247 | def forward(self, recon_points, gt_points): 248 | dist1 = compute_end_distance(recon_points, gt_points, radius=0.1, nsample=16) 249 | loss = dist1 / recon_points.shape[1] * 24 250 | 251 | return loss 252 | 253 | 254 | if __name__ == '__main__': 255 | p1 = torch.rand((2, 16, 3)) 256 | p2 = torch.rand((2, 128, 3)) 257 | p3 = torch.rand(2, 16, 16) 258 | a = ComputeVecSimilarityLoss() 259 | 260 | print(a(p1, p3, 0.95)) 261 | -------------------------------------------------------------------------------- /models/model_weightchamfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_pointnet_utils import PointNetSetAbstractionMsg 4 | from models import chamfer_distance 5 | 6 | 7 | class ComputeLoss3d(nn.Module): 8 | def __init__(self): 9 | super(ComputeLoss3d, self).__init__() 10 | 11 | def compute_chamfer_distance(self, p1, p2): 12 | ''' 13 | Calculate Chamfer Distance between two point sets 14 | :param p1: size[bn, N, D] 15 | :param p2: size[bn, M, D] 16 | :return: sum of Chamfer Distance of two point sets 17 | ''' 18 | 19 | diff = p1[:, :, None, :] - p2[:, None, :, :] 20 | dist = torch.sum(diff * diff, dim=3) 21 | dist_min1, _ = torch.min(dist, dim=2) 22 | dist_min2, _ = torch.min(dist, dim=1) 23 | 24 | return (torch.sum(dist_min1) + torch.sum(dist_min2)) / (p1.shape[0]) 25 | 26 | def forward(self, gt_points, structure_points, origin_points=None): 27 | if origin_points is None: 28 | return self.compute_chamfer_distance(gt_points, structure_points) 29 | else: 30 | return self.compute_chamfer_distance(gt_points, structure_points) + self.compute_chamfer_distance(origin_points, structure_points) 31 | 32 | 33 | class VecLoss(nn.Module): 34 | def __init__(self): 35 | super(VecLoss, self).__init__() 36 | self.vec_loss_fun = chamfer_distance.ComputeVecSimilarityLoss() 37 | 38 | def forward(self, structure_points, similarity_map, threshold=0.95): 39 | structure_points = structure_points.cuda() 40 | self.vec_loss = self.vec_loss_fun(structure_points, similarity_map, threshold) * 10 41 | return self.vec_loss 42 | 43 | 44 | class WeightedChamferLoss(nn.Module): 45 | def __init__(self): 46 | super(WeightedChamferLoss, self).__init__() 47 | 48 | def compute_chamfer_distance(self, p1, p2): 49 | ''' 50 | Calculate Chamfer Distance between two point sets 51 | :param p1: size[bn, N, D] 52 | :param p2: size[bn, M, D] 53 | :return: sum of Chamfer Distance of two point sets 54 | ''' 55 | 56 | diff = p1[:, :, None, :] - p2[:, None, :, :] 57 | dist = torch.sum(diff * diff, dim=3) 58 | dist_min1, _ = torch.min(dist, dim=2) 59 | dist_min2, _ = torch.min(dist, dim=1) 60 | 61 | return (torch.mean(dist_min1) + torch.mean(dist_min2)) 62 | 63 | def compute_end_distance(self, fps_points, structure_points, weight_map): 64 | ''' 65 | Calculate Chamfer Distance between two point sets 66 | :param p1: size[bn, N, D] 67 | :param p2: size[bn, M, D] 68 | :return: sum of Chamfer Distance of two point sets 69 | ''' 70 | # weight_map = torch.sigmoid(weight_map) 71 | weight_map = torch.sum(weight_map, dim=1) 72 | weight_map = torch.sigmoid(weight_map) + 1 73 | diff = structure_points[:, :, None, :] - fps_points[:, None, :, :] 74 | dist = torch.sum(diff * diff, dim=3) 75 | dist_min1, _ = torch.min(dist, dim=2) 76 | dist_min2, _ = torch.min(dist, dim=1) 77 | dist_min2 = dist_min2 * weight_map 78 | return (torch.mean(dist_min1) + torch.mean(dist_min2)) 79 | 80 | def forward(self, fps_points, structure_points, weight_map, origin_points=None): 81 | if origin_points is None: 82 | return self.compute_end_distance(fps_points, structure_points, weight_map) 83 | else: 84 | return self.compute_end_distance(fps_points, structure_points, weight_map) + self.compute_chamfer_distance(origin_points, structure_points) 85 | 86 | 87 | class Conv1dProbLayer(nn.Module): 88 | def __init__(self, in_channels, out_channels, out=False, kernel_size=1, dropout=0.2, normalize=False): 89 | super(Conv1dProbLayer, self).__init__() 90 | self.out = out 91 | self.dropout_conv_bn_layer = nn.Sequential( 92 | nn.Dropout(dropout), 93 | nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size), 94 | nn.BatchNorm1d(num_features=out_channels), 95 | ) 96 | self.relu = nn.ReLU() 97 | self.softmax = nn.Softmax(dim=2) 98 | self.normalize = normalize 99 | self.normalize_layer = Normalize(dim=1) 100 | 101 | def forward(self, x): 102 | x = self.dropout_conv_bn_layer(x) 103 | if self.normalize: 104 | x = self.normalize_layer(x) 105 | if self.out: 106 | x = self.softmax(x) 107 | else: 108 | x = self.relu(x) 109 | return x 110 | 111 | 112 | class RefineNet(nn.Module): 113 | def __init__(self, num_structure_points, in_channel=128 + 256 + 256, out=True, normalize=False): 114 | super(RefineNet, self).__init__() 115 | 116 | conv1d_stpts_prob_modules = [] 117 | if num_structure_points <= in_channel: 118 | conv1d_stpts_prob_modules.append( 119 | Conv1dProbLayer(in_channels=in_channel, out_channels=512, kernel_size=1, out=False)) 120 | in_channels = 512 121 | while in_channels >= num_structure_points * 2: 122 | out_channels = int(in_channels / 2) 123 | conv1d_stpts_prob_modules.append( 124 | Conv1dProbLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, out=False)) 125 | in_channels = out_channels 126 | conv1d_stpts_prob_modules.append( 127 | Conv1dProbLayer(in_channels=in_channels, out_channels=num_structure_points, kernel_size=1, out=out, 128 | normalize=normalize)) 129 | else: 130 | conv1d_stpts_prob_modules.append( 131 | Conv1dProbLayer(in_channels=in_channel, out_channels=1024, kernel_size=1, out=False)) 132 | in_channels = 1024 133 | while in_channels <= num_structure_points / 2: 134 | out_channels = int(in_channels * 2) 135 | conv1d_stpts_prob_modules.append( 136 | Conv1dProbLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, out=False)) 137 | in_channels = out_channels 138 | conv1d_stpts_prob_modules.append( 139 | Conv1dProbLayer(in_channels=in_channels, out_channels=num_structure_points, kernel_size=1, out=out, 140 | normalize=normalize)) 141 | 142 | self.conv1d_stpts_prob = nn.Sequential(*conv1d_stpts_prob_modules) 143 | 144 | def forward(self, features): 145 | return self.conv1d_stpts_prob(features) 146 | 147 | 148 | 149 | 150 | class FeatureMergeBlock(nn.Module): 151 | def __init__(self, num_structure_points, boom_rate=2): 152 | super(FeatureMergeBlock, self).__init__() 153 | self.relu = nn.ReLU(inplace=True) 154 | self.merge_layer = nn.Sequential( 155 | nn.Conv2d(num_structure_points, boom_rate * num_structure_points, kernel_size=1), 156 | nn.BatchNorm2d(boom_rate * num_structure_points), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(boom_rate * num_structure_points, num_structure_points, kernel_size=1), 159 | nn.BatchNorm2d(num_structure_points), 160 | ) 161 | 162 | def forward(self, features): 163 | return self.relu(self.merge_layer(features) + features) 164 | # return self.merge_layer(features) + features 165 | 166 | 167 | class Normalize(nn.Module): 168 | def __init__(self, dim): 169 | super().__init__() 170 | self.dim = dim 171 | 172 | def forward(self, x): 173 | norm = torch.norm(x, p=2, dim=self.dim, keepdim=True) 174 | return x / norm 175 | 176 | 177 | class Pointnet2StructurePointNet(nn.Module): 178 | 179 | def __init__(self, num_structure_points, input_channels=3, multi_distribution_num=1, offset=False, 180 | merge_block_num=1): 181 | super(Pointnet2StructurePointNet, self).__init__() 182 | self.point_dim = 3 183 | self.num_structure_points = num_structure_points 184 | self.offset = offset 185 | self.input_channels = input_channels 186 | self.num_structure_points = num_structure_points 187 | 188 | self.sa1 = PointNetSetAbstractionMsg(npoint=512, radius_list=[0.1, 0.2, 0.4], nsample_list=[16, 32, 128], 189 | in_channel=0, mlp_list=[[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 190 | self.sa2 = PointNetSetAbstractionMsg(npoint=128, radius_list=[0.2, 0.4, 0.8], nsample_list=[32, 64, 128], 191 | in_channel=64 + 128 + 128, 192 | mlp_list=[[64, 64, 128], [128, 128, 256], [128, 128, 256]]) 193 | 194 | self.multi_distri_layers = nn.ModuleList() 195 | for i in range(multi_distribution_num): 196 | self.multi_distri_layers.append(RefineNet(num_structure_points, in_channel=128 + 256 + 256, out=False)) 197 | 198 | feature_merge = [] 199 | for _ in range(merge_block_num): 200 | feature_merge.append(FeatureMergeBlock(num_structure_points=num_structure_points)) 201 | self.feature_merge = nn.Sequential(*feature_merge) 202 | self.softmax = nn.Softmax(dim=2) 203 | 204 | def forward(self, pointcloud): 205 | ''' 206 | :param pointcloud: input point cloud with shape (bn, num_of_pts, 3) 207 | :param return_weighted_feature: whether return features for the structure points or not 208 | :return: 209 | ''' 210 | _ = None 211 | B = pointcloud.shape[0] 212 | if pointcloud.shape[2] == 3: 213 | pointcloud = pointcloud.permute(0, 2, 1) 214 | 215 | if pointcloud.shape[1] > 3: 216 | xyz = pointcloud[:, :3, :] 217 | features = pointcloud[:, 3:, :] 218 | else: 219 | xyz = pointcloud 220 | features = None 221 | xyz, features = self.sa1(xyz, features) 222 | xyz, features = self.sa2(xyz, features) 223 | 224 | stpts_prob_map = [] 225 | for i in range(len(self.multi_distri_layers)): 226 | stpts_prob_map.append(self.multi_distri_layers[i](features)) 227 | 228 | stpts_prob_map = torch.stack(stpts_prob_map, dim=3) 229 | stpts_prob_map = torch.max(self.feature_merge(stpts_prob_map), dim=3)[0] 230 | stpts_prob_map = self.softmax(stpts_prob_map) 231 | 232 | # (4,16,128) *(4,3,128) 233 | if not xyz.shape[2] == 3: 234 | xyz = xyz.permute(0, 2, 1) 235 | 236 | structure_points = torch.sum(stpts_prob_map[:, :, :, None] * xyz[:, None, :, :], dim=2) 237 | # features = features * torch.sum(stpts_prob_map, dim=1, keepdim=True) 238 | # weighted_features = torch.sum(stpts_prob_map[:, None, :, :] * features[:, :, None, :], dim=3) 239 | # cos_similarity = SimilarityPoints(weighted_features) 240 | cos_similarity = None 241 | return structure_points, xyz, cos_similarity, stpts_prob_map 242 | 243 | 244 | if __name__ == '__main__': 245 | data = torch.rand(2, 3, 1024) 246 | print("===> testing pointSPN ...") 247 | model = Pointnet2StructurePointNet(num_structure_points=16, offset=False) 248 | data = data.cuda() 249 | model = model.cuda() 250 | structure_points, xyz, cos_similarity, stpts_prob_map = model(data) 251 | loss = ComputeLoss3d() 252 | loss = loss.cuda() 253 | print(loss(xyz, structure_points)) 254 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | class STN3d(nn.Module): 11 | def __init__(self, channel): 12 | super(STN3d, self).__init__() 13 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 14 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 15 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 16 | self.fc1 = nn.Linear(1024, 512) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.fc3 = nn.Linear(256, 9) 19 | self.relu = nn.ReLU() 20 | 21 | self.bn1 = nn.BatchNorm1d(64) 22 | self.bn2 = nn.BatchNorm1d(128) 23 | self.bn3 = nn.BatchNorm1d(1024) 24 | self.bn4 = nn.BatchNorm1d(512) 25 | self.bn5 = nn.BatchNorm1d(256) 26 | 27 | def forward(self, x): 28 | batchsize = x.size()[0] 29 | x = F.relu(self.bn1(self.conv1(x))) 30 | x = F.relu(self.bn2(self.conv2(x))) 31 | x = F.relu(self.bn3(self.conv3(x))) 32 | x = torch.max(x, 2, keepdim=True)[0] 33 | x = x.view(-1, 1024) 34 | 35 | x = F.relu(self.bn4(self.fc1(x))) 36 | x = F.relu(self.bn5(self.fc2(x))) 37 | x = self.fc3(x) 38 | 39 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 40 | batchsize, 1) 41 | if x.is_cuda: 42 | iden = iden.cuda() 43 | x = x + iden 44 | x = x.view(-1, 3, 3) 45 | return x 46 | 47 | 48 | class STNkd(nn.Module): 49 | def __init__(self, k=64): 50 | super(STNkd, self).__init__() 51 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 52 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 53 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 54 | self.fc1 = nn.Linear(1024, 512) 55 | self.fc2 = nn.Linear(512, 256) 56 | self.fc3 = nn.Linear(256, k * k) 57 | self.relu = nn.ReLU() 58 | 59 | self.bn1 = nn.BatchNorm1d(64) 60 | self.bn2 = nn.BatchNorm1d(128) 61 | self.bn3 = nn.BatchNorm1d(1024) 62 | self.bn4 = nn.BatchNorm1d(512) 63 | self.bn5 = nn.BatchNorm1d(256) 64 | 65 | self.k = k 66 | 67 | def forward(self, x): 68 | batchsize = x.size()[0] 69 | x = F.relu(self.bn1(self.conv1(x))) 70 | x = F.relu(self.bn2(self.conv2(x))) 71 | x = F.relu(self.bn3(self.conv3(x))) 72 | x = torch.max(x, 2, keepdim=True)[0] 73 | x = x.view(-1, 1024) 74 | 75 | x = F.relu(self.bn4(self.fc1(x))) 76 | x = F.relu(self.bn5(self.fc2(x))) 77 | x = self.fc3(x) 78 | 79 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( 80 | batchsize, 1) 81 | if x.is_cuda: 82 | iden = iden.cuda() 83 | x = x + iden 84 | x = x.view(-1, self.k, self.k) 85 | return x 86 | 87 | 88 | class PointNetEncoder(nn.Module): 89 | def __init__(self, global_feat=True, feature_transform=False, channel=3): 90 | super(PointNetEncoder, self).__init__() 91 | self.stn = STN3d(channel) 92 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 93 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 94 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 95 | self.bn1 = nn.BatchNorm1d(64) 96 | self.bn2 = nn.BatchNorm1d(128) 97 | self.bn3 = nn.BatchNorm1d(1024) 98 | self.global_feat = global_feat 99 | self.feature_transform = feature_transform 100 | if self.feature_transform: 101 | self.fstn = STNkd(k=64) 102 | 103 | def forward(self, x): 104 | B, D, N = x.size() 105 | trans = self.stn(x) 106 | x = x.transpose(2, 1) 107 | if D >3 : 108 | x, feature = x.split(3,dim=2) 109 | x = torch.bmm(x, trans) 110 | if D > 3: 111 | x = torch.cat([x,feature],dim=2) 112 | x = x.transpose(2, 1) 113 | x = F.relu(self.bn1(self.conv1(x))) 114 | 115 | if self.feature_transform: 116 | trans_feat = self.fstn(x) 117 | x = x.transpose(2, 1) 118 | x = torch.bmm(x, trans_feat) 119 | x = x.transpose(2, 1) 120 | else: 121 | trans_feat = None 122 | 123 | pointfeat = x 124 | x = F.relu(self.bn2(self.conv2(x))) 125 | x = self.bn3(self.conv3(x)) 126 | x = torch.max(x, 2, keepdim=True)[0] 127 | x = x.view(-1, 1024) 128 | if self.global_feat: 129 | return x, trans, trans_feat 130 | else: 131 | x = x.view(-1, 1024, 1).repeat(1, 1, N) 132 | return torch.cat([x, pointfeat], 1), trans, trans_feat 133 | 134 | 135 | def feature_transform_reguliarzer(trans): 136 | d = trans.size()[1] 137 | I = torch.eye(d)[None, :, :] 138 | if trans.is_cuda: 139 | I = I.cuda() 140 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2))) 141 | return loss 142 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_cls_msg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetSetAbstraction 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self,num_class,normal_channel=True): 8 | super(get_model, self).__init__() 9 | in_channel = 3 if normal_channel else 0 10 | self.normal_channel = normal_channel 11 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 12 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]]) 13 | self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) 14 | self.fc1 = nn.Linear(1024, 512) 15 | self.bn1 = nn.BatchNorm1d(512) 16 | self.drop1 = nn.Dropout(0.4) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.bn2 = nn.BatchNorm1d(256) 19 | self.drop2 = nn.Dropout(0.5) 20 | self.fc3 = nn.Linear(256, num_class) 21 | 22 | def forward(self, xyz): 23 | B, _, _ = xyz.shape 24 | if self.normal_channel: 25 | norm = xyz[:, 3:, :] 26 | xyz = xyz[:, :3, :] 27 | else: 28 | norm = None 29 | l1_xyz, l1_points = self.sa1(xyz, norm) 30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 32 | x = l3_points.view(B, 1024) 33 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 34 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 35 | x = self.fc3(x) 36 | x = F.log_softmax(x, -1) 37 | 38 | 39 | return x,l3_points 40 | 41 | 42 | class get_loss(nn.Module): 43 | def __init__(self): 44 | super(get_loss, self).__init__() 45 | 46 | def forward(self, pred, target, trans_feat): 47 | total_loss = F.nll_loss(pred, target) 48 | 49 | return total_loss 50 | 51 | 52 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_cls_ssg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .pointnet_util import PointNetSetAbstraction 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self,num_class,normal_channel=True): 8 | super(get_model, self).__init__() 9 | in_channel = 6 if normal_channel else 3 10 | self.normal_channel = normal_channel 11 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False) 12 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 13 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 14 | self.fc1 = nn.Linear(1024, 512) 15 | self.bn1 = nn.BatchNorm1d(512) 16 | self.drop1 = nn.Dropout(0.4) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.bn2 = nn.BatchNorm1d(256) 19 | self.drop2 = nn.Dropout(0.4) 20 | self.fc3 = nn.Linear(256, num_class) 21 | 22 | def forward(self, xyz): 23 | B, _, _ = xyz.shape 24 | if self.normal_channel: 25 | norm = xyz[:, 3:, :] 26 | xyz = xyz[:, :3, :] 27 | else: 28 | norm = None 29 | l1_xyz, l1_points = self.sa1(xyz, norm) 30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 32 | x = l3_points.view(B, 1024) 33 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 34 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 35 | x = self.fc3(x) 36 | x = F.log_softmax(x, -1) 37 | 38 | 39 | return x, l3_points 40 | 41 | 42 | 43 | class get_loss(nn.Module): 44 | def __init__(self): 45 | super(get_loss, self).__init__() 46 | 47 | def forward(self, pred, target, trans_feat): 48 | total_loss = F.nll_loss(pred, target) 49 | 50 | return total_loss 51 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_part_seg_msg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .pointnet_util import PointNetSetAbstractionMsg,PointNetSetAbstraction,PointNetFeaturePropagation 5 | 6 | 7 | class get_model(nn.Module): 8 | def __init__(self, num_classes, normal_channel=False): 9 | super(get_model, self).__init__() 10 | if normal_channel: 11 | additional_channel = 3 12 | else: 13 | additional_channel = 0 14 | self.normal_channel = normal_channel 15 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 16 | self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]]) 17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True) 18 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256]) 19 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128]) 20 | self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128]) 21 | self.conv1 = nn.Conv1d(128, 128, 1) 22 | self.bn1 = nn.BatchNorm1d(128) 23 | self.drop1 = nn.Dropout(0.5) 24 | self.conv2 = nn.Conv1d(128, num_classes, 1) 25 | 26 | def forward(self, xyz, cls_label): 27 | # Set Abstraction layers 28 | B,C,N = xyz.shape 29 | if self.normal_channel: 30 | l0_points = xyz 31 | l0_xyz = xyz[:,:3,:] 32 | else: 33 | l0_points = xyz 34 | l0_xyz = xyz 35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 38 | # Feature Propagation layers 39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N) 42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points) 43 | # FC layers 44 | feat = F.relu(self.bn1(self.conv1(l0_points))) 45 | x = self.drop1(feat) 46 | x = self.conv2(x) 47 | x = F.log_softmax(x, dim=1) 48 | x = x.permute(0, 2, 1) 49 | return x, l3_points 50 | 51 | 52 | class get_loss(nn.Module): 53 | def __init__(self): 54 | super(get_loss, self).__init__() 55 | 56 | def forward(self, pred, target, trans_feat): 57 | total_loss = F.nll_loss(pred, target) 58 | 59 | return total_loss -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_part_seg_ssg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .pointnet_util import PointNetSetAbstraction,PointNetFeaturePropagation 5 | 6 | 7 | class get_model(nn.Module): 8 | def __init__(self, num_classes, normal_channel=False): 9 | super(get_model, self).__init__() 10 | if normal_channel: 11 | additional_channel = 3 12 | else: 13 | additional_channel = 0 14 | self.normal_channel = normal_channel 15 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=6+additional_channel, mlp=[64, 64, 128], group_all=False) 16 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 17 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 18 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256]) 19 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128]) 20 | self.fp1 = PointNetFeaturePropagation(in_channel=128+16+6+additional_channel, mlp=[128, 128, 128]) 21 | self.conv1 = nn.Conv1d(128, 128, 1) 22 | self.bn1 = nn.BatchNorm1d(128) 23 | self.drop1 = nn.Dropout(0.5) 24 | self.conv2 = nn.Conv1d(128, num_classes, 1) 25 | 26 | def forward(self, xyz, cls_label): 27 | # Set Abstraction layers 28 | B,C,N = xyz.shape 29 | if self.normal_channel: 30 | l0_points = xyz 31 | l0_xyz = xyz[:,:3,:] 32 | else: 33 | l0_points = xyz 34 | l0_xyz = xyz 35 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 37 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 38 | # Feature Propagation layers 39 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 40 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 41 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N) 42 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points) 43 | # FC layers 44 | feat = F.relu(self.bn1(self.conv1(l0_points))) 45 | x = self.drop1(feat) 46 | x = self.conv2(x) 47 | x = F.log_softmax(x, dim=1) 48 | x = x.permute(0, 2, 1) 49 | return x, l3_points 50 | 51 | 52 | class get_loss(nn.Module): 53 | def __init__(self): 54 | super(get_loss, self).__init__() 55 | 56 | def forward(self, pred, target, trans_feat): 57 | total_loss = F.nll_loss(pred, target) 58 | 59 | return total_loss -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_sem_seg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .pointnet_util import PointNetSetAbstraction,PointNetFeaturePropagation 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self, num_classes): 8 | super(get_model, self).__init__() 9 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False) 10 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 11 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 12 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 13 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 14 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 15 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 16 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 17 | self.conv1 = nn.Conv1d(128, 128, 1) 18 | self.bn1 = nn.BatchNorm1d(128) 19 | self.drop1 = nn.Dropout(0.5) 20 | self.conv2 = nn.Conv1d(128, num_classes, 1) 21 | 22 | def forward(self, xyz): 23 | l0_points = xyz 24 | l0_xyz = xyz[:,:3,:] 25 | 26 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 27 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 28 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 29 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 30 | 31 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 32 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 33 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 34 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 35 | 36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 37 | x = self.conv2(x) 38 | x = F.log_softmax(x, dim=1) 39 | x = x.permute(0, 2, 1) 40 | return x, l4_points 41 | 42 | 43 | class get_loss(nn.Module): 44 | def __init__(self): 45 | super(get_loss, self).__init__() 46 | def forward(self, pred, target, trans_feat, weight): 47 | total_loss = F.nll_loss(pred, target, weight=weight) 48 | 49 | return total_loss 50 | 51 | if __name__ == '__main__': 52 | import torch 53 | model = get_model(13) 54 | xyz = torch.rand(6, 9, 2048) 55 | (model(xyz)) -------------------------------------------------------------------------------- /models/pointnetpp/pointnet2_sem_seg_msg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .pointnet_util import PointNetSetAbstractionMsg, PointNetFeaturePropagation 4 | 5 | 6 | class get_model(nn.Module): 7 | def __init__(self, num_classes): 8 | super(get_model, self).__init__() 9 | 10 | self.sa1 = PointNetSetAbstractionMsg(1024, [0.05, 0.1], [16, 32], 9, [[16, 16, 32], [32, 32, 64]]) 11 | self.sa2 = PointNetSetAbstractionMsg(256, [0.1, 0.2], [16, 32], 32 + 64, [[64, 64, 128], [64, 96, 128]]) 12 | self.sa3 = PointNetSetAbstractionMsg(64, [0.2, 0.4], [16, 32], 128 + 128, [[128, 196, 256], [128, 196, 256]]) 13 | self.sa4 = PointNetSetAbstractionMsg(16, [0.4, 0.8], [16, 32], 256 + 256, [[256, 256, 512], [256, 384, 512]]) 14 | self.fp4 = PointNetFeaturePropagation(512 + 512 + 256 + 256, [256, 256]) 15 | self.fp3 = PointNetFeaturePropagation(128 + 128 + 256, [256, 256]) 16 | self.fp2 = PointNetFeaturePropagation(32 + 64 + 256, [256, 128]) 17 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 18 | self.conv1 = nn.Conv1d(128, 128, 1) 19 | self.bn1 = nn.BatchNorm1d(128) 20 | self.drop1 = nn.Dropout(0.5) 21 | self.conv2 = nn.Conv1d(128, num_classes, 1) 22 | 23 | def forward(self, xyz): 24 | l0_points = xyz 25 | l0_xyz = xyz[:, :3, :] 26 | 27 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) 28 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 29 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 30 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 31 | 32 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 33 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 34 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 35 | l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points) 36 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 37 | x = self.conv2(x) 38 | x = x.float() 39 | x = F.log_softmax(x, dim=1) 40 | x = x.permute(0, 2, 1) 41 | 42 | return x, l4_points 43 | 44 | 45 | class get_loss(nn.Module): 46 | def __init__(self): 47 | super(get_loss, self).__init__() 48 | 49 | def forward(self, pred, target, trans_feat, weight): 50 | total_loss = F.nll_loss(pred, target, weight=weight) 51 | 52 | return total_loss 53 | 54 | 55 | if __name__ == '__main__': 56 | import torch 57 | 58 | model = get_model(13) 59 | xyz = torch.rand(6, 9, 2048) 60 | print(model(xyz)) 61 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet_cls.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.data 3 | import torch.nn.functional as F 4 | from .pointnet import PointNetEncoder, feature_transform_reguliarzer 5 | 6 | class get_model(nn.Module): 7 | def __init__(self, k=40, normal_channel=True): 8 | super(get_model, self).__init__() 9 | if normal_channel: 10 | channel = 6 11 | else: 12 | channel = 3 13 | self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel) 14 | self.fc1 = nn.Linear(1024, 512) 15 | self.fc2 = nn.Linear(512, 256) 16 | self.fc3 = nn.Linear(256, k) 17 | self.dropout = nn.Dropout(p=0.4) 18 | self.bn1 = nn.BatchNorm1d(512) 19 | self.bn2 = nn.BatchNorm1d(256) 20 | self.relu = nn.ReLU() 21 | 22 | def forward(self, x): 23 | x, trans, trans_feat = self.feat(x) 24 | x = F.relu(self.bn1(self.fc1(x))) 25 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 26 | x = self.fc3(x) 27 | x = F.log_softmax(x, dim=1) 28 | return x, trans_feat 29 | 30 | class get_loss(torch.nn.Module): 31 | def __init__(self, mat_diff_loss_scale=0.001): 32 | super(get_loss, self).__init__() 33 | self.mat_diff_loss_scale = mat_diff_loss_scale 34 | 35 | def forward(self, pred, target, trans_feat): 36 | loss = F.nll_loss(pred, target) 37 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 38 | 39 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale 40 | return total_loss 41 | -------------------------------------------------------------------------------- /models/pointnetpp/pointnet_part_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | from .pointnet import STN3d, STNkd, feature_transform_reguliarzer 7 | 8 | 9 | class get_model(nn.Module): 10 | def __init__(self, part_num=50, normal_channel=True): 11 | super(get_model, self).__init__() 12 | if normal_channel: 13 | channel = 6 14 | else: 15 | channel = 3 16 | self.part_num = part_num 17 | self.stn = STN3d(channel) 18 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 19 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 20 | self.conv3 = torch.nn.Conv1d(128, 128, 1) 21 | self.conv4 = torch.nn.Conv1d(128, 512, 1) 22 | self.conv5 = torch.nn.Conv1d(512, 2048, 1) 23 | self.bn1 = nn.BatchNorm1d(64) 24 | self.bn2 = nn.BatchNorm1d(128) 25 | self.bn3 = nn.BatchNorm1d(128) 26 | self.bn4 = nn.BatchNorm1d(512) 27 | self.bn5 = nn.BatchNorm1d(2048) 28 | self.fstn = STNkd(k=128) 29 | self.convs1 = torch.nn.Conv1d(4944, 256, 1) 30 | self.convs2 = torch.nn.Conv1d(256, 256, 1) 31 | self.convs3 = torch.nn.Conv1d(256, 128, 1) 32 | self.convs4 = torch.nn.Conv1d(128, part_num, 1) 33 | self.bns1 = nn.BatchNorm1d(256) 34 | self.bns2 = nn.BatchNorm1d(256) 35 | self.bns3 = nn.BatchNorm1d(128) 36 | 37 | def forward(self, point_cloud, label): 38 | B, D, N = point_cloud.size() 39 | trans = self.stn(point_cloud) 40 | point_cloud = point_cloud.transpose(2, 1) 41 | if D > 3: 42 | point_cloud, feature = point_cloud.split(3, dim=2) 43 | point_cloud = torch.bmm(point_cloud, trans) 44 | if D > 3: 45 | point_cloud = torch.cat([point_cloud, feature], dim=2) 46 | 47 | point_cloud = point_cloud.transpose(2, 1) 48 | 49 | out1 = F.relu(self.bn1(self.conv1(point_cloud))) 50 | out2 = F.relu(self.bn2(self.conv2(out1))) 51 | out3 = F.relu(self.bn3(self.conv3(out2))) 52 | 53 | trans_feat = self.fstn(out3) 54 | x = out3.transpose(2, 1) 55 | net_transformed = torch.bmm(x, trans_feat) 56 | net_transformed = net_transformed.transpose(2, 1) 57 | 58 | out4 = F.relu(self.bn4(self.conv4(net_transformed))) 59 | out5 = self.bn5(self.conv5(out4)) 60 | out_max = torch.max(out5, 2, keepdim=True)[0] 61 | out_max = out_max.view(-1, 2048) 62 | 63 | out_max = torch.cat([out_max,label.squeeze(1)],1) 64 | expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N) 65 | concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) 66 | net = F.relu(self.bns1(self.convs1(concat))) 67 | net = F.relu(self.bns2(self.convs2(net))) 68 | net = F.relu(self.bns3(self.convs3(net))) 69 | net = self.convs4(net) 70 | net = net.transpose(2, 1).contiguous() 71 | net = F.log_softmax(net.view(-1, self.part_num), dim=-1) 72 | net = net.view(B, N, self.part_num) # [B, N, 50] 73 | 74 | return net, trans_feat 75 | 76 | 77 | class get_loss(torch.nn.Module): 78 | def __init__(self, mat_diff_loss_scale=0.001): 79 | super(get_loss, self).__init__() 80 | self.mat_diff_loss_scale = mat_diff_loss_scale 81 | 82 | def forward(self, pred, target, trans_feat): 83 | loss = F.nll_loss(pred, target) 84 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 85 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale 86 | return total_loss -------------------------------------------------------------------------------- /models/pointnetpp/pointnet_sem_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import torch.nn.functional as F 6 | from .pointnet import PointNetEncoder, feature_transform_reguliarzer 7 | 8 | 9 | class get_model(nn.Module): 10 | def __init__(self, num_class, with_rgb=True): 11 | super(get_model, self).__init__() 12 | if with_rgb: 13 | channel = 6 14 | else: 15 | channel = 3 16 | self.k = num_class 17 | self.feat = PointNetEncoder(global_feat=False, feature_transform=True, channel=channel) 18 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 19 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 20 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 21 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 22 | self.bn1 = nn.BatchNorm1d(512) 23 | self.bn2 = nn.BatchNorm1d(256) 24 | self.bn3 = nn.BatchNorm1d(128) 25 | 26 | def forward(self, x): 27 | batchsize = x.size()[0] 28 | n_pts = x.size()[2] 29 | x, trans, trans_feat = self.feat(x) 30 | x = F.relu(self.bn1(self.conv1(x))) 31 | x = F.relu(self.bn2(self.conv2(x))) 32 | x = F.relu(self.bn3(self.conv3(x))) 33 | x = self.conv4(x) 34 | x = x.transpose(2,1).contiguous() 35 | x = F.log_softmax(x.view(-1,self.k), dim=-1) 36 | x = x.view(batchsize, n_pts, self.k) 37 | return x, trans_feat 38 | 39 | class get_loss(torch.nn.Module): 40 | def __init__(self, mat_diff_loss_scale=0.001): 41 | super(get_loss, self).__init__() 42 | self.mat_diff_loss_scale = mat_diff_loss_scale 43 | 44 | def forward(self, pred, target, trans_feat, weight): 45 | loss = F.nll_loss(pred, target, weight = weight) 46 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 47 | total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale 48 | return total_loss 49 | 50 | 51 | if __name__ == '__main__': 52 | model = get_model(13, with_rgb=False) 53 | xyz = torch.rand(12, 3, 2048) 54 | (model(xyz)) -------------------------------------------------------------------------------- /models/pointnetpp/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | 29 | Input: 30 | src: source points, [B, N, C] 31 | dst: target points, [B, M, C] 32 | Output: 33 | dist: per-point square distance, [B, N, M] 34 | """ 35 | B, N, _ = src.shape 36 | _, M, _ = dst.shape 37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 38 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 40 | return dist 41 | 42 | 43 | def index_points(points, idx): 44 | """ 45 | 46 | Input: 47 | points: input points data, [B, N, C] 48 | idx: sample index data, [B, S] 49 | Return: 50 | new_points:, indexed points data, [B, S, C] 51 | """ 52 | device = points.device 53 | B = points.shape[0] 54 | view_shape = list(idx.shape) 55 | view_shape[1:] = [1] * (len(view_shape) - 1) 56 | repeat_shape = list(idx.shape) 57 | repeat_shape[0] = 1 58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, 3] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, 3] 93 | new_xyz: query points, [B, S, 3] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, 3] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 120 | new_points: sampled points data, [B, npoint, nsample, 3+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 125 | torch.cuda.empty_cache() 126 | new_xyz = index_points(xyz, fps_idx) 127 | torch.cuda.empty_cache() 128 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 129 | torch.cuda.empty_cache() 130 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 131 | torch.cuda.empty_cache() 132 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 133 | torch.cuda.empty_cache() 134 | 135 | if points is not None: 136 | grouped_points = index_points(points, idx) 137 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 138 | else: 139 | new_points = grouped_xyz_norm 140 | if returnfps: 141 | return new_xyz, new_points, grouped_xyz, fps_idx 142 | else: 143 | return new_xyz, new_points 144 | 145 | 146 | def sample_and_group_all(xyz, points): 147 | """ 148 | Input: 149 | xyz: input points position data, [B, N, 3] 150 | points: input points data, [B, N, D] 151 | Return: 152 | new_xyz: sampled points position data, [B, 1, 3] 153 | new_points: sampled points data, [B, 1, N, 3+D] 154 | """ 155 | device = xyz.device 156 | B, N, C = xyz.shape 157 | new_xyz = torch.zeros(B, 1, C).to(device) 158 | grouped_xyz = xyz.view(B, 1, N, C) 159 | if points is not None: 160 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 161 | else: 162 | new_points = grouped_xyz 163 | return new_xyz, new_points 164 | 165 | 166 | class PointNetSetAbstraction(nn.Module): 167 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 168 | super(PointNetSetAbstraction, self).__init__() 169 | self.npoint = npoint 170 | self.radius = radius 171 | self.nsample = nsample 172 | self.mlp_convs = nn.ModuleList() 173 | self.mlp_bns = nn.ModuleList() 174 | last_channel = in_channel 175 | for out_channel in mlp: 176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 178 | last_channel = out_channel 179 | self.group_all = group_all 180 | 181 | def forward(self, xyz, points): 182 | """ 183 | Input: 184 | xyz: input points position data, [B, C, N] 185 | points: input points data, [B, D, N] 186 | Return: 187 | new_xyz: sampled points position data, [B, C, S] 188 | new_points_concat: sample points feature data, [B, D', S] 189 | """ 190 | xyz = xyz.permute(0, 2, 1) 191 | if points is not None: 192 | points = points.permute(0, 2, 1) 193 | 194 | if self.group_all: 195 | new_xyz, new_points = sample_and_group_all(xyz, points) 196 | else: 197 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 198 | # new_xyz: sampled points position data, [B, npoint, C] 199 | # new_points: sampled points data, [B, npoint, nsample, C+D] 200 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 201 | for i, conv in enumerate(self.mlp_convs): 202 | bn = self.mlp_bns[i] 203 | new_points = F.relu(bn(conv(new_points))) 204 | 205 | new_points = torch.max(new_points, 2)[0] 206 | new_xyz = new_xyz.permute(0, 2, 1) 207 | return new_xyz, new_points 208 | 209 | 210 | class PointNetSetAbstractionMsg(nn.Module): 211 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 212 | super(PointNetSetAbstractionMsg, self).__init__() 213 | self.npoint = npoint 214 | self.radius_list = radius_list 215 | self.nsample_list = nsample_list 216 | self.conv_blocks = nn.ModuleList() 217 | self.bn_blocks = nn.ModuleList() 218 | for i in range(len(mlp_list)): 219 | convs = nn.ModuleList() 220 | bns = nn.ModuleList() 221 | last_channel = in_channel + 3 222 | for out_channel in mlp_list[i]: 223 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 224 | bns.append(nn.BatchNorm2d(out_channel)) 225 | last_channel = out_channel 226 | self.conv_blocks.append(convs) 227 | self.bn_blocks.append(bns) 228 | 229 | def forward(self, xyz, points): 230 | """ 231 | Input: 232 | xyz: input points position data, [B, C, N] 233 | points: input points data, [B, D, N] 234 | Return: 235 | new_xyz: sampled points position data, [B, C, S] 236 | new_points_concat: sample points feature data, [B, D', S] 237 | """ 238 | xyz = xyz.permute(0, 2, 1) 239 | if points is not None: 240 | points = points.permute(0, 2, 1) 241 | 242 | B, N, C = xyz.shape 243 | S = self.npoint 244 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 245 | new_points_list = [] 246 | for i, radius in enumerate(self.radius_list): 247 | K = self.nsample_list[i] 248 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 249 | grouped_xyz = index_points(xyz, group_idx) 250 | grouped_xyz -= new_xyz.view(B, S, 1, C) 251 | if points is not None: 252 | grouped_points = index_points(points, group_idx) 253 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 254 | else: 255 | grouped_points = grouped_xyz 256 | 257 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 258 | for j in range(len(self.conv_blocks[i])): 259 | conv = self.conv_blocks[i][j] 260 | bn = self.bn_blocks[i][j] 261 | grouped_points = F.relu(bn(conv(grouped_points))) 262 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 263 | new_points_list.append(new_points) 264 | 265 | new_xyz = new_xyz.permute(0, 2, 1) 266 | new_points_concat = torch.cat(new_points_list, dim=1) 267 | return new_xyz, new_points_concat 268 | 269 | 270 | class PointNetFeaturePropagation(nn.Module): 271 | def __init__(self, in_channel, mlp): 272 | super(PointNetFeaturePropagation, self).__init__() 273 | self.mlp_convs = nn.ModuleList() 274 | self.mlp_bns = nn.ModuleList() 275 | last_channel = in_channel 276 | for out_channel in mlp: 277 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 278 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 279 | last_channel = out_channel 280 | 281 | def forward(self, xyz1, xyz2, points1, points2): 282 | """ 283 | Input: 284 | xyz1: input points position data, [B, C, N] 285 | xyz2: sampled input points position data, [B, C, S] 286 | points1: input points data, [B, D, N] 287 | points2: sampled input points data, [B, D, S] 288 | Return: 289 | new_points: upsampled points data, [B, D', N] 290 | """ 291 | xyz1 = xyz1.permute(0, 2, 1) 292 | xyz2 = xyz2.permute(0, 2, 1) 293 | 294 | points2 = points2.permute(0, 2, 1) 295 | B, N, C = xyz1.shape 296 | _, S, _ = xyz2.shape 297 | 298 | if S == 1: 299 | interpolated_points = points2.repeat(1, N, 1) 300 | else: 301 | dists = square_distance(xyz1, xyz2) 302 | dists, idx = dists.sort(dim=-1) 303 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 304 | 305 | dist_recip = 1.0 / (dists + 1e-8) 306 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 307 | weight = dist_recip / norm 308 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 309 | 310 | if points1 is not None: 311 | points1 = points1.permute(0, 2, 1) 312 | new_points = torch.cat([points1, interpolated_points], dim=-1) 313 | else: 314 | new_points = interpolated_points 315 | 316 | new_points = new_points.permute(0, 2, 1) 317 | for i, conv in enumerate(self.mlp_convs): 318 | bn = self.mlp_bns[i] 319 | new_points = F.relu(bn(conv(new_points))) 320 | return new_points 321 | 322 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pointnet2_ops/_ext-src 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: pointnet2-ops 3 | Version: 3.0.0 4 | Summary: UNKNOWN 5 | Author: Erik Wijmans 6 | License: UNKNOWN 7 | Platform: UNKNOWN 8 | 9 | UNKNOWN 10 | 11 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | MANIFEST.in 2 | setup.py 3 | pointnet2_ops/__init__.py 4 | pointnet2_ops/_version.py 5 | pointnet2_ops/pointnet2_modules.py 6 | pointnet2_ops/pointnet2_utils.py 7 | pointnet2_ops.egg-info/PKG-INFO 8 | pointnet2_ops.egg-info/SOURCES.txt 9 | pointnet2_ops.egg-info/dependency_links.txt 10 | pointnet2_ops.egg-info/requires.txt 11 | pointnet2_ops.egg-info/top_level.txt 12 | pointnet2_ops/_ext-src/include/ball_query.h 13 | pointnet2_ops/_ext-src/include/cuda_utils.h 14 | pointnet2_ops/_ext-src/include/group_points.h 15 | pointnet2_ops/_ext-src/include/interpolate.h 16 | pointnet2_ops/_ext-src/include/sampling.h 17 | pointnet2_ops/_ext-src/include/utils.h 18 | pointnet2_ops/_ext-src/src/ball_query.cpp 19 | pointnet2_ops/_ext-src/src/ball_query_gpu.cu 20 | pointnet2_ops/_ext-src/src/bindings.cpp 21 | pointnet2_ops/_ext-src/src/group_points.cpp 22 | pointnet2_ops/_ext-src/src/group_points_gpu.cu 23 | pointnet2_ops/_ext-src/src/interpolate.cpp 24 | pointnet2_ops/_ext-src/src/interpolate_gpu.cu 25 | pointnet2_ops/_ext-src/src/sampling.cpp 26 | pointnet2_ops/_ext-src/src/sampling_gpu.cu -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | pointnet2_ops 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | try: 8 | import pointnet2_ops._ext as _ext 9 | except ImportError: 10 | from torch.utils.cpp_extension import load 11 | import glob 12 | import os.path as osp 13 | import os 14 | 15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") 16 | 17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 19 | osp.join(_ext_src_root, "src", "*.cu") 20 | ) 21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 22 | 23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 24 | _ext = load( 25 | "_ext", 26 | sources=_ext_sources, 27 | extra_include_paths=[osp.join(_ext_src_root, "include")], 28 | extra_cflags=["-O3"], 29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 30 | with_cuda=True, 31 | ) 32 | 33 | 34 | class FurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz, npoint): 37 | # type: (Any, torch.Tensor, int) -> torch.Tensor 38 | r""" 39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 40 | minimum distance 41 | 42 | Parameters 43 | ---------- 44 | xyz : torch.Tensor 45 | (B, N, 3) tensor where N > npoint 46 | npoint : int32 47 | number of features in the sampled set 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, npoint) tensor containing the set 53 | """ 54 | out = _ext.furthest_point_sampling(xyz, npoint) 55 | 56 | ctx.mark_non_differentiable(out) 57 | 58 | return out 59 | 60 | @staticmethod 61 | def backward(ctx, grad_out): 62 | return () 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | @staticmethod 70 | def forward(ctx, features, idx): 71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | 88 | ctx.save_for_backward(idx, features) 89 | 90 | return _ext.gather_points(features, idx) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_out): 94 | idx, features = ctx.saved_tensors 95 | N = features.size(2) 96 | 97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 98 | return grad_features, None 99 | 100 | 101 | gather_operation = GatherOperation.apply 102 | 103 | 104 | class ThreeNN(Function): 105 | @staticmethod 106 | def forward(ctx, unknown, known): 107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 108 | r""" 109 | Find the three nearest neighbors of unknown in known 110 | Parameters 111 | ---------- 112 | unknown : torch.Tensor 113 | (B, n, 3) tensor of known features 114 | known : torch.Tensor 115 | (B, m, 3) tensor of unknown features 116 | 117 | Returns 118 | ------- 119 | dist : torch.Tensor 120 | (B, n, 3) l2 distance to the three nearest neighbors 121 | idx : torch.Tensor 122 | (B, n, 3) index of 3 nearest neighbors 123 | """ 124 | dist2, idx = _ext.three_nn(unknown, known) 125 | dist = torch.sqrt(dist2) 126 | 127 | ctx.mark_non_differentiable(dist, idx) 128 | 129 | return dist, idx 130 | 131 | @staticmethod 132 | def backward(ctx, grad_dist, grad_idx): 133 | return () 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | class ThreeInterpolate(Function): 140 | @staticmethod 141 | def forward(ctx, features, idx, weight): 142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 143 | r""" 144 | Performs weight linear interpolation on 3 features 145 | Parameters 146 | ---------- 147 | features : torch.Tensor 148 | (B, c, m) Features descriptors to be interpolated from 149 | idx : torch.Tensor 150 | (B, n, 3) three nearest neighbors of the target features in features 151 | weight : torch.Tensor 152 | (B, n, 3) weights 153 | 154 | Returns 155 | ------- 156 | torch.Tensor 157 | (B, c, n) tensor of the interpolated features 158 | """ 159 | ctx.save_for_backward(idx, weight, features) 160 | 161 | return _ext.three_interpolate(features, idx, weight) 162 | 163 | @staticmethod 164 | def backward(ctx, grad_out): 165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 166 | r""" 167 | Parameters 168 | ---------- 169 | grad_out : torch.Tensor 170 | (B, c, n) tensor with gradients of ouputs 171 | 172 | Returns 173 | ------- 174 | grad_features : torch.Tensor 175 | (B, c, m) tensor with gradients of features 176 | 177 | None 178 | 179 | None 180 | """ 181 | idx, weight, features = ctx.saved_tensors 182 | m = features.size(2) 183 | 184 | grad_features = _ext.three_interpolate_grad( 185 | grad_out.contiguous(), idx, weight, m 186 | ) 187 | 188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 189 | 190 | 191 | three_interpolate = ThreeInterpolate.apply 192 | 193 | 194 | class GroupingOperation(Function): 195 | @staticmethod 196 | def forward(ctx, features, idx): 197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 198 | r""" 199 | 200 | Parameters 201 | ---------- 202 | features : torch.Tensor 203 | (B, C, N) tensor of features to group 204 | idx : torch.Tensor 205 | (B, npoint, nsample) tensor containing the indicies of features to group with 206 | 207 | Returns 208 | ------- 209 | torch.Tensor 210 | (B, C, npoint, nsample) tensor 211 | """ 212 | ctx.save_for_backward(idx, features) 213 | 214 | return _ext.group_points(features, idx) 215 | 216 | @staticmethod 217 | def backward(ctx, grad_out): 218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 219 | r""" 220 | 221 | Parameters 222 | ---------- 223 | grad_out : torch.Tensor 224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 225 | 226 | Returns 227 | ------- 228 | torch.Tensor 229 | (B, C, N) gradient of the features 230 | None 231 | """ 232 | idx, features = ctx.saved_tensors 233 | N = features.size(2) 234 | 235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 236 | 237 | return grad_features, torch.zeros_like(idx) 238 | 239 | 240 | grouping_operation = GroupingOperation.apply 241 | 242 | 243 | class BallQuery(Function): 244 | @staticmethod 245 | def forward(ctx, radius, nsample, xyz, new_xyz): 246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 247 | r""" 248 | 249 | Parameters 250 | ---------- 251 | radius : float 252 | radius of the balls 253 | nsample : int 254 | maximum number of features in the balls 255 | xyz : torch.Tensor 256 | (B, N, 3) xyz coordinates of the features 257 | new_xyz : torch.Tensor 258 | (B, npoint, 3) centers of the ball query 259 | 260 | Returns 261 | ------- 262 | torch.Tensor 263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 264 | """ 265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 266 | 267 | ctx.mark_non_differentiable(output) 268 | 269 | return output 270 | 271 | @staticmethod 272 | def backward(ctx, grad_out): 273 | return () 274 | 275 | 276 | ball_query = BallQuery.apply 277 | 278 | 279 | class QueryAndGroup(nn.Module): 280 | r""" 281 | Groups with a ball query of radius 282 | 283 | Parameters 284 | --------- 285 | radius : float32 286 | Radius of ball 287 | nsample : int32 288 | Maximum number of features to gather in the ball 289 | """ 290 | 291 | def __init__(self, radius, nsample, use_xyz=True): 292 | # type: (QueryAndGroup, float, int, bool) -> None 293 | super(QueryAndGroup, self).__init__() 294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 295 | 296 | def forward(self, xyz, new_xyz, features=None): 297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 298 | r""" 299 | Parameters 300 | ---------- 301 | xyz : torch.Tensor 302 | xyz coordinates of the features (B, N, 3) 303 | new_xyz : torch.Tensor 304 | centriods (B, npoint, 3) 305 | features : torch.Tensor 306 | Descriptors of the features (B, C, N) 307 | 308 | Returns 309 | ------- 310 | new_features : torch.Tensor 311 | (B, 3 + C, npoint, nsample) tensor 312 | """ 313 | 314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 315 | xyz_trans = xyz.transpose(1, 2).contiguous() 316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 318 | 319 | if features is not None: 320 | grouped_features = grouping_operation(features, idx) 321 | if self.use_xyz: 322 | new_features = torch.cat( 323 | [grouped_xyz, grouped_features], dim=1 324 | ) # (B, C + 3, npoint, nsample) 325 | else: 326 | new_features = grouped_features 327 | else: 328 | assert ( 329 | self.use_xyz 330 | ), "Cannot have not features and not use xyz as a feature!" 331 | new_features = grouped_xyz 332 | 333 | return new_features 334 | 335 | 336 | class GroupAll(nn.Module): 337 | r""" 338 | Groups all features 339 | 340 | Parameters 341 | --------- 342 | """ 343 | 344 | def __init__(self, use_xyz=True): 345 | # type: (GroupAll, bool) -> None 346 | super(GroupAll, self).__init__() 347 | self.use_xyz = use_xyz 348 | 349 | def forward(self, xyz, new_xyz, features=None): 350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 351 | r""" 352 | Parameters 353 | ---------- 354 | xyz : torch.Tensor 355 | xyz coordinates of the features (B, N, 3) 356 | new_xyz : torch.Tensor 357 | Ignored 358 | features : torch.Tensor 359 | Descriptors of the features (B, C, N) 360 | 361 | Returns 362 | ------- 363 | new_features : torch.Tensor 364 | (B, C + 3, 1, N) tensor 365 | """ 366 | 367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 368 | if features is not None: 369 | grouped_features = features.unsqueeze(2) 370 | if self.use_xyz: 371 | new_features = torch.cat( 372 | [grouped_xyz, grouped_features], dim=1 373 | ) # (B, 3 + C, 1, N) 374 | else: 375 | new_features = grouped_features 376 | else: 377 | new_features = grouped_xyz 378 | 379 | return new_features 380 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="pointnet2_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="pointnet2_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | 35 | def shuffle_points(batch_data): 36 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 37 | Use the same shuffling idx for the entire batch. 38 | Input: 39 | BxNxC array 40 | Output: 41 | BxNxC array 42 | """ 43 | idx = np.arange(batch_data.shape[1]) 44 | np.random.shuffle(idx) 45 | return batch_data[:, idx, :] 46 | 47 | 48 | def rotate_point_cloud(batch_data): 49 | """ Randomly rotate the point clouds to augument the dataset 50 | rotation is per shape based along up direction 51 | Input: 52 | BxNx3 array, original batch of point clouds 53 | Return: 54 | BxNx3 array, rotated batch of point clouds 55 | """ 56 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 57 | for k in range(batch_data.shape[0]): 58 | rotation_angle = np.random.uniform() * 2 * np.pi 59 | cosval = np.cos(rotation_angle) 60 | sinval = np.sin(rotation_angle) 61 | rotation_matrix = np.array([[cosval, 0, sinval], 62 | [0, 1, 0], 63 | [-sinval, 0, cosval]]) 64 | shape_pc = batch_data[k, ...] 65 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 66 | return rotated_data 67 | 68 | 69 | def rotate_point_cloud_z(batch_data): 70 | """ Randomly rotate the point clouds to augument the dataset 71 | rotation is per shape based along up direction 72 | Input: 73 | BxNx3 array, original batch of point clouds 74 | Return: 75 | BxNx3 array, rotated batch of point clouds 76 | """ 77 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 78 | for k in range(batch_data.shape[0]): 79 | rotation_angle = np.random.uniform() * 2 * np.pi 80 | cosval = np.cos(rotation_angle) 81 | sinval = np.sin(rotation_angle) 82 | rotation_matrix = np.array([[cosval, sinval, 0], 83 | [-sinval, cosval, 0], 84 | [0, 0, 1]]) 85 | shape_pc = batch_data[k, ...] 86 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 87 | return rotated_data 88 | 89 | 90 | def rotate_point_cloud_with_normal(batch_xyz_normal): 91 | ''' Randomly rotate XYZ, normal point cloud. 92 | Input: 93 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 94 | Output: 95 | B,N,6, rotated XYZ, normal point cloud 96 | ''' 97 | for k in range(batch_xyz_normal.shape[0]): 98 | rotation_angle = np.random.uniform() * 2 * np.pi 99 | cosval = np.cos(rotation_angle) 100 | sinval = np.sin(rotation_angle) 101 | rotation_matrix = np.array([[cosval, 0, sinval], 102 | [0, 1, 0], 103 | [-sinval, 0, cosval]]) 104 | shape_pc = batch_xyz_normal[k, :, 0:3] 105 | shape_normal = batch_xyz_normal[k, :, 3:6] 106 | batch_xyz_normal[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 107 | batch_xyz_normal[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 108 | return batch_xyz_normal 109 | 110 | 111 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 112 | """ Randomly perturb the point clouds by small rotations 113 | Input: 114 | BxNx6 array, original batch of point clouds and point normals 115 | Return: 116 | BxNx3 array, rotated batch of point clouds 117 | """ 118 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 119 | for k in range(batch_data.shape[0]): 120 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 121 | Rx = np.array([[1, 0, 0], 122 | [0, np.cos(angles[0]), -np.sin(angles[0])], 123 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 124 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 125 | [0, 1, 0], 126 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 127 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 128 | [np.sin(angles[2]), np.cos(angles[2]), 0], 129 | [0, 0, 1]]) 130 | R = np.dot(Rz, np.dot(Ry, Rx)) 131 | shape_pc = batch_data[k, :, 0:3] 132 | shape_normal = batch_data[k, :, 3:6] 133 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 134 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 135 | return rotated_data 136 | 137 | 138 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 139 | """ Rotate the point cloud along up direction with certain angle. 140 | Input: 141 | BxNx3 array, original batch of point clouds 142 | Return: 143 | BxNx3 array, rotated batch of point clouds 144 | """ 145 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 146 | for k in range(batch_data.shape[0]): 147 | # rotation_angle = np.random.uniform() * 2 * np.pi 148 | cosval = np.cos(rotation_angle) 149 | sinval = np.sin(rotation_angle) 150 | rotation_matrix = np.array([[cosval, 0, sinval], 151 | [0, 1, 0], 152 | [-sinval, 0, cosval]]) 153 | shape_pc = batch_data[k, :, 0:3] 154 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 155 | return rotated_data 156 | 157 | 158 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 159 | """ Rotate the point cloud along up direction with certain angle. 160 | Input: 161 | BxNx6 array, original batch of point clouds with normal 162 | scalar, angle of rotation 163 | Return: 164 | BxNx6 array, rotated batch of point clouds iwth normal 165 | """ 166 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 167 | for k in range(batch_data.shape[0]): 168 | # rotation_angle = np.random.uniform() * 2 * np.pi 169 | cosval = np.cos(rotation_angle) 170 | sinval = np.sin(rotation_angle) 171 | rotation_matrix = np.array([[cosval, 0, sinval], 172 | [0, 1, 0], 173 | [-sinval, 0, cosval]]) 174 | shape_pc = batch_data[k, :, 0:3] 175 | shape_normal = batch_data[k, :, 3:6] 176 | rotated_data[k, :, 0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 177 | rotated_data[k, :, 3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 178 | return rotated_data 179 | 180 | 181 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 182 | """ Randomly perturb the point clouds by small rotations 183 | Input: 184 | BxNx3 array, original batch of point clouds 185 | Return: 186 | BxNx3 array, rotated batch of point clouds 187 | """ 188 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 189 | for k in range(batch_data.shape[0]): 190 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 191 | Rx = np.array([[1, 0, 0], 192 | [0, np.cos(angles[0]), -np.sin(angles[0])], 193 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 194 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 195 | [0, 1, 0], 196 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 197 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 198 | [np.sin(angles[2]), np.cos(angles[2]), 0], 199 | [0, 0, 1]]) 200 | R = np.dot(Rz, np.dot(Ry, Rx)) 201 | shape_pc = batch_data[k, ...] 202 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 203 | return rotated_data 204 | 205 | 206 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 207 | """ Randomly jitter points. jittering is per point. 208 | Input: 209 | BxNx3 array, original batch of point clouds 210 | Return: 211 | BxNx3 array, jittered batch of point clouds 212 | """ 213 | B, N, C = batch_data.shape 214 | assert (clip > 0) 215 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip) 216 | jittered_data += batch_data 217 | return jittered_data 218 | 219 | 220 | def shift_point_cloud(batch_data, shift_range=0.1): 221 | """ Randomly shift point cloud. Shift is per point cloud. 222 | Input: 223 | BxNx3 array, original batch of point clouds 224 | Return: 225 | BxNx3 array, shifted batch of point clouds 226 | """ 227 | B, N, C = batch_data.shape 228 | shifts = np.random.uniform(-shift_range, shift_range, (B, 3)) 229 | for batch_index in range(B): 230 | batch_data[batch_index, :, :] += shifts[batch_index, :] 231 | return batch_data 232 | 233 | 234 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 235 | """ Randomly scale the point cloud. Scale is per point cloud. 236 | Input: 237 | BxNx3 array, original batch of point clouds 238 | Return: 239 | BxNx3 array, scaled batch of point clouds 240 | """ 241 | B, N, C = batch_data.shape 242 | scales = np.random.uniform(scale_low, scale_high, B) 243 | for batch_index in range(B): 244 | batch_data[batch_index, :, :] *= scales[batch_index] 245 | return batch_data 246 | 247 | 248 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 249 | ''' batch_pc: BxNx3 ''' 250 | for b in range(batch_pc.shape[0]): 251 | dropout_ratio = np.random.random() * max_dropout_ratio # 0~0.875 252 | drop_idx = np.where(np.random.random((batch_pc.shape[1])) <= dropout_ratio)[0] 253 | if len(drop_idx) > 0: 254 | batch_pc[b, drop_idx, :] = batch_pc[b, 0, :] # set to the first point 255 | return batch_pc 256 | -------------------------------------------------------------------------------- /test_correspondence.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import importlib 3 | import sys 4 | import torch 5 | import os 6 | import argparse 7 | import utils.check_points_utils as checkpoint_util 8 | import numpy as np 9 | from tqdm import tqdm 10 | from data_utils.keypointnet_dataloader import KeyPointNetDataLoader 11 | from models.torch_pointnet_utils import knn_point 12 | 13 | 14 | class PointcloudJitter(object): 15 | def __init__(self, std=0.01, clip=0.05): 16 | self.std, self.clip = std, clip 17 | 18 | def __call__(self, points): 19 | jittered_data = ( 20 | points.new(points.size(0), 3) 21 | .normal_(mean=0.0, std=self.std) 22 | .clamp_(-self.clip, self.clip) 23 | ) 24 | points[:, 0:3] += jittered_data 25 | return points 26 | 27 | 28 | def main(args): 29 | experiment_dir = 'log/' + args.log_dir 30 | sys.path.append(experiment_dir) 31 | model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0] 32 | 33 | epoch, iters, checkpoint = checkpoint_util.load_checkpoint(model_3d=None, filename=str( 34 | experiment_dir) + '/checkpoints/' + args.model) 35 | category = checkpoint['category'] 36 | num_structure_points = checkpoint['num_structure_points'] 37 | multi_distribution = checkpoint['multi_distribution'] 38 | offset = checkpoint['offset'] 39 | 40 | model = importlib.import_module(model_name) 41 | model = model.Pointnet2StructurePointNet(num_structure_points=num_structure_points, input_channels=0, 42 | multi_distribution_num=multi_distribution, 43 | offset=offset) 44 | model.load_state_dict(checkpoint['model_state_3d']) 45 | 46 | model.cuda() 47 | model.eval() 48 | 49 | if os.path.exists(args.output_dir) is False: 50 | os.makedirs(args.output_dir) 51 | 52 | test_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, category + '.json'), 53 | pcd_path=args.pcd_path, split='val') 54 | 55 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False, 56 | num_workers=2, pin_memory=True, 57 | persistent_workers=True) 58 | 59 | thresholds = np.linspace(0., 1) 60 | 61 | model_datas = [] 62 | point_cloud_jitter = PointcloudJitter(std=args.gauss_rate, clip=0.1) 63 | 64 | for batch_id, (batch_points, key_points, key_points_num, _) in tqdm(enumerate(testDataLoader, 0), 65 | total=len(testDataLoader), 66 | smoothing=0.9): 67 | 68 | with torch.no_grad(): 69 | batch_points_jitter = torch.Tensor(batch_points) 70 | if not args.gauss_rate == 0: 71 | for pc in range(batch_points.shape[0]): 72 | batch_points_jitter[pc] = point_cloud_jitter(batch_points[pc]) 73 | 74 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points_jitter.cuda()) 75 | for i in range(0, batch_points.shape[0]): 76 | diameter_shape = torch.sqrt(torch.sum((torch.max(batch_points[i], dim=0)[0] - torch.min(batch_points[i], dim=0)[0]) ** 2)) 77 | 78 | model_datas.append({'structure_pts': structure_points[i] + (torch.rand(structure_points[i].shape) * args.noise_rate).cuda(), 79 | 'gt_feat_pts': key_points[i][:key_points_num[i], :].cuda(), 'diameter_shape': diameter_shape}) 80 | 81 | dis_ratios, dis_thresholds = compute_correspondence_accuracy(model_datas) 82 | 83 | if not os.path.exists('corrs_model/'): 84 | os.mkdir('corrs_model/') 85 | np.savez('corrs_model/' + args.corres_name, dis_ratios, dis_thresholds) 86 | 87 | 88 | def compute_correspondence_dis(model_data_a, model_data_b): 89 | structure_pts_a = model_data_a['structure_pts'] 90 | gt_feat_pts_a = model_data_a['gt_feat_pts'] 91 | structure_pts_b = model_data_b['structure_pts'] 92 | gt_feat_pts_b = model_data_b['gt_feat_pts'] 93 | diameter_shape_b = model_data_b['diameter_shape'] 94 | res_dis = [] 95 | 96 | if not gt_feat_pts_a.shape[0] == gt_feat_pts_b.shape[0]: 97 | return res_dis 98 | 99 | # knn_a_idxs, knn_a_dis = query_KNN_tensor(structure_pts_a, gt_feat_pts_a, 1) 100 | knn_a_idxs = knn_point(1, structure_pts_a[None, :, :], gt_feat_pts_a[None, :, :]) 101 | 102 | corres_pts_in_b = structure_pts_b[knn_a_idxs[0, :, 0], :] 103 | diff = corres_pts_in_b - gt_feat_pts_b 104 | tmp_dis = torch.sqrt(torch.sum(diff * diff, dim=1)) / diameter_shape_b 105 | 106 | for i in range(tmp_dis.shape[0]): 107 | # nan means this feature point is missing on groundtruth model 108 | if torch.isnan(gt_feat_pts_a[i, 0]) == False and torch.isnan(gt_feat_pts_b[i, 0]) == False: 109 | res_dis.append(tmp_dis[i].item()) 110 | 111 | return res_dis 112 | 113 | 114 | def compute_correspondence_accuracy(model_datas): 115 | dis_list = [] 116 | for i in tqdm(range(len(model_datas)), total=len(model_datas)): 117 | for j in range(len(model_datas)): 118 | if i == j: 119 | continue 120 | model_data_i = model_datas[i] 121 | model_data_j = model_datas[j] 122 | corres_dis = compute_correspondence_dis(model_data_i, model_data_j) 123 | dis_list = dis_list + corres_dis 124 | 125 | dis_array = np.array(dis_list) 126 | 127 | dis_thresholds = np.arange(0, 0.26, 0.01) 128 | dis_ratios = [] 129 | 130 | for i in range(dis_thresholds.shape[0]): 131 | threshold = dis_thresholds[i] 132 | ratio = dis_array[dis_array <= threshold].shape[0] / dis_array.shape[0] 133 | dis_ratios.append(ratio) 134 | 135 | dis_ratios = np.array(dis_ratios) 136 | 137 | return dis_ratios, dis_thresholds 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser(description="Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) 142 | parser.add_argument("-num_inputs", type=int, default=1024, help="sample points from initial point cloud") 143 | parser.add_argument("-log_dir", type=str, default='4.8.2', help="path to the trained model log") 144 | parser.add_argument("-model", type=str, default='model_min_test_loss', 145 | help="the trained model[default: model_min_test_loss]") 146 | parser.add_argument("-output_dir", type=str, default='out', help="output dir") 147 | parser.add_argument('-prediction_output', type=str, default='merger_prediction.npz', 148 | help='Output file where prediction results are written.') 149 | parser.add_argument('-pcd_path', type=str, default='./keypointnet/pcds', 150 | help='Point cloud file folder path from KeypointNet dataset.') 151 | parser.add_argument('-json_path', default='./keypointnet/annotations/', help='') 152 | parser.add_argument('-gauss_rate', type=float, default=0, help='') 153 | parser.add_argument('-noise_rate', type=float, default=0, help='') 154 | parser.add_argument('-corres_name', type=str, help='') 155 | 156 | args = parser.parse_args() 157 | main(args) 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import shutil 3 | import torch 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | import os 7 | import argparse 8 | import gc 9 | import importlib 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | ROOT_DIR = BASE_DIR 13 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 14 | 15 | import utils.check_points_utils as checkpoint_util 16 | from tqdm import tqdm 17 | import logging 18 | import datetime 19 | from pathlib import Path 20 | import provider 21 | from data_utils.dataset import Dataset 22 | from data_utils.keypointnet_dataloader import KeyPointNetDataLoader 23 | 24 | torch.backends.cudnn.enabled = True 25 | torch.backends.cudnn.benchmark = True 26 | scaler = torch.cuda.amp.GradScaler() 27 | autocast = torch.cuda.amp.autocast 28 | 29 | 30 | def train_one_epoch(model, optimizer, data_loader, current_iter, criterions, lr_scheduler, num_of_trans, 31 | num_inputs, logger): 32 | model.train() 33 | loss_dict = {} 34 | loss_dict['Loss'] = 0 35 | for l in criterions.keys(): 36 | loss_dict[l] = 0 37 | count = 0 38 | 39 | for batch_id, (batch_points, _, _, _) in tqdm(enumerate(data_loader, 0), total=len(data_loader), 40 | smoothing=0.9): 41 | optimizer.zero_grad() 42 | # print(batch_points.shape) 43 | batch_points = batch_points.data.numpy() 44 | batch_points[:, :, 0:3] = provider.random_scale_point_cloud(batch_points[:, :, 0:3]) 45 | batch_points[:, :, 0:3] = provider.shift_point_cloud(batch_points[:, :, 0:3]) 46 | batch_points = torch.Tensor(batch_points) 47 | batch_points = batch_points.cuda() 48 | 49 | if args.use_half: 50 | with autocast(): 51 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points) 52 | 53 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points) 54 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points) 55 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85) 56 | loss = ComputeLoss3dLoss+WeightedChamferLoss 57 | 58 | scaler.scale(loss).backward() 59 | scaler.step(optimizer) 60 | scaler.update() 61 | else: 62 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points) 63 | 64 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points) 65 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points) 66 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85) 67 | loss = ComputeLoss3dLoss+WeightedChamferLoss 68 | 69 | loss.backward() 70 | optimizer.step() 71 | 72 | current_iter += 1 73 | loss_dict['Loss'] += loss.item() 74 | loss_dict['ComputeLoss3d'] += ComputeLoss3dLoss.item() 75 | loss_dict['WeightedChamferLoss'] += WeightedChamferLoss.item() 76 | # loss_dict['VecLoss'] += loss_Vec.item() 77 | 78 | current_iter += 1 79 | # gc.collect() 80 | count += 1 81 | 82 | lr_scheduler.step() 83 | for k in loss_dict.keys(): 84 | loss_dict[k] /= count 85 | 86 | return loss_dict, current_iter 87 | 88 | 89 | def test(model, data_loader, criterions): 90 | model.eval() 91 | count = 0 92 | loss_dict = {} 93 | loss_dict['Loss'] = 0 94 | for l in criterions.keys(): 95 | loss_dict[l] = 0 96 | for batch_id, (batch_points, _, _, _) in tqdm(enumerate(data_loader, 0), total=len(data_loader), 97 | smoothing=0.9): 98 | batch_points = batch_points.cuda() 99 | structure_points, fps_points, cos_similarity, stpts_prob_map = model(batch_points) 100 | 101 | ComputeLoss3dLoss = criterions['ComputeLoss3d'](batch_points, structure_points) 102 | WeightedChamferLoss = criterions['WeightedChamferLoss'](fps_points, structure_points, stpts_prob_map, batch_points) 103 | # loss_Vec = criterions['VecLoss'](structure_points, cos_similarity, 0.85) 104 | loss = WeightedChamferLoss 105 | 106 | loss_dict['Loss'] += loss.item() 107 | loss_dict['ComputeLoss3d'] += ComputeLoss3dLoss.item() 108 | loss_dict['WeightedChamferLoss'] += WeightedChamferLoss.item() 109 | # loss_dict['VecLoss'] += loss_Vec.item() 110 | 111 | count += 1 112 | 113 | for k in loss_dict.keys(): 114 | loss_dict[k] /= count 115 | return loss_dict 116 | 117 | 118 | def create_loggger(args): 119 | '''CREATE DIR''' 120 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 121 | exp_dir = Path('./log/') 122 | exp_dir.mkdir(exist_ok=True) 123 | # exp_dir = exp_dir.joinpath('') 124 | # exp_dir.mkdir(exist_ok=True) 125 | if args.log_dir is None: 126 | exp_dir = exp_dir.joinpath(timestr) 127 | else: 128 | exp_dir = exp_dir.joinpath(args.log_dir) 129 | exp_dir.mkdir(exist_ok=True) 130 | checkpoints_dir = exp_dir.joinpath('checkpoints/') 131 | checkpoints_dir.mkdir(exist_ok=True) 132 | log_dir = exp_dir.joinpath('logs/') 133 | log_dir.mkdir(exist_ok=True) 134 | 135 | '''LOG''' 136 | logger = logging.getLogger("Model") 137 | logger.setLevel(logging.INFO) 138 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 139 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model), mode='w+') 140 | file_handler.setLevel(logging.INFO) 141 | file_handler.setFormatter(formatter) 142 | logger.addHandler(file_handler) 143 | 144 | return logger, checkpoints_dir, exp_dir 145 | 146 | 147 | def log_string(str, logger): 148 | logger.info(str) 149 | print(str) 150 | 151 | 152 | def train(args): 153 | # lr_clip = 1e-5 154 | # bnm_clip = 1e-2 155 | torch.cuda.empty_cache() 156 | logger, checkpoints_dir, exp_dir = create_loggger(args) 157 | log_string('PARAMETER ...', logger=logger) 158 | log_string(args, logger=logger) 159 | 160 | '''DATA LOADING''' 161 | log_string('Load dataset ...', logger=logger) 162 | 163 | # train_dataset = bhcp_dataloader(args.data_path, args.category, is_pts_aligned=False, split='train') 164 | # test_dataset = bhcp_dataloader(args.data_path, args.category, is_pts_aligned=False, split='test') 165 | # train_dataset = KeyPointNetDataLoader(json_path=cmd_args.json_path, pcd_path=cmd_args.pcd_path, split='train') 166 | # test_dataset = KeyPointNetDataLoader(json_path=cmd_args.json_path, pcd_path=cmd_args.pcd_path, split='val') 167 | 168 | if args.dataset_name == 'keypointnet': 169 | train_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, args.category + '.json'), 170 | pcd_path=args.pcd_path, split='train') 171 | test_dataset = KeyPointNetDataLoader(num_points=args.num_inputs, json_path=os.path.join(args.json_path, args.category + '.json'), 172 | pcd_path=args.pcd_path, split='val') 173 | else: 174 | train_dataset = Dataset(root=args.data_path, dataset_name=args.dataset_name, class_choice=args.category, 175 | num_points=args.num_inputs, split='train', 176 | segmentation=args.segmentation) 177 | test_dataset = Dataset(root=args.data_path, dataset_name=args.dataset_name, class_choice=args.category, 178 | num_points=args.num_inputs, split='test', 179 | segmentation=args.segmentation) 180 | 181 | trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 182 | num_workers=args.num_workers, drop_last=True, pin_memory=True, 183 | persistent_workers=True) 184 | testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 185 | num_workers=args.num_workers, pin_memory=True, persistent_workers=True) 186 | 187 | shutil.copy('./models/%s.py' % args.model, str(exp_dir)) 188 | shutil.copy('./models/chamfer_distance.py', str(exp_dir)) 189 | shutil.copy('./models/torch_pointnet_utils.py', str(exp_dir)) 190 | shutil.copy('./train.py', str(exp_dir)) 191 | 192 | '''MODEL LOADING''' 193 | model = importlib.import_module(args.model) 194 | criterions = {'ComputeLoss3d': model.ComputeLoss3d(), 'WeightedChamferLoss': model.WeightedChamferLoss(), 195 | 'VecLoss': model.VecLoss()} 196 | # criterions = {'ComputeLoss3d': model.ComputeLoss3d(), 'VecLoss': model.VecLoss()} 197 | 198 | model = model.Pointnet2StructurePointNet(num_structure_points=args.num_structure_points, input_channels=0, 199 | multi_distribution_num=args.multi_distribution, 200 | offset=args.offset) 201 | 202 | model.cuda() 203 | 204 | optimizer = optim.Adam( 205 | model.parameters(), 206 | lr=args.lr, 207 | betas=(0.9, 0.999), 208 | eps=1e-08, 209 | weight_decay=args.weight_decay 210 | ) 211 | 212 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_batch, gamma=args.lr_decay) 213 | 214 | iters = -1 215 | min_test_loss = float('inf') 216 | 217 | # load status from checkpoint 218 | start_epoch = 0 219 | if args.checkpoint is not None: 220 | start_epoch, iters = checkpoint_util.load_checkpoint(model_3d=model, optimizer=optimizer, 221 | filename=args.checkpoint) 222 | start_epoch += 1 223 | 224 | log_string('Start Training Unsupervised Structure Points for %s...' % args.dataset_name, logger=logger) 225 | iters = max(iters, 0) 226 | 227 | for epoch_i in range(start_epoch, args.max_epochs): 228 | log_string('-------------------------------------------', logger=logger) 229 | log_string('Epoch %d/%s,Learning Rate %f:' % (epoch_i + 1, args.max_epochs, lr_scheduler.get_last_lr()[0]), 230 | logger=logger) 231 | 232 | loss_dict, iters = train_one_epoch(model, 233 | optimizer, 234 | trainDataLoader, 235 | iters, 236 | criterions, 237 | lr_scheduler, 238 | num_of_trans=args.num_of_transform, 239 | num_inputs=args.num_inputs, 240 | logger=logger) 241 | loss_str = '' 242 | for i in loss_dict.keys(): 243 | loss_str += '%s: %f \t' % (i, loss_dict[i]) 244 | log_string(loss_str, logger) 245 | 246 | with torch.no_grad(): 247 | loss_dict = test(model, 248 | data_loader=testDataLoader, 249 | criterions=criterions, 250 | ) 251 | loss_str = '' 252 | for i in loss_dict.keys(): 253 | loss_str += '%s: %f \t' % (i, loss_dict[i]) 254 | log_string(loss_str, logger) 255 | 256 | if loss_dict['Loss'] < min_test_loss: 257 | min_test_loss = loss_dict['Loss'] 258 | log_string('Min Test Loss: %f' % (loss_dict['Loss']), logger=logger) 259 | 260 | log_string('Save model...', logger=logger) 261 | fname = os.path.join(checkpoints_dir, 'model_min_test_loss') 262 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters, 263 | epoch=epoch_i, category=args.category, 264 | num_structure_points=args.num_structure_points, 265 | multi_distribution=args.multi_distribution, 266 | offset=args.offset) 267 | else: 268 | log_string('Min Test Loss: %f' % (min_test_loss), logger=logger) 269 | 270 | if (epoch_i + 1) % 50 == 0: 271 | fname = os.path.join(checkpoints_dir, 'model_%d' % (epoch_i + 1)) 272 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters, 273 | epoch=epoch_i, category=args.category, 274 | num_structure_points=args.num_structure_points, 275 | multi_distribution=args.multi_distribution, 276 | offset=args.offset) 277 | fname = os.path.join(checkpoints_dir, 'model') 278 | checkpoint_util.save_checkpoint(filename=fname, model_3d=model, optimizer=optimizer, iters=iters, 279 | epoch=epoch_i, category=args.category, 280 | num_structure_points=args.num_structure_points, 281 | multi_distribution=args.multi_distribution, 282 | offset=args.offset) 283 | 284 | 285 | def parse_args(): 286 | parser = argparse.ArgumentParser(description="Arguments", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) 287 | parser.add_argument("-batch_size", type=int, default=36, help="Batch size") 288 | parser.add_argument("-weight_decay", type=float, default=1e-5, help="L2 regularization coeff") 289 | parser.add_argument("-num_inputs", type=int, default=1024, help="sample points from initial point cloud") 290 | parser.add_argument("-num_structure_points", type=int, default=12, help="Number of structure points") 291 | parser.add_argument("-category", type=str, default='laptop', help="Category of the objects to train") 292 | parser.add_argument("-dataset_name", type=str, default='shapenetpart', help="keypointnet,shapenetpart") 293 | parser.add_argument("-data_path", type=str, default='../', help="") 294 | parser.add_argument('-segmentation', action='store_true', default=False, help='') 295 | parser.add_argument('-offset', action='store_true', default=False, help='') 296 | parser.add_argument("-max_epochs", type=int, default=100, help="Number of epochs to train for") 297 | parser.add_argument("-log_dir", type=str, default=None, help="Root of the log") 298 | parser.add_argument("-multi_distribution", type=int, default=3, help="Multivariate normal distribution nums") 299 | parser.add_argument('-num_workers', type=int, default=4, help='dataload num worker') 300 | parser.add_argument('-model', default='model_weightchamfer', help='model name [default: model_weightchamfer Structure_pointnet]') 301 | parser.add_argument('-use_half', action='store_true', default=True, help='use mix half mode') 302 | parser.add_argument('-json_path', default='./keypointnet/annotations/', help='') 303 | parser.add_argument('-pcd_path', type=str, default='./keypointnet/pcds', 304 | help='Point cloud file folder path from KeypointNet dataset.') 305 | parser.add_argument("-lr", type=float, default=1e-3, help="Initial learning rate") 306 | parser.add_argument("-lr_decay", type=float, default=0.7, help="Learning rate decay gamma") 307 | parser.add_argument("-decay_batch", type=float, default=20, help="Learning rate decay batch") 308 | parser.add_argument("-bn_momentum", type=float, default=0.5, help="Initial batch norm momentum") 309 | parser.add_argument("-bnm_decay", type=float, default=0.5, help="Batch norm momentum decay gamma") 310 | parser.add_argument("-checkpoint_save_step", type=int, default=50, help="Step for saving Checkpoint") 311 | parser.add_argument("-checkpoint", type=str, default=None, help="Checkpoint to start from") 312 | parser.add_argument("-num_of_transform", type=int, default=0, 313 | help="Number of transforms for rotation data augmentation. Useful when testing on shapes without alignment") 314 | args = parser.parse_args() 315 | return args 316 | 317 | 318 | if __name__ == "__main__": 319 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 320 | args = parse_args() 321 | import platform 322 | 323 | sys = platform.system() 324 | if sys == "Windows": 325 | args.batch_size = 2 326 | 327 | train(args=args) 328 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import sys 9 | sys.path.append("..") -------------------------------------------------------------------------------- /utils/check_points_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | 5 | def save_checkpoint(filename, model_3d=None, model_2d=None, optimizer=None, iters=None, epoch=None, meta_data=None, category=None, num_structure_points=None, multi_distribution=None,offset=None): 6 | print("save checkpoint '{}'".format(filename)) 7 | optim_state = optimizer.state_dict() if optimizer is not None else None 8 | if model_3d is not None: 9 | if isinstance(model_3d, torch.nn.DataParallel): 10 | model_state_3d = model_3d.module.state_dict() 11 | else: 12 | model_state_3d = model_3d.state_dict() 13 | else: 14 | model_state_3d = None 15 | 16 | if model_2d is not None: 17 | if isinstance(model_2d, torch.nn.DataParallel): 18 | model_state_2d = model_2d.module.state_dict() 19 | else: 20 | model_state_2d = model_2d.state_dict() 21 | else: 22 | model_state_2d = None 23 | 24 | state = { 25 | 'iter': iters, 26 | 'epoch': epoch, 27 | 'model_state_2d': model_state_2d, 28 | 'model_state_3d': model_state_3d, 29 | 'optimizer_state': optim_state, 30 | 'meta_data': meta_data, 31 | 'category': category, 32 | 'num_structure_points': num_structure_points, 33 | 'multi_distribution': multi_distribution, 34 | 'offset':offset 35 | } 36 | torch.save(state, filename) 37 | 38 | 39 | def load_checkpoint(filename, model_3d=None, optimizer=None, meta_data=None): 40 | if os.path.isfile(filename): 41 | checkpoint = torch.load(filename) 42 | iters = checkpoint.get('iter', 0.0) 43 | epoch = checkpoint['epoch'] 44 | 45 | if model_3d is not None and checkpoint['model_state_3d'] is not None: 46 | model_3d.load_state_dict(checkpoint['model_state_3d']) 47 | if optimizer is not None and checkpoint['optimizer_state'] is not None: 48 | optimizer.load_state_dict(checkpoint['optimizer_state']) 49 | 50 | if meta_data is not None and 'meta_data' in checkpoint: 51 | for key in checkpoint['meta_data']: 52 | meta_data[key] = checkpoint['meta_data'][key] 53 | return epoch, iters, checkpoint 54 | else: 55 | print("==> Checkpoint '{}' not found".format(filename)) 56 | return None 57 | -------------------------------------------------------------------------------- /utils/logutils.py: -------------------------------------------------------------------------------- 1 | class LogUtils(): 2 | def __init__(self, fname, filemode): 3 | self.logf = open(fname, filemode) 4 | 5 | def write(self, text, need_display=True): 6 | if need_display is True: 7 | print(text) 8 | 9 | self.logf.write(text + '\n') 10 | self.logf.flush() 11 | 12 | def close(self): 13 | self.logf.close() 14 | 15 | def write_args(self, cmd_args): 16 | self.logf.write('cmd arguments:\n') 17 | for k in cmd_args.__dict__: 18 | val = cmd_args.__dict__[k] 19 | self.logf.write('{0}: {1}\n'.format(k, val)) 20 | 21 | def write_correspondence_accuracy(fname, dis_ratio, dis_threshold): 22 | with open(fname, 'w') as f: 23 | f.write('Correspondence Accuracy:\n') 24 | f.write('distance ratio: ') 25 | for i in range(dis_ratio.shape[0]): 26 | f.write(' {0}'.format(dis_ratio[i])) 27 | f.write('\n') 28 | 29 | f.write('distance threshold: ') 30 | for i in range(dis_threshold.shape[0]): 31 | f.write(' {0}'.format(dis_threshold[i])) 32 | f.write('\n') 33 | f.close() 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/mesh_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def read_obj_file(fname): 3 | vertices = [] 4 | faces = [] 5 | try: 6 | f = open(fname) 7 | 8 | for line in f: 9 | if line[:2] == "v ": 10 | strs = line.split() 11 | v0 = float(strs[1]) 12 | v1 = float(strs[2]) 13 | v2 = float(strs[3]) 14 | vertex = [v0, v1, v2] 15 | vertices.append(vertex) 16 | 17 | elif line[0] == "f": 18 | strs = line.split() 19 | f0 = int(strs[1].split('/')[0])-1 20 | f1 = int(strs[2].split('/')[0])-1 21 | f2 = int(strs[3].split('/')[0])-1 22 | face = [f0, f1, f2] 23 | 24 | faces.append(face) 25 | 26 | f.close() 27 | except IOError: 28 | print(".obj file not found.") 29 | 30 | vertices = np.array(vertices) 31 | faces = np.array(faces) 32 | 33 | return vertices, faces 34 | 35 | 36 | def rotate_obj_file(fname, rot_mat): 37 | 38 | lines = [] 39 | with open(fname) as fin: 40 | 41 | for line in fin: 42 | lines.append(line) 43 | 44 | for i in range(len(lines)): 45 | line = lines[i] 46 | if line[:2] == "v ": 47 | strs = line.split(' ') 48 | vertex = np.array([float(strs[1]), float(strs[2]), float(strs[3])]) 49 | vertex = np.matmul(rot_mat, vertex) 50 | line = "v {0} {1} {2}\n".format(vertex[0], vertex[1], vertex[2]) 51 | lines[i] = line 52 | elif line[:3] == 'vn ': 53 | strs = line.split(' ') 54 | vn = np.array([float(strs[1]), float(strs[2]), float(strs[3])]) 55 | vn = np.matmul(rot_mat, vn) 56 | line = "vn {0} {1} {2}\n".format(vn[0], vn[1], vn[2]) 57 | lines[i] = line 58 | 59 | with open(fname, 'w') as f: 60 | for line in lines: 61 | f.write(line) 62 | 63 | 64 | def write_off_file(fname, vertices, faces): 65 | with open(fname, 'w') as f: 66 | vnum = len(vertices) 67 | fnum = len(faces) 68 | f.write('COFF\n') 69 | f.write('{0} {1} {2}\n'.format(vnum, fnum, 0)) 70 | for i in range(0, vnum): 71 | f.write('{0} {1} {2}\n'.format(vertices[i][0], vertices[i][1], vertices[i][2])) 72 | 73 | fnum = len(faces) 74 | for i in range(0, fnum): 75 | f.write('3 {0} {1} {2}\n'.format(faces[i][0], faces[i][1], faces[i][2])) 76 | 77 | def read_off_file(fname): 78 | vertices = [] 79 | faces = [] 80 | try: 81 | f = open(fname) 82 | head = f.readline() 83 | strline = f.readline() 84 | strs = strline.split(' ') 85 | vnum = int(strs[0]) 86 | fnum = int(strs[1]) 87 | for i in range(0, vnum): 88 | strline = f.readline() 89 | strs = strline.split(' ') 90 | v0 = float(strs[0]) 91 | v1 = float(strs[1]) 92 | v2 = float(strs[2]) 93 | vertex = [v0, v1, v2] 94 | vertices.append(vertex) 95 | 96 | for i in range(0, fnum): 97 | strline = f.readline() 98 | strs = strline.split(' ') 99 | f0 = int(strs[1]) 100 | f1 = int(strs[2]) 101 | f2 = int(strs[3]) 102 | face = [f0, f1, f2] 103 | faces.append(face) 104 | 105 | f.close() 106 | except IOError: 107 | print(".off file not found.") 108 | 109 | vertices = np.array(vertices) 110 | faces = np.array(faces) 111 | return vertices, faces 112 | 113 | 114 | def write_obj_file(fname, vertices, faces): 115 | with open(fname, 'w') as f: 116 | vnum = len(vertices) 117 | for i in range(0, vnum): 118 | f.write('v {0} {1} {2}\n'.format(vertices[i][0], vertices[i][1], vertices[i][2])) 119 | 120 | fnum = len(faces) 121 | for i in range(0, fnum): 122 | f.write('f {0} {1} {2}\n'.format(faces[i][0]+1, faces[i][1]+1, faces[i][2]+1)) 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /utils/point_cloud_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | # from pointnet2 import _ext 5 | 6 | def write_points_off(fname, points, colors=None): 7 | 8 | with open(fname, 'w') as f: 9 | 10 | num = points.shape[0] 11 | f.write('COFF\n') 12 | f.write('{0} 0 0\n'.format(num)) 13 | for i in range(0, num): 14 | if colors is not None: 15 | f.write('{0} {1} {2} {3} {4} {5}\n'.format(points[i, 0], points[i, 1], points[i, 2], int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2]))) 16 | else: 17 | f.write('{0} {1} {2}\n'.format(points[i, 0], points[i, 1], points[i, 2])) 18 | 19 | 20 | def write_points_obj(fname, points, colors=None): 21 | 22 | with open(fname, 'w') as f: 23 | 24 | num = points.shape[0] 25 | for i in range(0, num): 26 | if colors is not None: 27 | f.write('v {0} {1} {2} {3} {4} {5}\n'.format(points[i, 0], points[i, 1], points[i, 2], int(colors[i, 0]), int(colors[i, 1]), int(colors[i, 2]))) 28 | else: 29 | f.write('v {0} {1} {2}\n'.format(points[i, 0], points[i, 1], points[i, 2])) 30 | 31 | 32 | def compute_pca(points): 33 | mean, eigvec = cv2.PCACompute(points, mean=None) 34 | if np.dot(np.cross(eigvec[0], eigvec[1]), eigvec[2])<0: 35 | eigvec[2] = -eigvec[2] 36 | 37 | eigvec[0] = eigvec[0] / np.linalg.norm(eigvec[0]) 38 | eigvec[1] = eigvec[1] / np.linalg.norm(eigvec[1]) 39 | eigvec[2] = eigvec[2] / np.linalg.norm(eigvec[2]) 40 | 41 | return eigvec 42 | 43 | def query_KNN(points, query_pts, k, return_dis=True): 44 | ''' 45 | 46 | :param points: n x 3 47 | :param query_pts: m x 3 48 | :param k: num of neighbors 49 | :return: m x k ids, sorted_dis 50 | ''' 51 | 52 | diff = query_pts[:, None, :] - points[None, :, :] 53 | dis = np.sqrt(np.sum(diff * diff, axis=2))# m x n 54 | sorted_idx = np.argsort(dis, axis=1) 55 | sorted_idx = sorted_idx[:, :k] 56 | 57 | if return_dis: 58 | sorted_dis = dis[None, 0, sorted_idx[0, :]] 59 | for i in range(1, query_pts.shape[0]): 60 | sorted_dis = np.concatenate((sorted_dis, dis[None, i, sorted_idx[i, :]]), axis=0) 61 | 62 | return sorted_idx, sorted_dis 63 | else: 64 | return sorted_idx 65 | 66 | 67 | def query_KNN_tensor(points, query_pts, k): 68 | ''' 69 | 70 | :param points: n x 3 71 | :param query_pts: m x 3 72 | :param k: num of neighbors 73 | :return: m x k ids, sorted_dis 74 | ''' 75 | 76 | diff = query_pts[:, None, :] - points[None, :, :] 77 | dis = torch.sqrt(torch.sum(diff * diff, dim=2))# m x n 78 | sorted_idx = torch.argsort(dis, dim=1) 79 | sorted_idx = sorted_idx[:, :k] 80 | 81 | sorted_dis = dis[None, 0, sorted_idx[0, :]] 82 | for i in range(1, query_pts.shape[0]): 83 | sorted_dis = torch.cat((sorted_dis, dis[None, i, sorted_idx[i, :]]), dim=0) 84 | 85 | return sorted_idx, sorted_dis 86 | 87 | 88 | 89 | def read_pointcloud_obj(fname): 90 | vertices = [] 91 | try: 92 | f = open(fname) 93 | 94 | for line in f: 95 | if line[:2] == "v ": 96 | strs = line.split(' ') 97 | v0 = float(strs[1]) 98 | v1 = float(strs[2]) 99 | v2 = float(strs[3]) 100 | vertex = [v0, v1, v2] 101 | vertices.append(vertex) 102 | 103 | f.close() 104 | except IOError: 105 | print(".obj file not found.") 106 | 107 | vertices = np.array(vertices) 108 | 109 | 110 | return vertices 111 | 112 | 113 | def read_points_off(fname, read_color=False): 114 | vertices = [] 115 | colors = [] 116 | 117 | try: 118 | f = open(fname) 119 | head = f.readline() 120 | strline = f.readline() 121 | strs = strline.split(' ') 122 | vnum = int(strs[0]) 123 | fnum = int(strs[1]) 124 | for i in range(0, vnum): 125 | strline = f.readline() 126 | strs = strline.split(' ') 127 | v0 = float(strs[0]) 128 | v1 = float(strs[1]) 129 | v2 = float(strs[2]) 130 | vertex = [v0, v1, v2] 131 | vertices.append(vertex) 132 | 133 | if len(strs) > 3: 134 | c0 = float(strs[3]) 135 | c1 = float(strs[4]) 136 | c2 = float(strs[5]) 137 | color = [c0, c1, c2] 138 | colors.append(color) 139 | 140 | 141 | 142 | 143 | f.close() 144 | except IOError: 145 | print(".off file not found.") 146 | 147 | pts = np.array(vertices).astype(np.float32) 148 | 149 | if len(colors) > 0 and read_color == True: 150 | colors = np.array(colors).astype(np.float32) 151 | return pts, colors 152 | else: 153 | return pts 154 | 155 | def trans_pointcloud(rot_mat, trans_mat, points): 156 | ''' 157 | 158 | :param rot_mat: 3 x 3 159 | :param trans_mat: 3 160 | :param points: n x 3 161 | :return: n x 3 162 | ''' 163 | tmp_points = np.matmul(rot_mat, np.transpose(points, (1, 0))) 164 | tmp_points = tmp_points + trans_mat[:, None] 165 | tmp_points = np.transpose(tmp_points, (1, 0)) 166 | return tmp_points 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | --------------------------------------------------------------------------------