├── LICENSE ├── README.md ├── arguments.py ├── common ├── colors.py ├── math │ ├── __pycache__ │ │ ├── random.cpython-38.pyc │ │ ├── se3.cpython-38.pyc │ │ └── so3.cpython-38.pyc │ ├── random.py │ ├── se3.py │ └── so3.py ├── math_torch │ ├── __pycache__ │ │ └── se3.cpython-38.pyc │ └── se3.py ├── misc.py └── torch.py ├── dataloader ├── NuScenesDataLoader.py ├── SemanticKITTYDataLoader.py ├── data_prepare_SemanticKITTI.py ├── data_prepare_nuScenes.py ├── semantic-kitti.yaml ├── semantic-nuscenes.yaml ├── transforms.py └── utils.py ├── metrics.py ├── models ├── PointUtils │ ├── points_utils.py │ ├── setup.py │ └── src │ │ ├── cuda_utils.h │ │ ├── furthest_point_sampling.cpp │ │ ├── furthest_point_sampling_gpu.cu │ │ ├── furthest_point_sampling_gpu.h │ │ └── point_utils_api.cpp ├── RandLA_Net.py ├── attention.py ├── compute_rigid_transform.py ├── key_point_dectector.py ├── loss.py ├── model.py ├── semanticCNN.py └── utils.py ├── requirements.txt ├── script ├── test_kitti.sh ├── test_nuscenes.sh ├── train_kitti.sh └── train_nuscenes.sh ├── test.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SARNet: Semantic Augmented Registration of Large-Scale Urban Point Clouds 2 | 3 | ## Environments 4 | Our source code was developed using Python 3.8 with PyTorch 1.8, maybe the other version is also suitable. 5 | The code mainly requires the following libraries and you can check `requirements.txt` for more environment requirements. 6 | - PyTorch 7 | - [nuscenes](https://github.com/nutonomy/nuscenes-devkit) 8 | 9 | Please run the following commands to install `point_utils` 10 | ``` 11 | cd models/PointUtils 12 | python setup.py install 13 | ``` 14 | 15 | ## Datasets 16 | After download the two datasets, we need to preprocess the datasets: 17 | ### [SemanticKITTI](http://www.semantic-kitti.org/) 18 | We need to download both the points data and the semantic label, the initial dataset should be organized as: 19 | ``` 20 | SemanticKITTI 21 | |-- velodyne/sequences 22 | |-- 00/velodyne 23 | |--000000.bin 24 | ... 25 | |-- 01/velodyne 26 | ... 27 | |-- label/sequences 28 | |-- 00/labels 29 | |--000000.label 30 | ... 31 | |-- 01/labels 32 | ... 33 | ``` 34 | After using data_prepare_SemanticKITTI.py to preprocess the intial data, We write the coordinates and the semantic label of points into one file: 35 | ``` 36 | SemanticKITTI 37 | |-- process 38 | |--train 39 | |--000000.npy 40 | ... 41 | |--val 42 | |--0003178.npy 43 | ... 44 | |--test 45 | |--0003632.npy 46 | ... 47 | ``` 48 | ### [NuScenes](https://www.nuscenes.org/) 49 | Similar to SemanticKITTI, we need to download both the points data and the semantic label(Nusenes-lidarseg), 50 | and extract the lidarseg and v1.0-* folders to our nuScenes root directory, the initial dataset should be organized as: 51 | ``` 52 | NuScenes 53 | |-- lidarseg 54 | |-- v1.0-{mini, test, trainval} <- Contains the .bin files; a .bin file contains the labels of the points in a point cloud 55 | |-- samples <- Sensor data for keyframes. 56 | |-- LIDAR_TOP 57 | |-- sweeps <- Sensor data for intermediate frames 58 | |-- LIDAR_TOP 59 | |-- v1.0-{mini, test, trainval} <- JSON tables that include all the meta data and annotations. 60 | 61 | ``` 62 | After using data_prepare_nuScenes.py to preprocess the intial data, We write the coordinates and the semantic label of points into one file: 63 | ``` 64 | NuScenes 65 | |-- process 66 | |--train 67 | |--0000fa1a7bfb46dc872045096181303e.npy 68 | ... 69 | |--val 70 | |--00ad11d5341a4ed3a300bc557dd00e73.npy 71 | ... 72 | |--test 73 | |--00b3af4bba044f0ea42adc2bbb5733ff.npy 74 | ... 75 | ``` 76 | 77 | ## Training 78 | We provide scripts to train both SemanticKITTI and NuScenes, you should specify your own dataset path, and you could use the default parameters or specify your own parameters like in the train_nuscenes.sh: 79 | ``` 80 | sh script/train_kitti.sh -> train SemanticKITTI 81 | sh script/train_nuscenes.sh -> train NuScenes 82 | ``` 83 | 84 | ## Testing 85 | ``` 86 | sh script/test_kitti.sh -> test SemanticKITTI 87 | sh script/test_nuscenes.sh -> test NuScenes 88 | ``` 89 | 90 | ## Acknowledgments 91 | We want to thank all the following open-source projects for the help of the implementation: 92 | - [RPMNet](https://github.com/yewzijian/RPMNet) 93 | - [PointNet++](https://github.com/sshaoshuai/Pointnet2.PyTorch)(unofficial implementation, for Furthest Points Sampling) 94 | - [RandLA-Net](https://github.com/aRI0U/RandLA-Net-pytorch.git)(unofficial implementation, for Semantic Segmentation) 95 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | def dataset_arguments(): 5 | parser = argparse.ArgumentParser( add_help=False ) 6 | parser.add_argument('--dataset', type=str, default='SemanticKitti', choices=['SemanticKitti', 'NuScenes'], help='Which dataset to choose, the choise is SemanticKitti or NuScenes') 7 | parser.add_argument( '--sample_point_num', default=16000, type=int, help="The sampled points' number in the voxel sampling" ) 8 | parser.add_argument( '--nsample', default=1024, type=int, help="The sampled points' number in the detector" ) 9 | parser.add_argument('--sample_voxel_size', default = 0.3, type=float, help="The sampled voxel in the voxel sampling" ) 10 | parser.add_argument('--boundingbox_diagonal', default = 102, type=float, help="The average diagonal of the dataset's boundingbox" ) 11 | 12 | #SemanticKITTI 13 | parser.add_argument('--kitty_root', type=str, metavar='PATH', default = "/.../SemanticKITTY/process", help="Directory of SemanticKitti after preprocessing" ) 14 | parser.add_argument('--kitti_ignore_label', default=[1,4,5,6,7,8], type=list, help="the semantic category in SemanticKitti will be ignored, sg, the moving category" ) 15 | #NuScenes 16 | parser.add_argument('--nuscenes_root', type=str, metavar='PATH', 17 | default = '/.../nuscenes/process', help="Directory of Nuscenes after preprocessing" ) 18 | parser.add_argument('--nuscenes_ignore_label', default=[1,2,3,4,5,7], type=list,help="the semantic category in Nuscenes will be ignored, sg, the moving category" ) 19 | return parser 20 | 21 | def model_arguments(): 22 | parser = argparse.ArgumentParser( add_help=False ) 23 | parser.add_argument( '--nb', default=20, type=int, help="the neighbor points number in the k-nearest-neighboring" ) 24 | parser.add_argument( '--init_dims', default=3, type=int, help="the input dimension of the network" ) 25 | parser.add_argument( '--emb_dims', default=512, type=int, help="the embedding dimension of the network" ) 26 | parser.add_argument( '--attention_head_num', default=4, type=int, help="multi-head attention number" ) 27 | parser.add_argument('--trans_loss_type', type=str, choices=['mse', 'mae'], default='mae', 28 | help=' Transformation loss to be optimized') 29 | parser.add_argument('--dev', action='store_true', help='If true, will ignore logdir and log to ../logdev instead') 30 | parser.add_argument('--name', type=str, help='Prefix to add to logging directory') 31 | parser.add_argument('--semantic_classes_num', default = 20, type=int, help="Semantic classes number which wil be used in both segmention and registration" ) 32 | parser.add_argument('--iter_num', default = 2, type=int, help=" iteration number in the RNN network" ) 33 | parser.add_argument("--local_rank", type=int, help="Used in the mutil-gpu training") 34 | parser.add_argument('--rot_mag', default = 45.0, type=float, help="Rotation limitation in the data processing" ) 35 | parser.add_argument('--trans_mag', default = 5.0, type=float, help="Translation limitation in the data processing" ) 36 | parser.add_argument('--partial_p_keep', default = [0.7,0.7], type=list, help="The ratio of keeping part in the croping processing" ) 37 | return parser 38 | 39 | 40 | def train_arguments(): 41 | parser = argparse.ArgumentParser( parents=[dataset_arguments(), model_arguments()] ) 42 | parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate during training') 43 | parser.add_argument('--scheduler_step_size', default=10, type=int, help='The scheduler step size, the learning rate will decrease every * epoches' ) 44 | parser.add_argument('--scheduler_gamma', default=0.5, type=float, help='The reduced speed of the learning rate' ) 45 | parser.add_argument('--augment', default = 0.5, type=float, 46 | help="The probability the data will be extra processed, eg, add another random rotation and translation to moving categories" ) 47 | parser.add_argument('--num_workers', default=4, type=int, help='Number of workers for data_loader loader') 48 | parser.add_argument( '--epoch_num', default=60, type=int, metavar='N', help="The epoch number in the training" ) 49 | parser.add_argument('--train_batch_size', default=3, type=int, metavar='N', 50 | help='Training mini-batch size(default 8)') 51 | parser.add_argument( '--val_batch_size', default=12, type=int, metavar='N', 52 | help='The mini-batch size during validation or testing') 53 | parser.add_argument( '--noise_type', default='jitter', type=str, help="Whether to add extra noise to the data" ) 54 | parser.add_argument( '--save_checkpoints_path', type=str, 55 | default=osp.join( '/.../checkpoints', 'ckpt' ), help="Directory to save checkpoints" ) 56 | parser.add_argument( '--load_checkpoints_path', type=str, 57 | default=osp.join('/.../checkpoints', 'ckpt-best.pth'), help="Directory to load checkpoints" ) 58 | parser.add_argument( '--validate_every', default=-1, type=int, 59 | help="The step size for validation, negetive number means validating every * epoches" ) 60 | parser.add_argument( '--summary_every', default=200, type=int, help="The step size for summary, negetive number means summary every * epoches" ) 61 | parser.add_argument('--RRE_thresholds', default = 2.0, type=float, 62 | help="The relative rotation errer thresholds to judge whether it's successful registration" ) 63 | parser.add_argument('--RTE_thresholds', default = 0.5, type=float, 64 | help="The relative translation errer thresholds to judge whether it's successful registration" ) 65 | parser.add_argument('--logdir', default='/.../logs', type=str, 66 | help='Directory to store logs, summaries, checkpoints.') 67 | return parser 68 | 69 | def test_arguments(): 70 | parser = argparse.ArgumentParser( parents=[dataset_arguments(), model_arguments()] ) 71 | parser.add_argument('--test_batch_size', default=16, type=int, metavar='N', 72 | help='The test mini-batch size') 73 | parser.add_argument('--num_workers', default=4, type=int, 74 | help='Number of workers for data_loader loader.') 75 | parser.add_argument('--eval_save_path', type=str, default='./eval_results', 76 | help='Output data_loader to save evaluation results') 77 | parser.add_argument('--transform_file', type=str, 78 | help='If provided, will use transforms from this provided pickle file') 79 | parser.add_argument('--RRE_thresholds', default = 2.0, type=float, 80 | help="The relative rotation errer thresholds to judge whether it's successful registration" ) 81 | parser.add_argument('--RTE_thresholds', default = 0.5, type=float, 82 | help="The relative translation errer thresholds to judge whether it's successful registration" ) 83 | parser.add_argument('--checkpoints_path', default='/.../checkpoints', type=str, 84 | help='Directory to load checkpoints.') 85 | parser.add_argument('--augment', default = 0.5, type=float, 86 | help="the probability the data will be extra processed, eg, add another random rotation and translation to moving categories" ) 87 | parser.add_argument( '--noise_type', default='clean', type=str, 88 | help="Whether to add extra noise to the data" ) 89 | #parser.add_argument( '--transform_file', default=None, type=str, help="" ) 90 | return parser 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = train_arguments() 95 | _args = parser.parse_args() 96 | -------------------------------------------------------------------------------- /common/colors.py: -------------------------------------------------------------------------------- 1 | """Useful color codes""" 2 | ORANGE = [239, 124, 0] 3 | BLUE = [0, 61, 124] -------------------------------------------------------------------------------- /common/math/__pycache__/random.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WinterCodeForEverything/SARNet/e60558a5a45016f84972be0df01b4490489d3e75/common/math/__pycache__/random.cpython-38.pyc -------------------------------------------------------------------------------- /common/math/__pycache__/se3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WinterCodeForEverything/SARNet/e60558a5a45016f84972be0df01b4490489d3e75/common/math/__pycache__/se3.cpython-38.pyc -------------------------------------------------------------------------------- /common/math/__pycache__/so3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WinterCodeForEverything/SARNet/e60558a5a45016f84972be0df01b4490489d3e75/common/math/__pycache__/so3.cpython-38.pyc -------------------------------------------------------------------------------- /common/math/random.py: -------------------------------------------------------------------------------- 1 | """Functions for random sampling""" 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | 5 | def generate_rand_rotm(x_lim=5.0, y_lim=5.0, z_lim=180.0): 6 | ''' 7 | Input: 8 | x_lim 9 | y_lim 10 | z_lim 11 | return: 12 | rotm: [3,3] 13 | ''' 14 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 15 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 16 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 17 | 18 | rand_eul = np.array([rand_z, rand_y, rand_x]) 19 | r = Rotation.from_euler('zyx', rand_eul, degrees=True) 20 | rotm = r.as_matrix() 21 | return rotm 22 | 23 | def generate_rand_trans(x_lim=10.0, y_lim=1.0, z_lim=0.1): 24 | ''' 25 | Input: 26 | x_lim 27 | y_lim 28 | z_lim 29 | return: 30 | trans [3] 31 | ''' 32 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 33 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 34 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 35 | 36 | rand_trans = np.array([rand_x, rand_y, rand_z]) 37 | 38 | return rand_trans 39 | 40 | 41 | def uniform_2_sphere(num: int = None): 42 | """Uniform sampling on a 2-sphere 43 | 44 | Source: https://gist.github.com/andrewbolster/10274979 45 | 46 | Args: 47 | num: Number of vectors to sample (or None if single) 48 | 49 | Returns: 50 | Random Vector (np.ndarray) of size (num, 3) with norm 1. 51 | If num is None returned value will have size (3,) 52 | 53 | """ 54 | if num is not None: 55 | phi = np.random.uniform(0.0, 2 * np.pi, num) 56 | cos_theta = np.random.uniform(-1.0, 1.0, num) 57 | else: 58 | phi = np.random.uniform(0.0, 2 * np.pi) 59 | cos_theta = np.random.uniform(-1.0, 1.0) 60 | 61 | theta = np.arccos(cos_theta) 62 | x = np.sin(theta) * np.cos(phi) 63 | y = np.sin(theta) * np.sin(phi) 64 | z = np.cos(theta) 65 | 66 | return np.stack((x, y, z), axis=-1) 67 | 68 | 69 | -------------------------------------------------------------------------------- /common/math/se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | 5 | def identity(): 6 | return np.eye(3, 4) 7 | 8 | 9 | def transform(g: np.ndarray, pts: np.ndarray): 10 | """ Applies the SE3 transform 11 | 12 | Args: 13 | g: SE3 transformation matrix of size ([B,] 3/4, 4) 14 | pts: Points to be transformed ([B,] N, 3) 15 | 16 | Returns: 17 | transformed points of size (N, 3) 18 | """ 19 | rot = g[..., :3, :3] # (3, 3) 20 | trans = g[..., :3, 3] # (3) 21 | 22 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :] 23 | return transformed 24 | 25 | 26 | def inverse(g: np.ndarray): 27 | """Returns the inverse of the SE3 transform 28 | 29 | Args: 30 | g: ([B,] 3/4, 4) transform 31 | 32 | Returns: 33 | ([B,] 3/4, 4) matrix containing the inverse 34 | 35 | """ 36 | rot = g[..., :3, :3] # (3, 3) 37 | trans = g[..., :3, 3] # (3) 38 | 39 | inv_rot = np.swapaxes(rot, -1, -2) 40 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1) 41 | if g.shape[-2] == 4: 42 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 43 | 44 | return inverse_transform 45 | 46 | 47 | def concatenate(a: np.ndarray, b: np.ndarray): 48 | """ Concatenate two SE3 transforms 49 | 50 | Args: 51 | a: First transform ([B,] 3/4, 4) 52 | b: Second transform ([B,] 3/4, 4) 53 | 54 | Returns: 55 | a*b ([B, ] 3/4, 4) 56 | 57 | """ 58 | 59 | r_a, t_a = a[..., :3, :3], a[..., :3, 3] 60 | r_b, t_b = b[..., :3, :3], b[..., :3, 3] 61 | 62 | r_ab = r_a @ r_b 63 | t_ab = r_a @ t_b[..., None] + t_a[..., None] 64 | 65 | concatenated = np.concatenate([r_ab, t_ab], axis=-1) 66 | 67 | if a.shape[-2] == 4: 68 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 69 | 70 | return concatenated 71 | 72 | 73 | def from_xyzquat(xyzquat): 74 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw 75 | 76 | Args: 77 | xyzquat: np.array (7,) containing translation and quaterion 78 | 79 | Returns: 80 | SE3 matrix (4, 4) 81 | """ 82 | rot = Rotation.from_quat(xyzquat[3:]) 83 | trans = rot.apply(-xyzquat[:3]) 84 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1) 85 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 86 | 87 | return transform -------------------------------------------------------------------------------- /common/math/so3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rotation related functions for numpy arrays 3 | """ 4 | 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | 9 | def dcm2euler(mats: np.ndarray, seq: str = 'zyx', degrees: bool = True): 10 | """Converts rotation matrix to euler angles 11 | 12 | Args: 13 | mats: (B, 3, 3) containing the B rotation matricecs 14 | seq: Sequence of euler rotations (default: 'zyx') 15 | degrees (bool): If true (default), will return in degrees instead of radians 16 | 17 | Returns: 18 | 19 | """ 20 | 21 | eulers = [] 22 | for i in range(mats.shape[0]): 23 | r = Rotation.from_matrix(mats[i]) 24 | eulers.append(r.as_euler(seq, degrees=degrees)) 25 | return np.stack(eulers) 26 | 27 | 28 | def transform(g: np.ndarray, pts: np.ndarray): 29 | """ Applies the SO3 transform 30 | 31 | Args: 32 | g: SO3 transformation matrix of size (3, 3) 33 | pts: Points to be transformed (N, 3) 34 | 35 | Returns: 36 | transformed points of size (N, 3) 37 | 38 | """ 39 | rot = g[:3, :3] # (3, 3) 40 | transformed = pts @ rot.transpose() 41 | return transformed 42 | -------------------------------------------------------------------------------- /common/math_torch/__pycache__/se3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WinterCodeForEverything/SARNet/e60558a5a45016f84972be0df01b4490489d3e75/common/math_torch/__pycache__/se3.cpython-38.pyc -------------------------------------------------------------------------------- /common/math_torch/se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transformation group 2 | """ 3 | import torch 4 | 5 | 6 | def identity(batch_size): 7 | return torch.eye(3, 4)[None, ...].repeat(batch_size, 1, 1) 8 | 9 | 10 | def inverse(g): 11 | """ Returns the inverse of the SE3 transform 12 | 13 | Args: 14 | g: (B, 3/4, 4) transform 15 | 16 | Returns: 17 | (B, 3, 4) matrix containing the inverse 18 | 19 | """ 20 | # Compute inverse 21 | rot = g[..., 0:3, 0:3] 22 | trans = g[..., 0:3, 3] 23 | inverse_transform = torch.cat([rot.transpose(-1, -2), rot.transpose(-1, -2) @ -trans[..., None]], dim=-1) 24 | 25 | return inverse_transform 26 | 27 | 28 | def concatenate(a, b): 29 | """Concatenate two SE3 transforms, 30 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix) 31 | 32 | Args: 33 | a: (B, 3/4, 4) 34 | b: (B, 3/4, 4) 35 | 36 | Returns: 37 | (B, 3/4, 4) 38 | """ 39 | 40 | rot1 = a[..., :3, :3] 41 | trans1 = a[..., :3, 3] 42 | rot2 = b[..., :3, :3] 43 | trans2 = b[..., :3, 3] 44 | 45 | rot_cat = rot1 @ rot2 46 | trans_cat = rot1 @ trans2[..., None] + trans1[..., None] 47 | concatenated = torch.cat([rot_cat, trans_cat], dim=-1) 48 | 49 | return concatenated 50 | 51 | 52 | def transform(g, a, normals=None): 53 | """ Applies the SE3 transform 54 | 55 | Args: 56 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4) 57 | a: Points to be transformed (N, 3) or (B, N, 3) 58 | normals: (Optional). If provided, normals will be transformed 59 | 60 | Returns: 61 | transformed points of size (N, 3) or (B, N, 3) 62 | 63 | """ 64 | R = g[..., :3, :3] # (B, 3, 3) 65 | p = g[..., :3, 3] # (B, 3) 66 | 67 | if len(g.size()) == len(a.size()): 68 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :] 69 | else: 70 | raise NotImplementedError 71 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked 72 | 73 | if normals is not None: 74 | rotated_normals = normals @ R.transpose(-1, -2) 75 | return b, rotated_normals 76 | 77 | else: 78 | return b 79 | -------------------------------------------------------------------------------- /common/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc utilities 3 | """ 4 | 5 | import argparse 6 | from datetime import datetime 7 | import logging 8 | import os 9 | import shutil 10 | import subprocess 11 | import sys 12 | 13 | import coloredlogs 14 | import git 15 | 16 | 17 | _logger = logging.getLogger() 18 | 19 | 20 | def print_info(opt, log_dir=None): 21 | """ Logs source code configuration 22 | """ 23 | _logger.info('Command: {}'.format(' '.join(sys.argv))) 24 | 25 | # Print commit ID 26 | try: 27 | repo = git.Repo(search_parent_directories=True) 28 | git_sha = repo.head.object.hexsha 29 | git_date = datetime.fromtimestamp(repo.head.object.committed_date).strftime('%Y-%m-%d') 30 | git_message = repo.head.object.message 31 | _logger.info('Source is from Commit {} ({}): {}'.format(git_sha[:8], git_date, git_message.strip())) 32 | 33 | # Also create diff file in the log directory 34 | if log_dir is not None: 35 | with open(os.path.join(log_dir, 'compareHead.diff'), 'w') as fid: 36 | subprocess.run(['git', 'diff'], stdout=fid) 37 | 38 | except git.exc.InvalidGitRepositoryError: 39 | pass 40 | 41 | # Arguments 42 | arg_str = ['{}: {}'.format(key, value) for key, value in vars(opt).items()] 43 | arg_str = ', '.join(arg_str) 44 | _logger.info('Arguments: {}'.format(arg_str)) 45 | 46 | 47 | def prepare_logger(opt: argparse.Namespace, log_path: str = None): 48 | """Creates logging directory, and installs colorlogs 49 | 50 | Args: 51 | opt: Program arguments, should include --dev and --logdir flag. 52 | See get_parent_parser() 53 | log_path: Logging path (optional). This serves to overwrite the settings in 54 | argparse namespace 55 | 56 | Returns: 57 | logger (logging.Logger) 58 | log_path (str): Logging directory 59 | """ 60 | 61 | if log_path is None: 62 | if opt.dev: 63 | log_path = '../logdev' 64 | shutil.rmtree(log_path, ignore_errors=True) 65 | else: 66 | datetime_str = datetime.now().strftime('%y%m%d_%H%M%S') 67 | if opt.name is not None: 68 | log_path = os.path.join(opt.logdir, datetime_str + '_' + opt.name) 69 | else: 70 | log_path = os.path.join(opt.logdir, datetime_str) 71 | 72 | os.makedirs(log_path, exist_ok=True) 73 | logger = logging.getLogger() 74 | coloredlogs.install(level='INFO', logger=logger) 75 | file_handler = logging.FileHandler('{}/log.txt'.format(log_path)) 76 | log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s') 77 | file_handler.setFormatter(log_formatter) 78 | logger.addHandler(file_handler) 79 | print_info(opt, log_path) 80 | logger.info('Output and logs will be saved to {}'.format(log_path)) 81 | 82 | return logger, log_path 83 | -------------------------------------------------------------------------------- /common/torch.py: -------------------------------------------------------------------------------- 1 | """PyTorch related utility functions 2 | """ 3 | 4 | import logging 5 | import os 6 | import pdb 7 | import shutil 8 | import sys 9 | import time 10 | import traceback 11 | import random 12 | 13 | import numpy as np 14 | import torch 15 | from torch.optim.optimizer import Optimizer 16 | 17 | def set_seed(seed): 18 | ''' 19 | Set random seed for torch, numpy and python 20 | ''' 21 | random.seed(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | if torch.cuda.is_available(): 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | torch.backends.cudnn.benchmark=False 29 | torch.backends.cudnn.deterministic=True 30 | 31 | 32 | def dict_all_to_device(tensor_dict, device): 33 | """Sends everything into a certain device """ 34 | for k in tensor_dict: 35 | if isinstance(tensor_dict[k], torch.Tensor): 36 | tensor_dict[k] = tensor_dict[k].to(device) 37 | 38 | 39 | def to_numpy(tensor): 40 | """Wrapper around .detach().cpu().numpy() """ 41 | if isinstance(tensor, torch.Tensor): 42 | return tensor.detach().cpu().numpy() 43 | elif isinstance(tensor, np.ndarray): 44 | return tensor 45 | else: 46 | raise NotImplementedError 47 | 48 | 49 | class CheckPointManager(object): 50 | """Manager for saving/managing pytorch checkpoints. 51 | 52 | Provides functionality similar to tf.Saver such as 53 | max_to_keep and keep_checkpoint_every_n_hours 54 | """ 55 | def __init__(self, save_path: str = None, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0): 56 | 57 | if max_to_keep <= 0: 58 | raise ValueError('max_to_keep must be at least 1') 59 | 60 | self._max_to_keep = max_to_keep 61 | self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 62 | 63 | self._ckpt_dir = os.path.dirname(save_path) 64 | self._save_path = save_path + '-{}.pth' if save_path is not None else None 65 | self._logger = logging.getLogger(self.__class__.__name__) 66 | self._checkpoints_fname = os.path.join(self._ckpt_dir, 'checkpoints.txt') 67 | 68 | self._checkpoints_permanent = [] # Will not be deleted 69 | self._checkpoints_buffer = [] # Those which might still be deleted 70 | self._next_save_time = time.time() 71 | self._best_score = -float('inf') 72 | self._best_step = None 73 | 74 | os.makedirs(self._ckpt_dir, exist_ok=True) 75 | self._update_checkpoints_file() 76 | 77 | def _save_checkpoint(self, step, model, optimizer, score): 78 | save_name = self._save_path.format(step) 79 | state = {'state_dict': model.state_dict(), 80 | 'optimizer': optimizer.state_dict(), 81 | 'step': step} 82 | torch.save(state, save_name) 83 | self._logger.info('Saved checkpoint: {}'.format(save_name)) 84 | 85 | self._checkpoints_buffer.append((save_name, time.time())) 86 | 87 | if score > self._best_score: 88 | best_save_name = self._save_path.format('best') 89 | shutil.copyfile(save_name, best_save_name) 90 | self._best_score = score 91 | self._best_step = step 92 | self._logger.info('Checkpoint is current best, score={:.3g}'.format(self._best_score)) 93 | 94 | def _remove_old_checkpoints(self): 95 | while len(self._checkpoints_buffer) > self._max_to_keep: 96 | to_remove = self._checkpoints_buffer.pop(0) 97 | 98 | if to_remove[1] > self._next_save_time: 99 | self._checkpoints_permanent.append(to_remove) 100 | self._next_save_time = to_remove[1] + self._keep_checkpoint_every_n_hours * 3600 101 | else: 102 | os.remove(to_remove[0]) 103 | 104 | def _update_checkpoints_file(self): 105 | checkpoints = [os.path.basename(c[0]) for c in self._checkpoints_permanent + self._checkpoints_buffer] 106 | with open(self._checkpoints_fname, 'w') as fid: 107 | fid.write('\n'.join(checkpoints)) 108 | fid.write('\nBest step: {}'.format(self._best_step)) 109 | 110 | def save(self, model: torch.nn.Module, optimizer: Optimizer, step: int, 111 | score: float = 0.0): 112 | """Save model checkpoint to file 113 | 114 | Args: 115 | model: Torch model 116 | optimizer: Torch optimizer 117 | step (int): Step, model will be saved as model-[step].pth 118 | score (float, optional): To determine which model is the best 119 | """ 120 | if self._save_path is None: 121 | raise AssertionError('Checkpoint manager must be initialized with save path for save().') 122 | 123 | self._save_checkpoint(step, model, optimizer, score) 124 | self._remove_old_checkpoints() 125 | self._update_checkpoints_file() 126 | 127 | def load(self, save_path, model: torch.nn.Module = None, optimizer: Optimizer = None, distributed: bool=False): 128 | """Loads saved model from file 129 | 130 | Args: 131 | save_path: Path to saved model (.pth). If a directory is provided instead, model-best.pth is used 132 | model: Torch model to restore weights to 133 | optimizer: Optimizer 134 | """ 135 | if os.path.isdir(save_path): 136 | save_path = os.path.join(save_path, 'model-best.pth') 137 | 138 | state = torch.load(save_path) 139 | 140 | step = 0 141 | if 'step' in state: 142 | step = state['step'] 143 | 144 | if 'state_dict' in state and model is not None: 145 | if distributed: 146 | from collections import OrderedDict 147 | new_state_dict = OrderedDict() 148 | for k, v in state['state_dict'].items(): 149 | name = k[7:] # remove `module.` 150 | new_state_dict[name] = v 151 | model.load_state_dict(new_state_dict) 152 | else: 153 | model.load_state_dict(state['state_dict']) 154 | 155 | 156 | if 'optimizer' in state and optimizer is not None: 157 | optimizer.load_state_dict(state['optimizer']) 158 | 159 | self._logger.info('Loaded models from {}'.format(save_path)) 160 | return step 161 | 162 | 163 | class TorchDebugger(torch.autograd.detect_anomaly): 164 | """Enters debugger when anomaly detected""" 165 | def __enter__(self) -> None: 166 | super().__enter__() 167 | 168 | def __exit__(self, type, value, trace): 169 | super().__exit__() 170 | if isinstance(value, RuntimeError): 171 | traceback.print_tb(trace) 172 | print(value) 173 | if sys.gettrace() is None: 174 | pdb.set_trace() 175 | -------------------------------------------------------------------------------- /dataloader/NuScenesDataLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import open3d as o3d 3 | import numpy as np 4 | from os.path import join, exists 5 | from torch.utils.data import Dataset 6 | 7 | import yaml 8 | 9 | import common.math.se3 as se3 10 | import common.math.random as rdm 11 | 12 | label_yaml_config = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'semantic-nuscenes.yaml' ) 13 | DATA = yaml.safe_load(open(label_yaml_config, 'r')) 14 | remap_dict = DATA["learning_map"] 15 | max_key = max(remap_dict.keys()) 16 | remap_lut = np.zeros((max_key + 20), dtype=np.int32) 17 | remap_lut[list(remap_dict.keys())] = list(remap_dict.values()) 18 | 19 | def pc_normalize(pc): 20 | centroid = np.mean(pc, axis=0) 21 | pc = pc - centroid 22 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 23 | pc = pc / m 24 | return pc 25 | 26 | class NuScenesDataSet(Dataset): 27 | def __init__(self, root, transform, split = 'train', ignore_label=None, augment = 0.5 ): 28 | super(NuScenesDataSet, self).__init__() 29 | self.root = root 30 | self.transform = transform 31 | self.split = split 32 | self.dataset = self.make_dataset() 33 | self.ignore_label = ignore_label 34 | self.augment = augment 35 | 36 | def make_dataset(self): 37 | data_path = join( self.root, self.split ) 38 | dataset = os.listdir(data_path) 39 | return dataset 40 | 41 | def __getitem__(self, index): 42 | fn = self.dataset[index] 43 | data = np.load(join( self.root, self.split, fn )) 44 | points = data[:,:3] 45 | label = data[:, 3] 46 | sample = { 'points' : points.astype('float32'), 'seg' : remap_lut[label.astype('int32')], 47 | 'idx': np.array(index, dtype=np.int32) } 48 | sample = self.transform( sample ) 49 | points_src, points_ref = sample['points_src'], sample['points_ref'] 50 | labels_src, labels_ref = sample['seg_src'], sample['seg_ref'] 51 | intersect_elm = np.intersect1d( labels_src, labels_ref ) 52 | if self.ignore_label != None: 53 | intersect_elm = np.setdiff1d( intersect_elm, self.ignore_label ) 54 | if np.random.rand() < self.augment: 55 | for il in self.ignore_label: 56 | rand_T = np.zeros((4,4), dtype=np.float32) 57 | rand_T[3,3] = 1.0 58 | rand_rotm = rdm.generate_rand_rotm( 3.0, 3.0, 3.0 ) 59 | rand_T[:3,:3] = rand_rotm 60 | rand_trans = rdm.generate_rand_trans( 10.0, 1.0, 0.1 ) 61 | rand_T[:3,3] = rand_trans 62 | points_src[labels_src == il] = se3.transform( rand_T, 63 | points_src[labels_src == il] ) 64 | 65 | intersect_src = np.isin( labels_src, intersect_elm ).astype(int) 66 | intersect_ref = np.isin( labels_ref, intersect_elm ).astype(int) 67 | sample['intersect_src'] = intersect_src 68 | sample['intersect_ref'] = intersect_ref 69 | 70 | sample['points_src'] = points_src.astype('float32') 71 | sample['points_ref'] = points_ref.astype('float32') 72 | 73 | sample.pop( "seg" ) 74 | sample.pop( "points_raw" ) 75 | return sample 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | -------------------------------------------------------------------------------- /dataloader/SemanticKITTYDataLoader.py: -------------------------------------------------------------------------------- 1 | #@InProceedings{Lu_2021_HRegNet, 2 | # author = {Lu, Fan and Chen, Guang and Liu, Yinlong and Zhang Lijun, Qu Sanqing, Liu Shu, Gu Rongqi}, 3 | # title = {HRegNet: A Hierarchical Network for Large-scale Outdoor LiDAR Point Cloud Registration}, 4 | # booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision}, 5 | # year = {2021} 6 | #} 7 | 8 | import numpy as np 9 | #from scipy.stats.stats import pointbiserialr 10 | import torch 11 | #import torchvision 12 | from torch.utils.data import Dataset 13 | 14 | import os 15 | from os.path import join 16 | import numpy as np 17 | import yaml 18 | 19 | import common.math.se3 as se3 20 | import common.math.random as rdm 21 | 22 | label_yaml_config = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'semantic-kitti.yaml' ) 23 | DATA = yaml.safe_load(open(label_yaml_config, 'r')) 24 | remap_dict = DATA["learning_map"] 25 | max_key = max(remap_dict.keys()) 26 | remap_lut = np.zeros((max_key + 20), dtype=np.int32) 27 | remap_lut[list(remap_dict.keys())] = list(remap_dict.values()) 28 | 29 | def pc_normalize(pc): 30 | centroid = np.mean(pc, axis=0) 31 | pc = pc - centroid 32 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 33 | pc = pc / m 34 | return pc 35 | 36 | def compute_diagonal(pc): 37 | centroid = np.mean(pc, axis=0) 38 | pc = pc - centroid 39 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 40 | return m 41 | 42 | class KittiDataset(Dataset): 43 | ''' 44 | Params: 45 | root 46 | split 47 | transform 48 | ignore_label 49 | ''' 50 | def __init__(self, root, split, transform, ignore_label=None, augment = 1.0 ): 51 | super(KittiDataset, self).__init__() 52 | 53 | self.root = root 54 | self.split = split 55 | self.transform = transform 56 | self.dataset = self.make_dataset() 57 | self.ignore_label = ignore_label 58 | self.augment = augment 59 | 60 | def make_dataset(self): 61 | data_path = join( self.root, self.split ) 62 | dataset = os.listdir(data_path) 63 | return dataset 64 | 65 | def __getitem__(self, index): 66 | fn = self.dataset[index] 67 | data = np.load(join( self.root, self.split, fn )) 68 | points = data[:,:3] 69 | label = data[:, 3] 70 | sample = { 'points' : points.astype('float32'), 'seg' : remap_lut[label.astype('int32')], 71 | 'idx': np.array(index, dtype=np.int32) } 72 | sample = self.transform( sample ) 73 | points_src, points_ref = sample['points_src'], sample['points_ref'] 74 | labels_src, labels_ref = sample['seg_src'], sample['seg_ref'] 75 | 76 | intersect_elm = np.intersect1d( labels_src, labels_ref ) 77 | if self.ignore_label != None: 78 | intersect_elm = np.setdiff1d( intersect_elm, self.ignore_label ) 79 | if np.random.rand() < self.augment: 80 | for il in self.ignore_label: 81 | rand_T = np.zeros((4,4), dtype=np.float32) 82 | rand_T[3,3] = 1.0 83 | rand_rotm = rdm.generate_rand_rotm( 3.0, 3.0, 3.0 ) 84 | rand_T[:3,:3] = rand_rotm 85 | rand_trans = rdm.generate_rand_trans( 10.0, 1.0, 0.1 ) 86 | rand_T[:3,3] = rand_trans 87 | points_src[labels_src == il] = se3.transform( rand_T, 88 | points_src[labels_src == il] ) 89 | 90 | intersect_src = np.isin( labels_src, intersect_elm ).astype(int) 91 | intersect_ref = np.isin( labels_ref, intersect_elm ).astype(int) 92 | sample['intersect_src'] = intersect_src 93 | sample['intersect_ref'] = intersect_ref 94 | 95 | sample['points_src'] = points_src.astype('float32') 96 | sample['points_ref'] = points_ref.astype('float32') 97 | 98 | sample.pop( "seg" ) 99 | sample.pop( "points_raw" ) 100 | return sample 101 | 102 | def __len__(self): 103 | return len(self.dataset) 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /dataloader/data_prepare_SemanticKITTI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join, exists 4 | 5 | data_root = 'your SemanticKitty path' 6 | process_data_path = join( data_root, 'process') 7 | if not exists(process_data_path): 8 | os.makedirs(process_data_path) 9 | train_data_path = join( process_data_path, 'train') 10 | if not exists(train_data_path): 11 | os.makedirs(train_data_path) 12 | val_data_path = join( process_data_path, 'val') 13 | if not exists(val_data_path): 14 | os.makedirs(val_data_path) 15 | test_data_path = join( process_data_path, 'test') 16 | if not exists(test_data_path): 17 | os.makedirs(test_data_path) 18 | 19 | if __name__ == '__main__': 20 | for i in range(11): 21 | points_file_path = os.path.join(data_root, 'velodyne/sequences', f'{i:02d}', 'velodyne' ) 22 | labels_file_path = os.path.join(data_root, 'label/sequences', f'{i:02d}', 'labels' ) 23 | points_file_names = sorted(os.listdir(points_file_path)) 24 | file_num = len(points_file_names) 25 | train_num, val_num = file_num*7//10, file_num*8//10 26 | file_idx = 0 27 | for pfn in points_file_names: 28 | bpfn = os.path.splitext(pfn)[0] 29 | pffn = os.path.join(points_file_path, pfn) 30 | lfn = bpfn + '.label' 31 | lffn = os.path.join(labels_file_path, lfn) 32 | points = np.fromfile( pffn, dtype=np.float32, count=-1).reshape([-1,4] )[:, :3] 33 | label = np.fromfile( lffn, dtype=np.uint32 ).reshape( (-1,1) ) 34 | label = label & 0xFFFF # only semantic label 35 | label = label.astype(np.float32) 36 | data = np.concatenate((points, label), axis=1) 37 | if file_idx < train_num: 38 | save_full_path = join( train_data_path, f'{i:02d}' + bpfn + '.npy') 39 | np.save( save_full_path, data ) 40 | elif file_idx < val_num: 41 | save_full_path = join( val_data_path, f'{i:02d}' + bpfn + '.npy') 42 | np.save( save_full_path, data ) 43 | else: 44 | save_full_path = join( test_data_path, f'{i:02d}' + bpfn + '.npy') 45 | np.save( save_full_path, data ) 46 | print( f'{i:02d}' + bpfn + '.npy' ) 47 | file_idx += 1 -------------------------------------------------------------------------------- /dataloader/data_prepare_nuScenes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join, exists 4 | #import h5py 5 | from collections import defaultdict 6 | from nuscenes import NuScenes 7 | 8 | data_root = 'your nuscense path' 9 | process_data_path = join( data_root, 'process') 10 | if not exists(process_data_path): 11 | os.makedirs(process_data_path) 12 | train_data_path = join( process_data_path, 'train') 13 | if not exists(train_data_path): 14 | os.makedirs(train_data_path) 15 | val_data_path = join( process_data_path, 'val') 16 | if not exists(val_data_path): 17 | os.makedirs(val_data_path) 18 | test_data_path = join( process_data_path, 'test') 19 | if not exists(test_data_path): 20 | os.makedirs(test_data_path) 21 | 22 | 23 | if __name__ == '__main__': 24 | nusc = NuScenes(version='v1.0-trainval', dataroot=data_root, verbose=False) 25 | scene_sample_data = defaultdict(list) 26 | for lidarseg in nusc.lidarseg: 27 | sample_data = nusc.get('sample_data', lidarseg['sample_data_token']) 28 | points_fn = join(data_root, sample_data['filename']) 29 | sample = nusc.get('sample', sample_data['sample_token']) 30 | scene = nusc.get('scene', sample['scene_token']) 31 | scene_sample_data[scene['name']].append( [sample_data['filename'], 32 | lidarseg['filename'], sample_data['token']] ) 33 | for k in scene_sample_data: 34 | sample_data_idx = 0 35 | for s in scene_sample_data[k]: 36 | points_fn, seg_fn, data_token = s[0], s[1], s[2] 37 | points = np.fromfile(join(data_root, points_fn), dtype=np.float32, count=-1).reshape([-1,5]) 38 | xyz = points[:, :3] 39 | label = np.fromfile(join(data_root, seg_fn), dtype=np.uint8, count=-1).reshape([-1,1]).astype(np.float32) 40 | data = np.concatenate((xyz, label), axis=1) 41 | if sample_data_idx < 30: 42 | save_full_path = join( train_data_path, data_token + '.npy') 43 | np.save( save_full_path, data ) 44 | elif sample_data_idx < 34: 45 | save_full_path = join( val_data_path, data_token + '.npy') 46 | np.save( save_full_path, data ) 47 | else: 48 | save_full_path = join( test_data_path, data_token + '.npy') 49 | np.save( save_full_path, data ) 50 | print( data_token ) 51 | sample_data_idx += 1 52 | -------------------------------------------------------------------------------- /dataloader/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /dataloader/semantic-nuscenes.yaml: -------------------------------------------------------------------------------- 1 | labels: 2 | 0 : "unlabeled" 3 | 1 : "Car or Van or SUV" 4 | 2: "Truck" 5 | 3: "Bendy Bus" 6 | 4: "Rigid Bus" 7 | 5: "Construction Vehicle" 8 | 6: "Motorcycle" 9 | 7: "Bicycle" 10 | 8: "Bicycle Rack" 11 | 9: "Trailer" 12 | 10: "Police Vehicle" 13 | 11: "Ambulance" 14 | 12: "Adult Pedestrian" 15 | 13: "Child Pedestrian" 16 | 14: "Construction Worker" 17 | 15: "Stroller" 18 | 16: "Wheelchair" 19 | 17: "Portable Personal Mobility Vehicle" 20 | 18: "Police Officer" 21 | 19: "Animal" 22 | 20: "Traffic Cone" 23 | 21: "Temporary Traffic Barrier" 24 | 22: "Pushable Pullable Object" 25 | 23: "Debris" 26 | 24: "flat.driveable_surface" 27 | 25: "flat.sidewalk" 28 | 26: "flat.terrain" 29 | 27: "flat.other" 30 | 28: "static.manmade" 31 | 29: "static.vegetation" 32 | 30: "static.other" 33 | 31: "vehicle.ego" 34 | 35 | learning_map: 36 | 0: 0 37 | 1: 1 38 | 2: 1 39 | 3: 1 40 | 4: 1 41 | 5: 1 42 | 6: 2 43 | 7: 2 44 | 8: 2 45 | 9: 1 46 | 10: 2 47 | 11: 1 48 | 12: 3 49 | 13: 3 50 | 14: 3 51 | 15: 4 52 | 16: 4 53 | 17: 1 54 | 18: 3 55 | 19: 5 56 | 20: 6 57 | 21: 6 58 | 22: 7 59 | 23: 7 60 | 24: 8 61 | 25: 9 62 | 26: 10 63 | 27: 11 64 | 28: 12 65 | 29: 13 66 | 30: 0 67 | 31: 1 68 | 69 | -------------------------------------------------------------------------------- /dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | from scipy.stats import special_ortho_group 7 | import torch 8 | import torch.utils.data 9 | 10 | import open3d as o3d 11 | 12 | from common.math.random import uniform_2_sphere 13 | import common.math.se3 as se3 14 | import common.math.so3 as so3 15 | 16 | def get_transforms(noise_type: str, 17 | rot_mag: float = 45.0, trans_mag: float = 5, voxel_size: float = 0.3, 18 | num: int = 1024, diagonal: float = 1.0, partial_p_keep: List = None): 19 | """Get the list of transformation to be used for training or evaluating RegNet 20 | 21 | Args: 22 | noise_type: Either 'clean', 'jitter', 'crop'. 23 | Depending on the option, some of the subsequent arguments may be ignored. 24 | rot_mag: Magnitude of rotation perturbation to apply to source, in degrees. 25 | Default: 45.0 (same as Deep Closest Point) 26 | trans_mag: Magnitude of translation perturbation to apply to source. 27 | Default: 0.5 (same as Deep Closest Point) 28 | voxel_size: voxel cell size to do voxel sampling 29 | num: Number of points to uniformly resample to. 30 | Note that this is with respect to the full point cloud. The number of 31 | points will be proportionally less if cropped 32 | partial_p_keep: Proportion to keep during cropping, [src_p, ref_p] 33 | Default: [0.7, 0.7], i.e. Crop both source and reference to ~70% 34 | 35 | Returns: 36 | train_transforms, test_transforms: Both contain list of transformations to be applied 37 | """ 38 | 39 | partial_p_keep = partial_p_keep if partial_p_keep is not None else [0.7, 0.7] 40 | 41 | if noise_type == "clean": 42 | # 1-1 correspondence for each point (resample first before splitting), no noise 43 | train_transforms = [VoxelResampler(voxel_size, num), 44 | SplitSourceRef(), 45 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 46 | ShufflePoints()] 47 | 48 | test_transforms = [SetDeterministic(), 49 | FixedVoxelResampler(voxel_size, num), 50 | SplitSourceRef(), 51 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 52 | ShufflePoints()] 53 | 54 | elif noise_type == "jitter": 55 | # Points randomly sampled (might not have perfect correspondence), gaussian noise to position 56 | train_transforms = [SplitSourceRef(), 57 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 58 | VoxelResampler(voxel_size, num), 59 | RandomJitter(diagonal), 60 | ShufflePoints()] 61 | 62 | test_transforms = [SetDeterministic(), 63 | SplitSourceRef(), 64 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 65 | VoxelResampler(voxel_size, num), 66 | RandomJitter(diagonal), 67 | ShufflePoints()] 68 | 69 | elif noise_type == "crop": 70 | # Both source and reference point clouds cropped, plus same noise in "jitter" 71 | train_transforms = [SplitSourceRef(), 72 | RandomCrop(partial_p_keep), 73 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 74 | VoxelResampler(voxel_size, num), 75 | RandomJitter(diagonal), 76 | ShufflePoints()] 77 | 78 | test_transforms = [SetDeterministic(), 79 | SplitSourceRef(), 80 | RandomCrop(partial_p_keep), 81 | RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 82 | VoxelResampler(voxel_size, num), 83 | RandomJitter(diagonal), 84 | ShufflePoints()] 85 | else: 86 | raise NotImplementedError 87 | 88 | return train_transforms, test_transforms 89 | 90 | 91 | 92 | class SplitSourceRef: 93 | """Clones the point cloud into separate source and reference point clouds""" 94 | def __call__(self, sample: Dict): 95 | sample['points_raw'] = sample.pop('points') 96 | if isinstance(sample['points_raw'], torch.Tensor): 97 | sample['points_src'] = sample['points_raw'].detach() 98 | sample['points_ref'] = sample['points_raw'].detach() 99 | if 'seg' in sample: 100 | sample['seg_src'] = sample['seg'].detach() 101 | sample['seg_ref'] = sample['seg'].detach() 102 | 103 | else: # is numpy 104 | sample['points_src'] = sample['points_raw'].copy() 105 | sample['points_ref'] = sample['points_raw'].copy() 106 | if 'seg' in sample: 107 | sample['seg_src'] = sample['seg'].copy() 108 | sample['seg_ref'] = sample['seg'].copy() 109 | 110 | return sample 111 | 112 | 113 | class VoxelResampler: 114 | def __init__(self, voxel_size: float, num: int): 115 | """Resamples a point cloud containing N points to one containing M 116 | 117 | Guaranteed to have no repeated points if M <= N. 118 | Otherwise, it is guaranteed that all points appear at least once. 119 | 120 | Args: 121 | num (int): Number of points to resample to, i.e. M 122 | 123 | """ 124 | self.voxel_size = voxel_size 125 | self.num = num 126 | 127 | 128 | def __call__(self, sample): 129 | 130 | if 'deterministic' in sample and sample['deterministic']: 131 | np.random.seed(sample['idx']) 132 | 133 | if 'points' in sample: 134 | if 'seg' in sample: 135 | sample['points'], sample['seg'] = self._resample(sample['points'], sample['seg'], self.voxel_size, self.num) 136 | else: 137 | sample['points'], _ = self._resample(sample['points'], None, self.voxel_size, self.num) 138 | else: 139 | if 'crop_proportion' not in sample: 140 | src_size, ref_size = self.num, self.num 141 | elif len(sample['crop_proportion']) == 1: 142 | src_size = math.ceil(sample['crop_proportion'][0] * self.num) 143 | ref_size = self.num 144 | elif len(sample['crop_proportion']) == 2: 145 | src_size = math.ceil(sample['crop_proportion'][0] * self.num) 146 | ref_size = math.ceil(sample['crop_proportion'][1] * self.num) 147 | else: 148 | raise ValueError('Crop proportion must have 1 or 2 elements') 149 | 150 | if 'seg' in sample: 151 | sample['points_src'], sample['seg_src'] = self._resample(sample['points_src'], sample['seg_src'], 152 | self.voxel_size, src_size ) 153 | sample['points_ref'], sample['seg_ref'] = self._resample(sample['points_ref'], sample['seg_ref'], 154 | self.voxel_size, ref_size ) 155 | else: 156 | sample['points_src'], _ = self._resample(sample['points_src'], None, self.voxel_size, src_size ) 157 | sample['points_ref'], _ = self._resample(sample['points_ref'], None, self.voxel_size, ref_size ) 158 | 159 | return sample 160 | 161 | @staticmethod 162 | def _resample( points, seg, voxel_size, k ): 163 | """Resamples the points such that there is exactly k points. 164 | 165 | If the input point cloud has <= k points, it is guaranteed the 166 | resampled point cloud contains every point in the input. 167 | If the input point cloud has > k points, it is guaranteed the 168 | resampled point cloud does not contain repeated point. 169 | """ 170 | pcd = o3d.geometry.PointCloud() 171 | pcd.points = o3d.utility.Vector3dVector(points) 172 | if voxel_size is not None: 173 | pcd_ds_and_idx = pcd.voxel_down_sample_and_trace(voxel_size, pcd.get_min_bound(), pcd.get_max_bound(), False ) 174 | points = np.asarray( pcd_ds_and_idx[0].points) 175 | idx = np.max( pcd_ds_and_idx[1], axis=1 ) 176 | if seg is not None: 177 | seg = seg[idx] 178 | 179 | if k <= points.shape[0]: 180 | rand_idxs = np.random.choice(points.shape[0], k, replace=False) 181 | if seg is not None: 182 | return points[rand_idxs, :], seg[rand_idxs] 183 | return points[rand_idxs, :], None 184 | elif points.shape[0] == k: 185 | return points, seg 186 | else: 187 | rand_idxs = np.concatenate([np.random.choice(points.shape[0], points.shape[0], replace=False), 188 | np.random.choice(points.shape[0], k - points.shape[0], replace=True)]) 189 | if seg is not None: 190 | return points[rand_idxs, :], seg[rand_idxs] 191 | return points[rand_idxs, :], None 192 | 193 | 194 | class FixedVoxelResampler(VoxelResampler): 195 | """Fixed resampling to always choose the first N points. 196 | Always deterministic regardless of whether the deterministic flag has been set 197 | """ 198 | @staticmethod 199 | def _resample(points, seg, voxel_size, k): 200 | pcd = o3d.geometry.PointCloud() 201 | pcd.points = o3d.utility.Vector3dVector(points) 202 | if voxel_size is not None: 203 | pcd_ds_and_idx = pcd.voxel_down_sample_and_trace(voxel_size, pcd.get_min_bound(), pcd.get_max_bound(), False ) 204 | points = np.asarray( pcd_ds_and_idx[0].points) 205 | idx = np.max( pcd_ds_and_idx[1], axis=1 ) 206 | if seg is not None: 207 | seg = seg[idx] 208 | 209 | multiple = k // points.shape[0] 210 | remainder = k % points.shape[0] 211 | 212 | points = np.concatenate((np.tile(points, (multiple, 1)), points[:remainder, :]), axis=0) 213 | if seg is not None: 214 | seg = np.concatenate((np.tile(seg, multiple), seg[:remainder]), axis=0) 215 | return points, seg 216 | 217 | 218 | class RandomJitter: 219 | """ generate perturbations """ 220 | def __init__(self, diagonal= 1): #scale=0.01, clip=0.05 221 | self.scale = diagonal * 0.01 222 | self.clip = diagonal * 0.05 223 | 224 | def jitter(self, pts): 225 | 226 | noise = np.clip(np.random.normal(0.0, scale=self.scale, size=(pts.shape[0], 3)), 227 | a_min=-self.clip, a_max=self.clip) 228 | pts[:, :3] += noise # Add noise to xyz 229 | 230 | return pts 231 | 232 | def __call__(self, sample): 233 | 234 | if 'points' in sample: 235 | sample['points'] = self.jitter(sample['points']) 236 | else: 237 | sample['points_src'] = self.jitter(sample['points_src']) 238 | sample['points_ref'] = self.jitter(sample['points_ref']) 239 | 240 | return sample 241 | 242 | 243 | class RandomCrop: 244 | """Randomly crops the *source* point cloud, approximately retaining half the points 245 | 246 | A direction is randomly sampled from S2, and we retain points which lie within the 247 | half-space oriented in this direction. 248 | If p_keep != 0.5, we shift the plane until approximately p_keep points are retained 249 | """ 250 | def __init__(self, p_keep: List = None): 251 | if p_keep is None: 252 | p_keep = [0.7, 0.7] # Crop both clouds to 70% 253 | self.p_keep = np.array(p_keep, dtype=np.float32) 254 | 255 | @staticmethod 256 | def crop(points, p_keep): 257 | rand_xyz = uniform_2_sphere() 258 | centroid = np.mean(points[:, :3], axis=0) 259 | points_centered = points[:, :3] - centroid 260 | 261 | dist_from_plane = np.dot(points_centered, rand_xyz) 262 | if p_keep == 0.5: 263 | mask = dist_from_plane > 0 264 | else: 265 | mask = dist_from_plane > np.percentile(dist_from_plane, (1.0 - p_keep) * 100) 266 | 267 | return points[mask, :] 268 | 269 | def __call__(self, sample): 270 | 271 | sample['crop_proportion'] = self.p_keep 272 | if np.all(self.p_keep == 1.0): 273 | return sample # No need crop 274 | 275 | if 'deterministic' in sample and sample['deterministic']: 276 | np.random.seed(sample['idx']) 277 | 278 | if len(self.p_keep) == 1: 279 | sample['points_src'] = self.crop(sample['points_src'], self.p_keep[0]) 280 | else: 281 | sample['points_src'] = self.crop(sample['points_src'], self.p_keep[0]) 282 | sample['points_ref'] = self.crop(sample['points_ref'], self.p_keep[1]) 283 | return sample 284 | 285 | 286 | class RandomTransformSE3: 287 | def __init__(self, rot_mag: float = 180.0, trans_mag: float = 1.0, random_mag: bool = False): 288 | """Applies a random rigid transformation to the source point cloud 289 | 290 | Args: 291 | rot_mag (float): Maximum rotation in degrees 292 | trans_mag (float): Maximum translation T. Random translation will 293 | be in the range [-X,X] in each axis 294 | random_mag (bool): If true, will randomize the maximum rotation, i.e. will bias towards small 295 | perturbations 296 | """ 297 | self._rot_mag = rot_mag 298 | self._trans_mag = trans_mag 299 | self._random_mag = random_mag 300 | 301 | def generate_transform(self): 302 | """Generate a random SE3 transformation (3, 4) """ 303 | 304 | if self._random_mag: 305 | attentuation = np.random.random() 306 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 307 | else: 308 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 309 | 310 | # Generate rotation 311 | rand_rot = special_ortho_group.rvs(3) 312 | axis_angle = Rotation.as_rotvec(Rotation.from_dcm(rand_rot)) 313 | axis_angle *= rot_mag / 180.0 314 | rand_rot = Rotation.from_rotvec(axis_angle).as_dcm() 315 | 316 | # Generate translation 317 | rand_trans = np.random.uniform(-trans_mag, trans_mag, 3) 318 | rand_SE3 = np.concatenate((rand_rot, rand_trans[:, None]), axis=1).astype(np.float32) 319 | 320 | return rand_SE3 321 | 322 | def apply_transform(self, p0, transform_mat): 323 | p1 = se3.transform(transform_mat, p0[:, :3]) 324 | if p0.shape[1] == 6: # Need to rotate normals also 325 | n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6]) 326 | p1 = np.concatenate((p1, n1), axis=-1) 327 | 328 | igt = transform_mat 329 | gt = se3.inverse(igt) 330 | 331 | return p1, gt, igt 332 | 333 | def transform(self, tensor): 334 | transform_mat = self.generate_transform() 335 | return self.apply_transform(tensor, transform_mat) 336 | 337 | def __call__(self, sample): 338 | 339 | if 'deterministic' in sample and sample['deterministic']: 340 | np.random.seed(sample['idx']) 341 | 342 | if 'points' in sample: 343 | sample['points'], _, _ = self.transform(sample['points']) 344 | else: 345 | src_transformed, transform_r_s, transform_s_r = self.transform(sample['points_src']) 346 | sample['transform_gt'] = transform_r_s # Apply to source to get reference 347 | sample['points_src'] = src_transformed 348 | 349 | return sample 350 | 351 | 352 | # noinspection PyPep8Naming 353 | class RandomTransformSE3_euler(RandomTransformSE3): 354 | """Same as RandomTransformSE3, but rotates using euler angle rotations 355 | 356 | This transformation is consistent to Deep Closest Point but does not 357 | generate uniform rotations 358 | 359 | """ 360 | def generate_transform(self): 361 | 362 | if self._random_mag: 363 | attentuation = np.random.random() 364 | rot_mag, trans_mag = attentuation * self._rot_mag, attentuation * self._trans_mag 365 | else: 366 | rot_mag, trans_mag = self._rot_mag, self._trans_mag 367 | 368 | # Generate rotation 369 | anglex = np.random.uniform() * np.pi * rot_mag / 180.0 370 | angley = np.random.uniform() * np.pi * rot_mag / 180.0 371 | anglez = np.random.uniform() * np.pi * rot_mag / 180.0 372 | 373 | cosx = np.cos(anglex) 374 | cosy = np.cos(angley) 375 | cosz = np.cos(anglez) 376 | sinx = np.sin(anglex) 377 | siny = np.sin(angley) 378 | sinz = np.sin(anglez) 379 | Rx = np.array([[1, 0, 0], 380 | [0, cosx, -sinx], 381 | [0, sinx, cosx]]) 382 | Ry = np.array([[cosy, 0, siny], 383 | [0, 1, 0], 384 | [-siny, 0, cosy]]) 385 | Rz = np.array([[cosz, -sinz, 0], 386 | [sinz, cosz, 0], 387 | [0, 0, 1]]) 388 | R_ab = Rx @ Ry @ Rz 389 | t_ab = np.random.uniform(-trans_mag, trans_mag, 3) 390 | 391 | rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32) 392 | return rand_SE3 393 | 394 | 395 | class RandomRotatorZ(RandomTransformSE3): 396 | """Applies a random z-rotation to the source point cloud""" 397 | 398 | def __init__(self): 399 | super().__init__(rot_mag=360) 400 | 401 | def generate_transform(self): 402 | """Generate a random SE3 transformation (3, 4) """ 403 | 404 | rand_rot_deg = np.random.random() * self._rot_mag 405 | rand_rot = Rotation.from_euler('z', rand_rot_deg, degrees=True).as_dcm() 406 | rand_SE3 = np.pad(rand_rot, ((0, 0), (0, 1)), mode='constant').astype(np.float32) 407 | 408 | return rand_SE3 409 | 410 | 411 | class ShufflePoints: 412 | """Shuffles the order of the points""" 413 | def __call__(self, sample): 414 | if 'points' in sample: 415 | per_arr = np.random.permutation( sample['points'].shape[0] ) 416 | sample['points'] = sample['points'][per_arr, :] 417 | if 'seg' in sample: 418 | sample['seg'] = sample['seg'][per_arr] 419 | else: 420 | src_arr = np.random.permutation( sample['points_src'].shape[0] ) 421 | sample['points_src'] = sample['points_src'][src_arr, :] 422 | ref_arr = np.random.permutation( sample['points_ref'].shape[0] ) 423 | sample['points_ref'] = sample['points_ref'][ref_arr, :] 424 | if 'seg' in sample: 425 | sample['seg_src'] = sample['seg_src'][src_arr] 426 | sample['seg_ref'] = sample['seg_ref'][ref_arr] 427 | return sample 428 | 429 | 430 | class SetDeterministic: 431 | """Adds a deterministic flag to the sample such that subsequent transforms 432 | use a fixed random seed where applicable. Used for test""" 433 | def __call__(self, sample): 434 | sample['deterministic'] = True 435 | return sample 436 | 437 | 438 | class Dict2DcpList: 439 | """Converts dictionary of tensors into a list of tensors compatible with Deep Closest Point""" 440 | def __call__(self, sample): 441 | 442 | target = sample['points_src'][:, :3].transpose().copy() 443 | src = sample['points_ref'][:, :3].transpose().copy() 444 | 445 | rotation_ab = sample['transform_gt'][:3, :3].transpose().copy() 446 | translation_ab = -rotation_ab @ sample['transform_gt'][:3, 3].copy() 447 | 448 | rotation_ba = sample['transform_gt'][:3, :3].copy() 449 | translation_ba = sample['transform_gt'][:3, 3].copy() 450 | 451 | euler_ab = Rotation.from_dcm(rotation_ab).as_euler('zyx').copy() 452 | euler_ba = Rotation.from_dcm(rotation_ba).as_euler('xyz').copy() 453 | 454 | return src, target, \ 455 | rotation_ab, translation_ab, rotation_ba, translation_ba, \ 456 | euler_ab, euler_ba 457 | 458 | 459 | class Dict2PointnetLKList: 460 | """Converts dictionary of tensors into a list of tensors compatible with PointNet LK""" 461 | def __call__(self, sample): 462 | 463 | if 'points' in sample: 464 | # Train Classifier (pretraining) 465 | return sample['points'][:, :3], sample['label'] 466 | else: 467 | # Train PointNetLK 468 | transform_gt_4x4 = np.concatenate([sample['transform_gt'], 469 | np.array([[0.0, 0.0, 0.0, 1.0]], dtype=np.float32)], axis=0) 470 | return sample['points_src'][:, :3], sample['points_ref'][:, :3], transform_gt_4x4 471 | -------------------------------------------------------------------------------- /dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.autograd import Function 5 | import torch.nn.functional as F 6 | 7 | import point_utils_cuda 8 | from pytorch3d.loss import chamfer_distance 9 | from pytorch3d.ops import knn_points, knn_gather 10 | from scipy.spatial.transform import Rotation 11 | import random 12 | 13 | class FurthestPointSampling(Function): 14 | @staticmethod 15 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 16 | ''' 17 | ctx: 18 | xyz: [B,N,3] 19 | npoint: int 20 | ''' 21 | assert xyz.is_contiguous() 22 | 23 | B, N, _ = xyz.size() 24 | output = torch.cuda.IntTensor(B, npoint) 25 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 26 | 27 | point_utils_cuda.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 28 | return output 29 | 30 | @staticmethod 31 | def backward(xyz, a=None): 32 | return None, None 33 | 34 | furthest_point_sample = FurthestPointSampling.apply 35 | 36 | class WeightedFurthestPointSampling(Function): 37 | @staticmethod 38 | def forward(ctx, xyz: torch.Tensor, weights: torch.Tensor, npoint: int) -> torch.Tensor: 39 | ''' 40 | ctx: 41 | xyz: [B,N,3] 42 | weights: [B,N] 43 | npoint: int 44 | ''' 45 | assert xyz.is_contiguous() 46 | assert weights.is_contiguous() 47 | B, N, _ = xyz.size() 48 | output = torch.cuda.IntTensor(B, npoint) 49 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 50 | 51 | point_utils_cuda.weighted_furthest_point_sampling_wrapper(B, N, npoint, xyz, weights, temp, output); 52 | return output 53 | 54 | @staticmethod 55 | def backward(xyz, a=None): 56 | return None, None 57 | 58 | weighted_furthest_point_sample = WeightedFurthestPointSampling.apply 59 | 60 | class GatherOperation(Function): 61 | @staticmethod 62 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 63 | ''' 64 | ctx 65 | features: [B,C,N] 66 | idx: [B,npoint] 67 | ''' 68 | assert features.is_contiguous() 69 | assert idx.is_contiguous() 70 | 71 | B, npoint = idx.size() 72 | _, C, N = features.size() 73 | output = torch.cuda.FloatTensor(B, C, npoint) 74 | 75 | point_utils_cuda.gather_points_wrapper(B, C, N, npoint, features, idx, output) 76 | 77 | ctx.for_backwards = (idx, C, N) 78 | return output 79 | 80 | @staticmethod 81 | def backward(ctx, grad_out): 82 | idx, C, N = ctx.for_backwards 83 | B, npoint = idx.size() 84 | grad_features = Variable(torch.cuda.FloatTensor(B,C,N).zero_()) 85 | grad_out_data = grad_out.data.contiguous() 86 | point_utils_cuda.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 87 | return grad_features, None 88 | 89 | gather_operation = GatherOperation.apply 90 | 91 | def generate_rand_rotm(x_lim=5.0, y_lim=5.0, z_lim=180.0): 92 | ''' 93 | Input: 94 | x_lim 95 | y_lim 96 | z_lim 97 | return: 98 | rotm: [3,3] 99 | ''' 100 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 101 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 102 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 103 | 104 | rand_eul = np.array([rand_z, rand_y, rand_x]) 105 | r = Rotation.from_euler('zyx', rand_eul, degrees=True) 106 | rotm = r.as_matrix() 107 | return rotm 108 | 109 | def generate_rand_trans(x_lim=10.0, y_lim=1.0, z_lim=0.1): 110 | ''' 111 | Input: 112 | x_lim 113 | y_lim 114 | z_lim 115 | return: 116 | trans [3] 117 | ''' 118 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 119 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 120 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 121 | 122 | rand_trans = np.array([rand_x, rand_y, rand_z]) 123 | 124 | return rand_trans 125 | 126 | def apply_transform(pts, trans): 127 | R = trans[:3, :3] 128 | T = trans[:3, 3] 129 | pts = pts @ R.T + T 130 | return pts 131 | 132 | def calc_error_np(pred_R, pred_t, gt_R, gt_t): 133 | tmp = (np.trace(pred_R.transpose().dot(gt_R))-1)/2 134 | if np.abs(tmp) > 1.0: 135 | tmp = 1.0 136 | L_rot = np.arccos(tmp) 137 | L_rot = 180 * L_rot / np.pi 138 | L_trans = np.linalg.norm(pred_t - gt_t) 139 | return L_rot, L_trans 140 | 141 | def set_seed(seed): 142 | ''' 143 | Set random seed for torch, numpy and python 144 | ''' 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | torch.manual_seed(seed) 148 | if torch.cuda.is_available(): 149 | torch.cuda.manual_seed(seed) 150 | torch.cuda.manual_seed_all(seed) 151 | 152 | torch.backends.cudnn.benchmark=False 153 | torch.backends.cudnn.deterministic=True -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from typing import Dict 4 | 5 | from common.math_torch import se3 6 | from common.torch import to_numpy 7 | 8 | 9 | 10 | def compute_metrics( data: Dict, pred_transforms ) -> Dict: 11 | 12 | with torch.no_grad(): 13 | gt_transforms = data['transform_gt'] 14 | 15 | # Rotation, translation errors (isotropic, i.e. doesn't depend on error 16 | # direction, which is more representative of the actual error) 17 | concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms) 18 | rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2] 19 | residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi 20 | residual_transmag = concatenated[:, :, 3].norm(dim=-1) 21 | 22 | 23 | metrics = { 24 | 'err_r_deg': to_numpy(residual_rotdeg), 25 | 'err_t': to_numpy(residual_transmag) 26 | } 27 | 28 | return metrics 29 | 30 | def summarize_metrics(metrics, rot_thres, trans_thres ): 31 | """Summaries computed metrices by taking mean over all data instances""" 32 | summarized = {} 33 | success_list = np.zeros((len(metrics['err_t'])), dtype=np.int32 ) 34 | success_list[ metrics['err_r_deg'] < rot_thres] = 1 35 | success_list[ metrics['err_t'] > trans_thres] = 0 36 | metrics['err_r_deg_right'] = metrics['err_r_deg'][success_list==1] 37 | metrics['err_t_right'] = metrics['err_t'][success_list==1] 38 | success_rate = np.sum(success_list, dtype=np.float32) / len(success_list) 39 | summarized['success_rate'] = success_rate 40 | for k in metrics: 41 | if k.startswith('err'): 42 | summarized[k + '_mean'] = np.mean(metrics[k]) 43 | summarized[k + '_std'] = np.std(metrics[k]) 44 | else: 45 | summarized[k] = np.mean(metrics[k]) 46 | 47 | return summarized 48 | 49 | def print_metrics( logger, summary_metrics: Dict, title: str = 'Metrics' ): 50 | logger.info( title + ':' ) 51 | logger.info('=' * (len(title) + 1)) 52 | 53 | logger.info('Rotation error {:.4f}(deg, mean)+/-{:.4f}(std)'.format( 54 | summary_metrics['err_r_deg_mean'], summary_metrics['err_r_deg_std'])) 55 | logger.info('Translation error {:.4g}(mean)+/-{:.4g}(std)'.format( 56 | summary_metrics['err_t_mean'], summary_metrics['err_t_std'])) 57 | logger.info('error rotation in success {:.4f}(deg, mean)+/-{:.4f}(std)'.format( 58 | summary_metrics['err_r_deg_right_mean'], summary_metrics['err_r_deg_right_std']) ) 59 | logger.info('error translation in success {:.4g}(mean)+/-{:.4g}(std)'.format( 60 | summary_metrics['err_t_right_mean'], summary_metrics['err_t_right_std']) ) 61 | logger.info('success_rate {:.4g}'.format(summary_metrics['success_rate'])) 62 | -------------------------------------------------------------------------------- /models/PointUtils/points_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | import torch.nn as nn 6 | 7 | from typing import Union 8 | 9 | import point_utils_cuda 10 | 11 | class FurthestPointSampling(Function): 12 | @staticmethod 13 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 14 | ''' 15 | ctx: 16 | xyz: [B,N,3] 17 | npoint: int 18 | ''' 19 | assert xyz.is_contiguous() 20 | 21 | B, N, _ = xyz.size() 22 | output = torch.cuda.IntTensor(B, npoint) 23 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 24 | 25 | point_utils_cuda.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 26 | return output 27 | 28 | @staticmethod 29 | def backward(xyz, a=None): 30 | return None, None 31 | 32 | furthest_point_sample = FurthestPointSampling.apply 33 | 34 | class WeightedFurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz: torch.Tensor, weights: torch.Tensor, npoint: int) -> torch.Tensor: 37 | ''' 38 | ctx: 39 | xyz: [B,N,3] 40 | weights: [B,N] 41 | npoint: int 42 | ''' 43 | assert xyz.is_contiguous() 44 | assert weights.is_contiguous() 45 | B, N, _ = xyz.size() 46 | output = torch.cuda.IntTensor(B, npoint) 47 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 48 | 49 | point_utils_cuda.weighted_furthest_point_sampling_wrapper(B, N, npoint, xyz, weights, temp, output); 50 | return output 51 | 52 | @staticmethod 53 | def backward(xyz, a=None): 54 | return None, None 55 | 56 | weighted_furthest_point_sample = WeightedFurthestPointSampling.apply 57 | 58 | class GatherOperation(Function): 59 | @staticmethod 60 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 61 | ''' 62 | ctx 63 | features: [B,C,N] 64 | idx: [B,npoint] 65 | ''' 66 | assert features.is_contiguous() 67 | assert idx.is_contiguous() 68 | 69 | B, npoint = idx.size() 70 | _, C, N = features.size() 71 | output = torch.cuda.FloatTensor(B, C, npoint) 72 | 73 | point_utils_cuda.gather_points_wrapper(B, C, N, npoint, features, idx, output) 74 | 75 | ctx.for_backwards = (idx, C, N) 76 | return output 77 | 78 | @staticmethod 79 | def backward(ctx, grad_out): 80 | idx, C, N = ctx.for_backwards 81 | B, npoint = idx.size() 82 | grad_features = Variable(torch.cuda.FloatTensor(B,C,N).zero_()) 83 | grad_out_data = grad_out.data.contiguous() 84 | point_utils_cuda.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 85 | return grad_features, None 86 | 87 | gather_operation = GatherOperation.apply 88 | -------------------------------------------------------------------------------- /models/PointUtils/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='point_utils', 6 | ext_modules=[ 7 | CUDAExtension('point_utils_cuda', [ 8 | 'src/point_utils_api.cpp', 9 | 10 | 'src/furthest_point_sampling.cpp', 11 | 'src/furthest_point_sampling_gpu.cu', 12 | ], 13 | extra_compile_args={ 14 | 'cxx':['-g'], 15 | 'nvcc': ['-O2'] 16 | }) 17 | ], 18 | cmdclass={'build_ext':BuildExtension} 19 | ) -------------------------------------------------------------------------------- /models/PointUtils/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 1024 8 | #define THREADS_PER_BLOCK 256 9 | #define DIVUP(m,n) ((m) / (n)+((m) % (n) > 0)) 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.") 13 | #define CHECK_CONTIGUOUS_CUDA(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | /*** 18 | * calculate proper thread number 19 | * If work_size < TOTAL_THREADS, number = work_size (2^n) 20 | * Else number = TOTAL_THREADS 21 | ***/ 22 | inline int opt_n_threads(int work_size) { 23 | // log2(work_size) 24 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 25 | // 1 * 2^(pow_2) 26 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 27 | } 28 | 29 | #endif -------------------------------------------------------------------------------- /models/PointUtils/src/furthest_point_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "furthest_point_sampling_gpu.h" 7 | 8 | // extern THCState *state; 9 | 10 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 12 | const float *points = points_tensor.data(); 13 | const int *idx = idx_tensor.data(); 14 | float *out = out_tensor.data(); 15 | 16 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 17 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 18 | return 1; 19 | } 20 | 21 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 22 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 23 | 24 | const float *grad_out = grad_out_tensor.data(); 25 | const int *idx = idx_tensor.data(); 26 | float *grad_points = grad_points_tensor.data(); 27 | 28 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 29 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 30 | return 1; 31 | } 32 | 33 | int furthest_point_sampling_wrapper(int b, int n, int m, 34 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 35 | 36 | const float *points = points_tensor.data(); 37 | float *temp = temp_tensor.data(); 38 | int *idx = idx_tensor.data(); 39 | 40 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 41 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 42 | return 1; 43 | } 44 | 45 | int weighted_furthest_point_sampling_wrapper(int b, int n, int m, 46 | at::Tensor points_tensor, at::Tensor weights_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 47 | 48 | const float *points = points_tensor.data(); 49 | const float *weights = weights_tensor.data(); 50 | float *temp = temp_tensor.data(); 51 | int *idx = idx_tensor.data(); 52 | 53 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 54 | weighted_furthest_point_sampling_kernel_launcher(b, n, m, points, weights, temp, idx, stream); 55 | return 1; 56 | } -------------------------------------------------------------------------------- /models/PointUtils/src/furthest_point_sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "furthest_point_sampling_gpu.h" 6 | 7 | __global__ void gather_points_kernel_fast(int b, int c, int n, int m, 8 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 9 | // points: [B,C,N] 10 | // idx: [B,M] 11 | 12 | int bs_idx = blockIdx.z; 13 | int c_idx = blockIdx.y; 14 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 15 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 16 | // Pointer to current point 17 | out += bs_idx * c * m + c_idx * m + pt_idx; // curr batch + channels + points 18 | idx += bs_idx * m + pt_idx; // curr batch + points 19 | points += bs_idx * c * n + c_idx * n; // batch + channels 20 | out[0] = points[idx[0]]; // curr batch channels -> channel of curr point ? 21 | } 22 | 23 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 24 | const float *points, const int *idx, float *out, cudaStream_t stream) { 25 | // points: [B,C,N] 26 | // idx: [B,npoints] 27 | cudaError_t err; 28 | // dim3 is a type to assign dimension 29 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); // DIVUP: npoints/THREADS_PER_BLOCK 30 | dim3 threads(THREADS_PER_BLOCK); // others assign to 1 31 | 32 | gather_points_kernel_fast<<>>(b, c, n, npoints, points, idx, out); 33 | 34 | err = cudaGetLastError(); 35 | if (cudaSuccess != err) { 36 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 37 | exit(-1); 38 | } 39 | } 40 | 41 | __global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, 42 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 43 | // grad_out: [B,C,M] 44 | // idx: [B,M] 45 | int bs_idx = blockIdx.z; 46 | int c_idx = blockIdx.y; 47 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 48 | if (bs_idx > b || c_idx >= c || pt_idx >= m) return; 49 | 50 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 51 | idx += bs_idx * m + pt_idx; 52 | grad_points += bs_idx * c * n + c_idx * n; 53 | 54 | atomicAdd(grad_points + idx[0], grad_out[0]); // assign the grad of indexed value to grad_points 55 | } 56 | 57 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 58 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 59 | // grad_out: [B,C, npoints] 60 | // idx: [B, npoints] 61 | 62 | cudaError_t err; 63 | dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b); 64 | dim3 threads(THREADS_PER_BLOCK); 65 | 66 | gather_points_grad_kernel_fast<<>>(b, c, n, npoints, grad_out, idx, grad_points); 67 | 68 | err = cudaGetLastError(); 69 | if (cudaSuccess != err) { 70 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 71 | exit(-1); 72 | } 73 | } 74 | 75 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2) { 76 | const float v1 = dists[idx1], v2 = dists[idx2]; 77 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 78 | dists[idx1] = max(v1, v2); 79 | dists_i[idx1] = v2 > v1 ? i2 : i1; 80 | } 81 | 82 | // A kernel runs on single thread and the launcher is defined to launch the kernel 83 | // Grid size and block size are all defined in the launcher 84 | template 85 | __global__ void furthest_point_sampling_kernel(int b, int n, int m, 86 | const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) { 87 | // dataset [B,N,3] 88 | // temp: [B,N] 89 | // idxs: 90 | // All global memory 91 | 92 | if (m <= 0) return; 93 | // assign shared memory 94 | __shared__ float dists[block_size]; 95 | __shared__ int dists_i[block_size]; 96 | 97 | int batch_index = blockIdx.x; 98 | // Point to curr batch (blockIdx of current thread of this kernel) 99 | dataset += batch_index * n * 3; 100 | temp += batch_index * n; 101 | idxs += batch_index * m; 102 | 103 | // threadIdx of current thread 104 | int tid = threadIdx.x; 105 | const int stride = block_size; // number of threads in one block 106 | 107 | int old = 0; 108 | if (threadIdx.x == 0) 109 | idxs[0] = old; // Initialize index 110 | 111 | __syncthreads(); 112 | // for loop m for m sampled points 113 | for (int j = 1; j < m; j++) { 114 | // printf("curr index: %d\n", j); 115 | int besti = 0; 116 | float best = -1; 117 | // Coordinate of last point 118 | float x1 = dataset[old * 3 + 0]; 119 | float y1 = dataset[old * 3 + 1]; 120 | float z1 = dataset[old * 3 + 2]; 121 | // Get global index, parallel calculate distance with multiple blocks 122 | for (int k = tid; k < n; k += stride) { 123 | // calculate distance with the other point 124 | float x2, y2, z2; 125 | x2 = dataset[k * 3 + 0]; 126 | y2 = dataset[k * 3 + 1]; 127 | z2 = dataset[k * 3 + 2]; 128 | 129 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 130 | float d2 = min(d, temp[k]); 131 | temp[k] = d2; // update temp distance 132 | besti = d2 > best ? k : besti; // If d2 > best, besti = k (idx) 133 | best = d2 > best ? d2 : best; // If d2 > best, best = d2 (distance) 134 | } 135 | // dists[tid] stores the largest dist over all blocks for the current threadIdx 136 | dists[tid] = best; 137 | dists_i[tid] = besti; 138 | __syncthreads(); // wait for all threads finishing compute the distance 139 | // calculate the idx of largest distance ? 140 | if (block_size >= 1024) { 141 | if (tid < 512) { 142 | __update(dists, dists_i, tid, tid + 512); 143 | } 144 | __syncthreads(); 145 | } 146 | if (block_size >= 512) { 147 | if (tid < 256) { 148 | __update(dists, dists_i, tid, tid + 256); 149 | } 150 | __syncthreads(); 151 | } 152 | if (block_size >= 256) { 153 | if (tid < 128) { 154 | __update(dists, dists_i, tid, tid + 128); 155 | } 156 | __syncthreads(); 157 | } 158 | if (block_size >= 128) { 159 | if (tid < 64) { 160 | __update(dists, dists_i, tid, tid + 64); 161 | } 162 | __syncthreads(); 163 | } 164 | if (block_size >= 64) { 165 | if (tid < 32) { 166 | __update(dists, dists_i, tid, tid + 32); 167 | } 168 | __syncthreads(); 169 | } 170 | if (block_size >= 32) { 171 | if (tid < 16) { 172 | __update(dists, dists_i, tid, tid + 16); 173 | } 174 | __syncthreads(); 175 | } 176 | if (block_size >= 16) { 177 | if (tid < 8) { 178 | __update(dists, dists_i, tid, tid + 8); 179 | } 180 | __syncthreads(); 181 | } 182 | if (block_size >= 8) { 183 | if (tid < 4) { 184 | __update(dists, dists_i, tid, tid + 4); 185 | } 186 | __syncthreads(); 187 | } 188 | if (block_size >= 4) { 189 | if (tid < 2) { 190 | __update(dists, dists_i, tid, tid + 2); 191 | } 192 | __syncthreads(); 193 | } 194 | if (block_size >= 2) { 195 | if (tid < 1) { 196 | __update(dists, dists_i, tid, tid + 1); 197 | } 198 | __syncthreads(); 199 | } 200 | 201 | // All threads update a single new point (old). 202 | old = dists_i[0]; // update last point index 203 | if (tid == 0) 204 | idxs[j] = old; 205 | } 206 | } 207 | 208 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 209 | const float *dataset, float *temp, int *idxs, cudaStream_t stream) { 210 | // dataset: [B,N,3] 211 | // tmp: [B,N] 212 | 213 | cudaError_t err; 214 | unsigned int n_threads = opt_n_threads(n); // compute proper thread number 215 | 216 | switch (n_threads) { 217 | // Call kernel functions: Func 218 | // Dg: grid size (how many blocks in the grid) 219 | // Db: block size (how many threads in the block) 220 | // Ns: memory for shared value, default 0 221 | // s: stream 222 | case 1024: 223 | furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, temp, idxs); break; 224 | case 512: 225 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); break; 226 | case 256: 227 | furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, temp, idxs); break; 228 | case 128: 229 | furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, temp, idxs); break; 230 | case 64: 231 | furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, temp, idxs); break; 232 | case 32: 233 | furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, temp, idxs); break; 234 | case 16: 235 | furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, temp, idxs); break; 236 | case 8: 237 | furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, temp, idxs); break; 238 | case 4: 239 | furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, temp, idxs); break; 240 | case 2: 241 | furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, temp, idxs); break; 242 | case 1: 243 | furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, temp, idxs); break; 244 | default: 245 | furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, temp, idxs); 246 | } 247 | err = cudaGetLastError(); 248 | if (cudaSuccess != err) { 249 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 250 | exit(-1); 251 | } 252 | } 253 | 254 | template 255 | __global__ void weighted_furthest_point_sampling_kernel(int b, int n, int m, 256 | const float *__restrict__ dataset, const float *__restrict__ weights, float *__restrict__ temp, int *__restrict__ idxs) { 257 | // dataset: [B,N,3] 258 | // weights: [B,N] 259 | // temp: [B,N] 260 | 261 | if (m <= 0) return; 262 | 263 | __shared__ float dists[block_size]; 264 | __shared__ int dists_i[block_size]; 265 | 266 | int batch_index = blockIdx.x; 267 | dataset += batch_index * n * 3; 268 | weights += batch_index * n; 269 | temp += batch_index * n; 270 | idxs += batch_index * m; 271 | 272 | int tid = threadIdx.x; 273 | const int stride = block_size; 274 | 275 | int old = 0; 276 | if (threadIdx.x == 0) 277 | idxs[0] = old; 278 | 279 | __syncthreads(); 280 | 281 | for (int j = 1; j < m; j++) { 282 | 283 | int besti = 0; 284 | float best = -1; 285 | 286 | float x1 = dataset[old * 3 + 0]; 287 | float y1 = dataset[old * 3 + 1]; 288 | float z1 = dataset[old * 3 + 2]; 289 | 290 | float w1 = weights[old]; 291 | 292 | for (int k = tid; k < n; k += stride) { 293 | float x2, y2, z2, w2; 294 | x2 = dataset[k * 3 + 0]; 295 | y2 = dataset[k * 3 + 1]; 296 | z2 = dataset[k * 3 + 2]; 297 | w2 = weights[k]; 298 | 299 | float d = w2 * ((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1)); 300 | float d2 = min(d, temp[k]); 301 | temp[k] = d2; 302 | besti = d2 > best ? k : besti; 303 | best = d2 > best ? d2 : best; 304 | } 305 | dists[tid] = best; 306 | dists_i[tid] = besti; 307 | __syncthreads(); 308 | 309 | if (block_size >= 1024) { 310 | if (tid < 512) { 311 | __update(dists, dists_i, tid, tid + 512); 312 | } 313 | __syncthreads(); 314 | } 315 | if (block_size >= 512) { 316 | if (tid < 256) { 317 | __update(dists, dists_i, tid, tid + 256); 318 | } 319 | __syncthreads(); 320 | } 321 | if (block_size >= 256) { 322 | if (tid < 128) { 323 | __update(dists, dists_i, tid, tid + 128); 324 | } 325 | __syncthreads(); 326 | } 327 | if (block_size >= 128) { 328 | if (tid < 64) { 329 | __update(dists, dists_i, tid, tid + 64); 330 | } 331 | __syncthreads(); 332 | } 333 | if (block_size >= 64) { 334 | if (tid < 32) { 335 | __update(dists, dists_i, tid, tid + 32); 336 | } 337 | __syncthreads(); 338 | } 339 | if (block_size >= 32) { 340 | if (tid < 16) { 341 | __update(dists, dists_i, tid, tid + 16); 342 | } 343 | __syncthreads(); 344 | } 345 | if (block_size >= 16) { 346 | if (tid < 8) { 347 | __update(dists, dists_i, tid, tid + 8); 348 | } 349 | __syncthreads(); 350 | } 351 | if (block_size >= 8) { 352 | if (tid < 4) { 353 | __update(dists, dists_i, tid, tid + 4); 354 | } 355 | __syncthreads(); 356 | } 357 | if (block_size >= 4) { 358 | if (tid < 2) { 359 | __update(dists, dists_i, tid, tid + 2); 360 | } 361 | __syncthreads(); 362 | } 363 | if (block_size >= 2) { 364 | if (tid < 1) { 365 | __update(dists, dists_i, tid, tid + 1); 366 | } 367 | __syncthreads(); 368 | } 369 | 370 | // All threads update a single new point (old). 371 | old = dists_i[0]; // update last point index 372 | if (tid == 0) 373 | idxs[j] = old; 374 | } 375 | } 376 | 377 | void weighted_furthest_point_sampling_kernel_launcher(int b, int n, int m, 378 | const float *dataset, const float *weights, float *temp, int *idxs, cudaStream_t stream) { 379 | 380 | cudaError_t err; 381 | unsigned int n_threads = opt_n_threads(n); // compute proper thread numbere 382 | 383 | switch (n_threads) { 384 | // Call kernel functions: Func 385 | // Dg: grid size (how many blocks in the grid) 386 | // Db: block size (how many threads in the block) 387 | // Ns: memory for shared value, default 0 388 | // s: stream 389 | case 1024: 390 | weighted_furthest_point_sampling_kernel<1024><<>>(b, n, m, dataset, weights, temp, idxs); break; 391 | case 512: 392 | weighted_furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, weights, temp, idxs); break; 393 | case 256: 394 | weighted_furthest_point_sampling_kernel<256><<>>(b, n, m, dataset, weights, temp, idxs); break; 395 | case 128: 396 | weighted_furthest_point_sampling_kernel<128><<>>(b, n, m, dataset, weights, temp, idxs); break; 397 | case 64: 398 | weighted_furthest_point_sampling_kernel<64><<>>(b, n, m, dataset, weights, temp, idxs); break; 399 | case 32: 400 | weighted_furthest_point_sampling_kernel<32><<>>(b, n, m, dataset, weights, temp, idxs); break; 401 | case 16: 402 | weighted_furthest_point_sampling_kernel<16><<>>(b, n, m, dataset, weights, temp, idxs); break; 403 | case 8: 404 | weighted_furthest_point_sampling_kernel<8><<>>(b, n, m, dataset, weights, temp, idxs); break; 405 | case 4: 406 | weighted_furthest_point_sampling_kernel<4><<>>(b, n, m, dataset, weights, temp, idxs); break; 407 | case 2: 408 | weighted_furthest_point_sampling_kernel<2><<>>(b, n, m, dataset, weights, temp, idxs); break; 409 | case 1: 410 | weighted_furthest_point_sampling_kernel<1><<>>(b, n, m, dataset, weights, temp, idxs); break; 411 | default: 412 | weighted_furthest_point_sampling_kernel<512><<>>(b, n, m, dataset, weights, temp, idxs); 413 | } 414 | err = cudaGetLastError(); 415 | if (cudaSuccess != err) { 416 | fprintf(stderr, "CUDA kernel failed: %s\n", cudaGetErrorString(err)); 417 | exit(-1); 418 | } 419 | } -------------------------------------------------------------------------------- /models/PointUtils/src/furthest_point_sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _FURTHEST_POINT_SAMPLING_H 2 | #define _FURTHEST_POINT_SAMPLING_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 9 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 10 | 11 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 12 | const float *points, const int *idx, float *out, cudaStream_t stream); 13 | 14 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 15 | at::Tensor grad_out_Tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 16 | 17 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 18 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 19 | 20 | int furthest_point_sampling_wrapper(int b, int n, int m, 21 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 22 | 23 | int weighted_furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor weights_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | void weighted_furthest_point_sampling_kernel_launcher(int b, int n, int m, 30 | const float *dataset, const float *weights, float *temp, int *idxs, cudaStream_t stream); 31 | 32 | #endif -------------------------------------------------------------------------------- /models/PointUtils/src/point_utils_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "furthest_point_sampling_gpu.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | 8 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 9 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 10 | 11 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 12 | m.def("weighted_furthest_point_sampling_wrapper", &weighted_furthest_point_sampling_wrapper, "weighted_furthest_point_sampling_wrapper"); 13 | } -------------------------------------------------------------------------------- /models/RandLA_Net.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from torch_points import knn 8 | except (ModuleNotFoundError, ImportError): 9 | from torch_points_kernels import knn 10 | 11 | class SharedMLP(nn.Module): 12 | def __init__( 13 | self, 14 | in_channels, 15 | out_channels, 16 | kernel_size=1, 17 | stride=1, 18 | transpose=False, 19 | padding_mode='zeros', 20 | bn=False, 21 | activation_fn=None 22 | ): 23 | super(SharedMLP, self).__init__() 24 | 25 | conv_fn = nn.ConvTranspose2d if transpose else nn.Conv2d 26 | 27 | self.conv = conv_fn( 28 | in_channels, 29 | out_channels, 30 | kernel_size, 31 | stride=stride, 32 | padding_mode=padding_mode 33 | ) 34 | self.batch_norm = nn.BatchNorm2d(out_channels, eps=1e-6, momentum=0.99) if bn else None 35 | self.activation_fn = activation_fn 36 | 37 | def forward(self, input): 38 | r""" 39 | Forward pass of the network 40 | 41 | Parameters 42 | ---------- 43 | input: torch.Tensor, shape (B, d_in, N, K) 44 | 45 | Returns 46 | ------- 47 | torch.Tensor, shape (B, d_out, N, K) 48 | """ 49 | x = self.conv(input) 50 | if self.batch_norm: 51 | x = self.batch_norm(x) 52 | if self.activation_fn: 53 | x = self.activation_fn(x) 54 | return x 55 | 56 | 57 | class LocalSpatialEncoding(nn.Module): 58 | def __init__(self, d, num_neighbors ): 59 | super(LocalSpatialEncoding, self).__init__() 60 | 61 | self.num_neighbors = num_neighbors 62 | self.mlp = SharedMLP(10, d, bn=True, activation_fn=nn.ReLU()) 63 | 64 | 65 | def forward(self, coords, features, knn_output): 66 | r""" 67 | Forward pass 68 | 69 | Parameters 70 | ---------- 71 | coords: torch.Tensor, shape (B, N, 3) 72 | coordinates of the point cloud 73 | features: torch.Tensor, shape (B, d, N, 1) 74 | features of the point cloud 75 | neighbors: tuple 76 | 77 | Returns 78 | ------- 79 | torch.Tensor, shape (B, 2*d, N, K) 80 | """ 81 | # finding neighboring points 82 | idx, dist = knn_output 83 | B, N, K = idx.size() 84 | # idx(B, N, K), coords(B, N, 3) 85 | # neighbors[b, i, n, k] = coords[b, idx[b, n, k], i] = extended_coords[b, i, extended_idx[b, i, n, k], k] 86 | extended_idx = idx.unsqueeze(1).expand(B, 3, N, K) 87 | extended_coords = coords.transpose(-2,-1).unsqueeze(-1).expand(B, 3, N, K) 88 | neighbors = torch.gather(extended_coords, 2, extended_idx) # shape (B, 3, N, K) 89 | # if USE_CUDA: 90 | # neighbors = neighbors.cuda() 91 | 92 | # relative point position encoding 93 | concat = torch.cat(( 94 | extended_coords, 95 | neighbors, 96 | extended_coords - neighbors, 97 | dist.unsqueeze(-3) 98 | ), dim=-3) #.to(self.device) 99 | return torch.cat(( 100 | self.mlp(concat), 101 | features.expand(B, -1, N, K) 102 | ), dim=-3) 103 | 104 | 105 | 106 | class AttentivePooling(nn.Module): 107 | def __init__(self, in_channels, out_channels): 108 | super(AttentivePooling, self).__init__() 109 | 110 | self.score_fn = nn.Sequential( 111 | nn.Linear(in_channels, in_channels, bias=False), 112 | nn.Softmax(dim=-2) 113 | ) 114 | self.mlp = SharedMLP(in_channels, out_channels, bn=True, activation_fn=nn.ReLU()) 115 | 116 | def forward(self, x): 117 | r""" 118 | Forward pass 119 | 120 | Parameters 121 | ---------- 122 | x: torch.Tensor, shape (B, d_in, N, K) 123 | 124 | Returns 125 | ------- 126 | torch.Tensor, shape (B, d_out, N, 1) 127 | """ 128 | # computing attention scores 129 | scores = self.score_fn(x.permute(0,2,3,1)).permute(0,3,1,2) 130 | 131 | # sum over the neighbors 132 | features = torch.sum(scores * x, dim=-1, keepdim=True) # shape (B, d_in, N, 1) 133 | 134 | return self.mlp(features) 135 | 136 | 137 | 138 | class LocalFeatureAggregation(nn.Module): 139 | def __init__(self, d_in, d_out, num_neighbors): 140 | super(LocalFeatureAggregation, self).__init__() 141 | 142 | self.num_neighbors = num_neighbors 143 | 144 | self.mlp1 = SharedMLP(d_in, d_out//2, activation_fn=nn.LeakyReLU(0.2)) 145 | self.mlp2 = SharedMLP(d_out, 2*d_out) 146 | self.shortcut = SharedMLP(d_in, 2*d_out, bn=True) 147 | 148 | self.lse1 = LocalSpatialEncoding(d_out//2, num_neighbors) 149 | self.lse2 = LocalSpatialEncoding(d_out//2, num_neighbors) 150 | 151 | self.pool1 = AttentivePooling(d_out, d_out//2) 152 | self.pool2 = AttentivePooling(d_out, d_out) 153 | 154 | self.lrelu = nn.LeakyReLU() 155 | 156 | def forward(self, coords, features): 157 | r""" 158 | Forward pass 159 | 160 | Parameters 161 | ---------- 162 | coords: torch.Tensor, shape (B, N, 3) 163 | coordinates of the point cloud 164 | features: torch.Tensor, shape (B, d_in, N, 1) 165 | features of the point cloud 166 | 167 | Returns 168 | ------- 169 | torch.Tensor, shape (B, 2*d_out, N, 1) 170 | """ 171 | 172 | idx, dist = knn(coords.cpu().contiguous(), coords.cpu().contiguous(), self.num_neighbors) 173 | knn_output = idx.to(coords.device) , dist.to(coords.device) 174 | 175 | x = self.mlp1(features) 176 | 177 | x = self.lse1(coords, x, knn_output) 178 | x = self.pool1(x) 179 | 180 | x = self.lse2(coords, x, knn_output) 181 | x = self.pool2(x) 182 | 183 | return self.lrelu(self.mlp2(x) + self.shortcut(features)) 184 | 185 | 186 | 187 | class RandLANet(nn.Module): 188 | def __init__(self, d_in, num_classes, num_neighbors=16, decimation=4): 189 | super(RandLANet, self).__init__() 190 | self.num_neighbors = num_neighbors 191 | self.decimation = decimation 192 | 193 | self.fc_start = nn.Linear(d_in, 8) 194 | self.bn_start = nn.Sequential( 195 | nn.BatchNorm2d(8, eps=1e-6, momentum=0.99), 196 | nn.LeakyReLU(0.2) 197 | ) 198 | 199 | # encoding layers 200 | self.encoder = nn.ModuleList([ 201 | LocalFeatureAggregation(8, 16, num_neighbors), 202 | LocalFeatureAggregation(32, 64, num_neighbors), 203 | LocalFeatureAggregation(128, 128, num_neighbors), 204 | LocalFeatureAggregation(256, 256, num_neighbors) 205 | ]) 206 | 207 | self.mlp = SharedMLP(512, 512, activation_fn=nn.ReLU()) 208 | 209 | # decoding layers 210 | decoder_kwargs = dict( 211 | transpose=True, 212 | bn=True, 213 | activation_fn=nn.ReLU() 214 | ) 215 | self.decoder = nn.ModuleList([ 216 | SharedMLP(1024, 256, **decoder_kwargs), 217 | SharedMLP(512, 128, **decoder_kwargs), 218 | SharedMLP(256, 32, **decoder_kwargs), 219 | SharedMLP(64, 8, **decoder_kwargs) 220 | ]) 221 | 222 | # final semantic prediction 223 | self.fc_end = nn.Sequential( 224 | SharedMLP(8, 64, bn=True, activation_fn=nn.ReLU()), 225 | SharedMLP(64, 32, bn=True, activation_fn=nn.ReLU()), 226 | nn.Dropout(), 227 | SharedMLP(32, num_classes) 228 | ) 229 | #self.device = device 230 | 231 | #self = self.to(device) 232 | 233 | def forward(self, input): 234 | r""" 235 | Forward pass 236 | 237 | Parameters 238 | ---------- 239 | input: torch.Tensor, shape (B, N, d_in) 240 | input points 241 | 242 | Returns 243 | ------- 244 | torch.Tensor, shape (B, num_classes, N) 245 | segmentation scores for each point 246 | """ 247 | N = input.size(1) 248 | d = self.decimation 249 | 250 | coords = input[...,:3] 251 | x = self.fc_start(input).transpose(-2,-1).unsqueeze(-1) 252 | x = self.bn_start(x) # shape (B, d, N, 1) 253 | 254 | decimation_ratio = 1 255 | 256 | # <<<<<<<<<< ENCODER 257 | x_stack = [] 258 | 259 | permutation = torch.randperm(N).to( input.device ) 260 | coords = coords[:,permutation] 261 | x = x[:,:,permutation] 262 | 263 | for lfa in self.encoder: 264 | # at iteration i, x.shape = (B, N//(d**i), d_in) 265 | x = lfa(coords[:,:N//decimation_ratio], x) 266 | x_stack.append(x.clone()) 267 | decimation_ratio *= d 268 | x = x[:,:,:N//decimation_ratio] 269 | 270 | 271 | # # >>>>>>>>>> ENCODER 272 | 273 | x = self.mlp(x) 274 | 275 | # <<<<<<<<<< DECODER 276 | for mlp in self.decoder: 277 | neighbors, _ = knn( 278 | coords[:,:N//decimation_ratio].cpu().contiguous(), # original set 279 | coords[:,:d*N//decimation_ratio].cpu().contiguous(), # upsampled set 280 | 1 281 | ) # shape (B, N, 1) 282 | neighbors = neighbors.to(input.device) 283 | 284 | extended_neighbors = neighbors.unsqueeze(1).expand(-1, x.size(1), -1, 1) 285 | 286 | x_neighbors = torch.gather(x, -2, extended_neighbors) 287 | 288 | x = torch.cat((x_neighbors, x_stack.pop()), dim=1) 289 | 290 | x = mlp(x) 291 | 292 | decimation_ratio //= d 293 | 294 | # >>>>>>>>>> DECODER 295 | # inverse permutation 296 | x = x[:,:,torch.argsort(permutation)] 297 | 298 | scores = self.fc_end(x) 299 | 300 | return scores.squeeze(-1) 301 | 302 | 303 | if __name__ == '__main__': 304 | import time 305 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 306 | 307 | d_in = 7 308 | cloud = 1000*torch.randn(1, 2**16, d_in).to(device) 309 | model = RandLANet(d_in, 6, 16, 4, device) 310 | # model.load_state_dict(torch.load('checkpoints/checkpoint_100.pth')) 311 | model.eval() 312 | 313 | t0 = time.time() 314 | pred = model(cloud) 315 | t1 = time.time() 316 | # print(pred) 317 | print(t1-t0) 318 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | 5 | def MLP(channels: list, do_bn=True): 6 | """ Multi-layer perceptron """ 7 | n = len(channels) 8 | layers = [] 9 | for i in range(1, n): 10 | layers.append( 11 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 12 | if i < (n-1): 13 | if do_bn: 14 | layers.append(nn.InstanceNorm1d(channels[i])) 15 | layers.append(nn.ReLU()) 16 | return nn.Sequential(*layers) 17 | 18 | 19 | def attention(query, key, value): 20 | dim = query.shape[1] 21 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 22 | prob = torch.nn.functional.softmax(scores, dim=-1) 23 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 24 | 25 | 26 | class MultiHeadedAttention(nn.Module): 27 | """ Multi-head attention to increase model expressivitiy """ 28 | def __init__(self, num_heads: int, d_model: int): 29 | super().__init__() 30 | assert d_model % num_heads == 0 31 | self.dim = d_model // num_heads 32 | self.num_heads = num_heads 33 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 34 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 35 | 36 | def forward(self, query, key, value): 37 | batch_dim = query.size(0) 38 | query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) 39 | for l, x in zip(self.proj, (query, key, value))] 40 | x, _ = attention(query, key, value) 41 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) 42 | 43 | class AttentionalPropagation(nn.Module): 44 | def __init__(self, feature_dim: int, num_heads: int): 45 | super().__init__() 46 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 47 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) 48 | nn.init.constant_(self.mlp[-1].bias, 0.0) 49 | 50 | def forward(self, x, source): 51 | message = self.attn(x, source, source) 52 | return self.mlp(torch.cat([x, message], dim=1)) 53 | 54 | if __name__ == '__main__': 55 | q, k, v = torch.randn((8, 32, 1024)), torch.randn((8, 32, 1024)), \ 56 | torch.randn((8, 32, 1024)) 57 | AP = AttentionalPropagation( 32, 4 ) 58 | x = AP( q, k ) 59 | print(x.shape) 60 | -------------------------------------------------------------------------------- /models/compute_rigid_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sympy.matrices import Matrix, GramSchmidt 3 | import torch 4 | import torch.nn.functional as F 5 | import torch 6 | import torch.nn as nn 7 | 8 | _EPS = 1e-5 # To prevent division by zero 9 | 10 | def svd(src: torch.Tensor, ref: torch.Tensor, permutation: torch.Tensor): 11 | """Compute rigid transforms between two point sets 12 | 13 | Args: 14 | src (torch.Tensor): (B, M, 3) points 15 | ref (torch.Tensor): (B, N, 3) points 16 | permutation (torch.Tensor): (B, M, N) 17 | 18 | Returns: 19 | Transform T (B, 3, 4) to get from src to ref, i.e. T*src = ref 20 | """ 21 | 22 | ref_perm = torch.bmm( permutation, ref ) 23 | center_src = torch.mean( src, dim=1, keepdim=True ) 24 | center_ref = torch.mean( ref_perm, dim=1, keepdim=True ) 25 | src_c = src - center_src 26 | ref_c = ref_perm - center_ref 27 | 28 | H = torch.bmm( src_c.transpose( 1, 2 ), ref_c ) 29 | u, s, v = torch.svd( H, some=False ) 30 | R_pos = torch.bmm( v, u.transpose(1, 2) ) 31 | v_neg = v.clone() 32 | v_neg[:, :, 2] *=-1 33 | R_neg = torch.bmm( v_neg, u.transpose(1, 2) ) 34 | R = torch.where( torch.det(R_pos)[:, None, None] > 0, R_pos, R_neg ) 35 | assert torch.all(torch.det(R) > 0) 36 | 37 | T = center_ref.transpose(1,2) - torch.bmm( R, center_src.transpose(1,2) ) 38 | transform = torch.cat((R, T), dim=2) 39 | return transform 40 | 41 | def weighted_svd( src: torch.Tensor, ref: torch.Tensor, 42 | weights: torch.Tensor, permutation: torch.Tensor = None ): 43 | sum_weights = torch.sum(weights,dim=1,keepdim=True) + _EPS 44 | weights = weights/sum_weights 45 | weights = weights.unsqueeze(2) 46 | 47 | ref_perm = ref 48 | if permutation != None: 49 | ref_perm = torch.bmm( permutation, ref_perm ) 50 | 51 | src_mean = torch.matmul(weights.transpose(1,2),src)/(torch.sum(weights,dim=1).unsqueeze(1)+_EPS) 52 | src_corres_mean = torch.matmul(weights.transpose(1,2),ref_perm)/(torch.sum(weights,dim=1).unsqueeze(1)+_EPS) 53 | src_centered = src - src_mean # [B,N,3] 54 | src_corres_centered = ref_perm - src_corres_mean # [B,N,3] 55 | weight_matrix = torch.diag_embed(weights.squeeze(2)) 56 | 57 | cov_mat = torch.matmul(src_centered.transpose(1,2),torch.matmul(weight_matrix,src_corres_centered)) 58 | try: 59 | u, s, v = torch.svd(cov_mat) 60 | except Exception as e: 61 | r = torch.eye(3).cuda() 62 | r = r.repeat(src_mean.shape[0],1,1) 63 | t = torch.zeros((src_mean.shape[0],3,1)).cuda() 64 | #t = t.view(t.shape[0], 3) 65 | transform = torch.cat((r, t), dim=2) 66 | return transform 67 | 68 | tm_determinant = torch.det(torch.matmul(v.transpose(1,2), u.transpose(1,2))) 69 | 70 | determinant_matrix = torch.diag_embed(torch.cat((torch.ones((tm_determinant.shape[0], 2)).cuda(),tm_determinant.unsqueeze(1)), 1)) 71 | r = torch.matmul(v, torch.matmul(determinant_matrix, u.transpose(1,2))) 72 | t = src_corres_mean.transpose(1,2) - torch.matmul(r, src_mean.transpose(1,2)) 73 | #t = t.view(t.shape[0], 3) 74 | 75 | transform = torch.cat((r, t), dim=2) 76 | return transform 77 | 78 | 79 | def orthogo_tensor(x): 80 | m, n = x.size() 81 | x_np = x.t().numpy() 82 | matrix = [Matrix(col) for col in x_np.T] 83 | gram = GramSchmidt(matrix) 84 | ort_list = [] 85 | for i in range(m): 86 | vector = [] 87 | for j in range(n): 88 | vector.append(float(gram[i][j])) 89 | ort_list.append(vector) 90 | ort_list = np.mat(ort_list) 91 | ort_list = torch.from_numpy(ort_list) 92 | ort_list = F.normalize(ort_list,dim=1) 93 | return ort_list 94 | 95 | if __name__ == '__main__': 96 | src = torch.randn((8, 128, 3)) 97 | R_r = torch.randn((8,3,3)) 98 | for i in range(8): 99 | R_r[i,:,:] = orthogo_tensor(R_r[i,:,:]) 100 | T_r = torch.randn((8,3,1)) 101 | ref = torch.bmm( R_r, src.transpose(1,2) ) + T_r 102 | ref = ref.transpose(1,2).contiguous() 103 | perm = torch.eye( 128 ).reshape((1, 128, 128)).repeat(8, 1, 1) 104 | transform = svd( src, ref, perm ) 105 | for i in range(8): 106 | print('----%d-----' % i) 107 | print( torch.cat((R_r, T_r), dim=2)[i] ) 108 | print( transform[i] ) 109 | 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /models/key_point_dectector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | from models.utils import furthest_point_sample, weighted_furthest_point_sample, gather_operation 8 | 9 | 10 | class KeypointDetector(nn.Module): 11 | def __init__(self, nsample, sem_num, sample_type = 'fps'): 12 | super(KeypointDetector, self).__init__() 13 | self.nsample = nsample 14 | self.sample_type = sample_type 15 | self.semantic_classes_num = sem_num 16 | 17 | 18 | def forward( self, xyz, seg_feature, seg_label = None, weights = None ): 19 | # Use FPS or random sampling 20 | B, N, C = xyz.shape 21 | sample_seg_feature, sample_seg_label = None, None 22 | seg_weights = None 23 | if seg_label != None: 24 | sem_one_hot = F.one_hot( seg_label.long(), num_classes=self.semantic_classes_num ) 25 | count = torch.sum( sem_one_hot, dim= 1 ) 26 | seg_weights = torch.gather( count, dim= 1, index= seg_label.long() ).float() 27 | if weights != None: 28 | assert( seg_weights.shape == weights.shape ) 29 | seg_weights = seg_weights * weights 30 | if self.sample_type == 'fps': 31 | # Use WFPS 32 | idx = weighted_furthest_point_sample(xyz, seg_weights, self.nsample) 33 | sampled_xyz = gather_operation(xyz.permute(0,2,1).contiguous(), idx).permute(0,2,1).contiguous() 34 | else: 35 | idx = torch.multinomial( seg_weights, self.nsample ) 36 | sampled_xyz = torch.gather( xyz, dim= 1, index= idx.unsqueeze(-1).repeat(1,1,C).long() ) 37 | else: 38 | if self.sample_type == 'fps': 39 | idx = furthest_point_sample(xyz, self.nsample) 40 | sampled_xyz = gather_operation(xyz.permute(0,2,1).contiguous(), idx).permute(0,2,1).contiguous() 41 | else: 42 | idx = torch.randperm(N)[:self.nsample] 43 | sampled_xyz = xyz[:,idx,:] 44 | idx = idx.unsqueeze(0).repeat(B, 1) 45 | sample_seg_feature = torch.gather( seg_feature, dim=-1, index=idx.unsqueeze(1).repeat(1, 46 | seg_feature.shape[1], 1).long()) 47 | if seg_label != None: 48 | sample_seg_label = torch.gather( seg_label, dim=1, index=idx.long()) 49 | 50 | return sampled_xyz, sample_seg_feature, sample_seg_label 51 | 52 | if __name__ == '__main__': 53 | _device = torch.device("cuda:0") 54 | xyz = Variable( torch.rand((2, 8, 3))) 55 | sem = Variable( torch.randint(0 , 20, (2, 8)) ) 56 | detector = KeypointDetector( 4, 20, sample_type = 'fps' ) 57 | sampled_xyz, idx = detector( xyz, sem ) 58 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | 5 | from typing import List 6 | from common.math_torch import se3 7 | 8 | #model 9 | from models.model import SARNet 10 | 11 | class Loss(nn.Module): 12 | def __init__(self, args: argparse.Namespace): 13 | super(Loss, self).__init__() 14 | self.model = SARNet( args ) 15 | self.trans_loss_type = args.trans_loss_type 16 | self.seg_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True) 17 | self.reg_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True) 18 | 19 | def forward(self, data: dict ): 20 | losses = {} 21 | predict = self.model( data['points_src'], data['points_ref'], 22 | data['seg_src'], data['seg_ref'], 23 | data['intersect_src'], data['intersect_ref'] ) 24 | losses_trans_iter = self.compute_trans_loss( data, predict['pred_transforms'], self.trans_loss_type ) 25 | discount_factor = 0.5 # Early iterations will be discounted 26 | iter_num = len( losses_trans_iter ) 27 | for i in range(iter_num): 28 | discount = discount_factor ** ( iter_num-i-1 ) 29 | losses_trans_iter[i] *= discount 30 | losses['trans'] = torch.sum(torch.stack(losses_trans_iter)) 31 | losses['semantic'] = self.compute_semantic_loss( predict ) 32 | factor_reg = 1.0 / (self.reg_sigma**2) 33 | factor_seg = 1.0 / (self.seg_sigma**2) 34 | losses['total'] = factor_reg*losses['trans'] + factor_seg*losses['semantic'] + \ 35 | 2 * torch.log(self.reg_sigma) + 2 * torch.log(self.seg_sigma) 36 | return predict, losses 37 | 38 | 39 | def compute_trans_loss(self, data: dict, pred_transforms: List, 40 | loss_type: str = 'mse', reduction: str= 'mean' ): 41 | # Compute losses 42 | losses = [] 43 | iter_num = len( pred_transforms ) 44 | gt_src_transformed = se3.transform(data['transform_gt'], data['points_src'][..., :3]) 45 | if loss_type == 'mse': 46 | # MSE loss to the groundtruth (does not take into account possible symmetries) 47 | criterion = nn.MSELoss( reduction= reduction ) 48 | for i in range( iter_num ): 49 | pred_src_transformed = se3.transform( pred_transforms[i], data['points_src'][..., :3] ) 50 | losses.append(criterion(pred_src_transformed, gt_src_transformed)) 51 | elif loss_type == 'mae': 52 | # MAE loss to the groundtruth (does not take into account possible symmetries) 53 | criterion = nn.L1Loss( reduction= reduction ) 54 | for i in range( iter_num ): 55 | pred_src_transformed = se3.transform( pred_transforms[i], data['points_src'][..., :3] ) 56 | losses.append(criterion(pred_src_transformed, gt_src_transformed)) 57 | else: 58 | raise NotImplementedError 59 | return losses 60 | 61 | def compute_semantic_loss(self, predict: dict ): 62 | criterion = nn.CrossEntropyLoss() 63 | seg_src = predict['seg_src'] 64 | seg_ref = predict['seg_ref'] 65 | sem_label_src = predict['seg_label_src'] 66 | sem_label_ref = predict['seg_label_ref'] 67 | loss = criterion( seg_src, sem_label_src.long() ) + criterion( seg_ref, sem_label_ref.long() ) 68 | return loss 69 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from common.math_torch import se3 7 | 8 | #from models.socnn import SOCNN 9 | from models.RandLA_Net import RandLANet 10 | from models.key_point_dectector import KeypointDetector 11 | from models.semanticCNN import SemanticCNN 12 | from models.attention import AttentionalPropagation 13 | from models.compute_rigid_transform import weighted_svd 14 | 15 | class PermutationWeights(nn.Module): 16 | def __init__( self, in_dim ): 17 | super( PermutationWeights, self ).__init__() 18 | self.conv = nn.Sequential(nn.Conv1d(in_dim*2, in_dim*2, kernel_size=1), 19 | nn.BatchNorm1d(in_dim*2), 20 | nn.ReLU(), 21 | nn.Conv1d(in_dim*2, in_dim, kernel_size=1), 22 | nn.BatchNorm1d(in_dim), 23 | nn.ReLU(), 24 | nn.Conv1d(in_dim, 1, kernel_size=1) 25 | ) 26 | 27 | def forward( self, features_src: torch.Tensor, featrues_ref: torch.Tensor, 28 | part_match_matrix: torch.Tensor = None ): 29 | d_k = features_src.size(1) 30 | scores = torch.matmul( features_src.transpose(1,2), featrues_ref )/math.sqrt(d_k) 31 | 32 | if part_match_matrix != None: 33 | assert( scores.shape == part_match_matrix.shape ) 34 | scores[part_match_matrix==0] = -1e10 35 | permutation = torch.softmax( scores, dim=-1 ) 36 | featrues_ref_perm = torch.bmm( permutation, featrues_ref.transpose(1,2) ).transpose(1,2).contiguous() 37 | weights = self.conv(torch.cat( (features_src, featrues_ref_perm), dim=1 )) 38 | weights_s = torch.sigmoid( weights.squeeze(1) ) 39 | return permutation, weights_s 40 | 41 | class SARNet(nn.Module): 42 | def __init__(self, args: argparse.Namespace): 43 | super( SARNet, self ).__init__() 44 | self.iter_num = args.iter_num 45 | self.classes_num = args.semantic_classes_num 46 | self.segcnn = RandLANet( 47 | args.init_dims, 48 | args.semantic_classes_num, 49 | num_neighbors=args.nb 50 | ) 51 | self.detector = KeypointDetector( args.nsample, args.semantic_classes_num ) 52 | self.regcnn = nn.ModuleList() 53 | self.attention_fea= nn.ModuleList() 54 | self.semconv, self.conv = nn.ModuleList(), nn.ModuleList() 55 | self.perm_weights = nn.ModuleList() 56 | for i in range(self.iter_num): 57 | self.regcnn += [SemanticCNN( args.init_dims, args.emb_dims, args.nb )] 58 | self.attention_fea += [AttentionalPropagation( args.emb_dims, args.attention_head_num )] 59 | 60 | self.conv += [nn.Sequential( 61 | nn.Conv1d( args.emb_dims + args.semantic_classes_num, args.emb_dims, kernel_size=1 ), 62 | nn.BatchNorm1d( args.emb_dims ), 63 | nn.ReLU(), 64 | nn.Conv1d( args.emb_dims, args.emb_dims, kernel_size=1 ) 65 | )] 66 | 67 | self.perm_weights += [PermutationWeights( args.emb_dims )] 68 | 69 | def forward( self, *input ): 70 | points_src, points_ref = input[0], input[1] 71 | seg_label_src, seg_label_ref = None, None 72 | weights_src, weights_ref = None, None 73 | if len(input) == 4: 74 | seg_label_src, seg_label_ref = input[2], input[3] 75 | if len(input) == 6: 76 | seg_label_src, seg_label_ref = input[2], input[3] 77 | weights_src, weights_ref = input[4], input[5] 78 | 79 | 80 | 81 | seg_src = self.segcnn( points_src ) 82 | seg_ref = self.segcnn( points_ref ) 83 | seg_src = torch.distributions.utils.probs_to_logits(seg_src, is_binary=False) 84 | seg_ref = torch.distributions.utils.probs_to_logits(seg_ref, is_binary=False) 85 | seg_src_detach = seg_src.detach() 86 | seg_ref_detach = seg_ref.detach() 87 | points_src, sample_seg_src, sample_seg_label_src = self.detector( 88 | points_src, seg_src_detach, seg_label = seg_label_src, weights = weights_src ) 89 | points_ref, sample_seg_ref, sample_seg_label_ref = self.detector( 90 | points_ref, seg_ref_detach, seg_label = seg_label_ref, weights = weights_ref ) 91 | onehot_src, onehot_ref = None, None 92 | if sample_seg_label_src != None and sample_seg_label_ref != None: 93 | onehot_src = F.one_hot( sample_seg_label_src.long(), num_classes = self.classes_num ).float() 94 | onehot_ref = F.one_hot( sample_seg_label_ref.long(), num_classes = self.classes_num ).float() 95 | else: 96 | onehot_src = F.one_hot( sample_seg_src.max(dim=1)[1].long(), num_classes = self.classes_num ).float() 97 | onehot_ref = F.one_hot( sample_seg_ref.max(dim=1)[1].long(), num_classes = self.classes_num ).float() 98 | 99 | part_match_matrix = torch.bmm( onehot_src, onehot_ref.transpose(1,2) ).int() 100 | onehot_src = onehot_src.transpose(1,2).contiguous() 101 | onehot_ref = onehot_ref.transpose(1,2).contiguous() 102 | 103 | points_src_t = points_src 104 | transforms = [None]*self.iter_num 105 | for i in range( self.iter_num ): 106 | features_src = self.regcnn[i]( points_src_t ) 107 | features_ref = self.regcnn[i]( points_ref ) 108 | features_src = features_src + self.attention_fea[i]( features_src, features_ref ) 109 | features_ref = features_ref + self.attention_fea[i]( features_ref, features_src ) 110 | 111 | features_seg_src = self.conv[i](torch.cat(( features_src, onehot_src ), dim=1 )) 112 | features_seg_ref = self.conv[i](torch.cat(( features_ref, onehot_ref ), dim=1 )) 113 | 114 | permutation, weights = self.perm_weights[i]( features_seg_src, features_seg_ref, part_match_matrix ) 115 | transform = weighted_svd( points_src, points_ref, weights, permutation ) 116 | points_src_t = se3.transform( transform.detach(), points_src_t ) 117 | transforms[i] = transform 118 | 119 | predict = { 'pred_transforms': transforms, 120 | 'seg_src': seg_src, 121 | 'seg_ref': seg_ref, 122 | 'seg_label_src': seg_label_src, 123 | 'seg_label_ref': seg_label_ref 124 | } 125 | return predict 126 | 127 | if __name__ == '__main__': 128 | pass 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /models/semanticCNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | from torch.autograd import Variable 7 | 8 | 9 | def nearest_neighbor(src, dst): 10 | inner = -2 * torch.matmul(src.transpose(1, 0).contiguous(), dst) # src, dst (num_dims, num_points) 11 | distances = -torch.sum(src ** 2, dim=0, keepdim=True).transpose(1, 0).contiguous() - inner - torch.sum(dst ** 2, 12 | dim=0, 13 | keepdim=True) 14 | distances, indices = distances.topk(k=1, dim=-1) 15 | return distances, indices 16 | 17 | 18 | def knn(x, k): 19 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 20 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 21 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 22 | 23 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 24 | return idx 25 | 26 | def get_graph_feature(data, k=20): 27 | xyz = data 28 | # x = x.squeeze() 29 | idx = knn(xyz, k=k) # (batch_size, num_points, k) 30 | batch_size, num_points, _ = idx.size() 31 | # device = torch.device('cuda') 32 | 33 | idx_base = torch.arange(0, batch_size).to(xyz.device).view(-1, 1, 1) * num_points 34 | 35 | idx = idx + idx_base 36 | 37 | idx = idx.view(-1) 38 | 39 | _, num_dims, _ = xyz.size() 40 | 41 | xyz = xyz.transpose(2, 1).contiguous() 42 | # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) 43 | # batch_size * num_points * k + range(0, batch_size*num_points) 44 | 45 | # gxyz 46 | neighbor_gxyz = xyz.view(batch_size * num_points, -1)[idx, :] 47 | neighbor_gxyz = neighbor_gxyz.view(batch_size, num_points, k, num_dims) 48 | # xyz 49 | xyz = xyz.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 50 | #lxyz_norm 51 | neighbor_lxyz_norm = torch.norm(neighbor_gxyz - xyz, dim=3, keepdim=True) 52 | 53 | feature = torch.cat((xyz, neighbor_gxyz, neighbor_lxyz_norm), dim=3) 54 | 55 | feature = feature.permute(0, 3, 1, 2).contiguous() 56 | 57 | return feature 58 | 59 | 60 | class SemanticConv(nn.Module): 61 | def __init__(self, in_dim, out_dim, neighboursnum=16): 62 | super(SemanticConv, self).__init__() 63 | # 确定输入的点云信息 64 | self.neighboursnum = neighboursnum 65 | self.in_dim = in_dim 66 | 67 | self.localConv = nn.Sequential( 68 | nn.Conv2d(in_dim*2+1, out_dim, kernel_size=1, bias=False ), 69 | nn.BatchNorm2d(out_dim), 70 | nn.LeakyReLU() 71 | ) 72 | self.semConv = nn.Sequential( 73 | nn.Conv1d(in_dim, out_dim, kernel_size=1, bias=False ), 74 | nn.BatchNorm1d(out_dim), 75 | nn.LeakyReLU() 76 | ) 77 | self.fullConv = nn.Sequential( 78 | nn.Conv1d(out_dim*2, out_dim, kernel_size=1, bias=False ), 79 | nn.BatchNorm1d(out_dim), 80 | nn.LeakyReLU() 81 | ) 82 | self.semAtt = nn.Conv1d(in_dim, in_dim, kernel_size=1) 83 | self.proj = nn.ModuleList([deepcopy(self.semAtt) for _ in range(3)]) 84 | #self.conv_a = nn.Conv1d( in_dim, in_dim, kernel_size=1 ) 85 | 86 | 87 | def forward( self, f_in ): # f_in:(B, C, N) 88 | neighbor_f_in = get_graph_feature(f_in, self.neighboursnum) # (B, C, N, n) 89 | Intra_channal = self.localConv( neighbor_f_in ) 90 | Intra_channal = Intra_channal.max(dim=-1, keepdim=False)[0] 91 | q, k, v = [ l(f_in) for l in self.proj ] 92 | scores = torch.einsum('bdm,bdn->bmn', q, k) / self.in_dim**.5 93 | scores = torch.softmax(scores, dim=-1) 94 | fgt = torch.einsum('bmn,bdn->bdm', scores, v) 95 | #a = (self.conv_a( fgt ) + neighbor_mean)/2 96 | Inter_channal = self.semConv(fgt) 97 | feature = self.fullConv( torch.cat( (Intra_channal, Inter_channal), dim= 1 ) ) 98 | return feature #, Inter_channal 99 | 100 | 101 | class SemanticCNN(nn.Module): 102 | def __init__(self, raw_dim, emb_dim, neighboursnum=16): 103 | super(SemanticCNN, self).__init__() 104 | self.conv1 = nn.Sequential( 105 | nn.Conv1d( raw_dim, emb_dim//16, kernel_size=1 ), 106 | nn.BatchNorm1d(emb_dim//16), 107 | nn.ReLU() 108 | ) 109 | self.conv2 = nn.Sequential( 110 | nn.Conv1d( emb_dim, emb_dim, kernel_size=1 ), 111 | nn.BatchNorm1d(emb_dim), 112 | nn.ReLU(), 113 | nn.Conv1d( emb_dim, emb_dim, kernel_size=1 ) 114 | ) 115 | #self.conv3 = nn.Sequential( 116 | # nn.Conv1d( emb_dim*2, emb_dim, kernel_size=1 ), 117 | # nn.BatchNorm1d(emb_dim), 118 | # nn.ReLU(), 119 | # nn.Conv1d( emb_dim, emb_dim, kernel_size=1 ) 120 | #) 121 | self.sem1 = SemanticConv( emb_dim//16, emb_dim//16, neighboursnum ) 122 | self.sem2 = SemanticConv( emb_dim//16, emb_dim//8, neighboursnum ) 123 | self.sem3 = SemanticConv( emb_dim//8, emb_dim//4, neighboursnum ) 124 | self.sem4 = SemanticConv( emb_dim//4, emb_dim//2, neighboursnum ) 125 | 126 | def forward(self, xyz): 127 | xyz = xyz.permute(0, 2, 1).contiguous() #(B, 3, N) 128 | #points_num = xyz.shape[2] 129 | x0 = self.conv1( xyz ) 130 | x1 = self.sem1(x0) 131 | x2 = self.sem2(x1) 132 | x3 = self.sem3(x2) 133 | x4 = self.sem4(x3) 134 | 135 | x = torch.cat((x0, x1, x2, x3, x4), dim=1) 136 | #gx = torch.cat((x0, gx1, gx2, gx3, gx4), dim=1) 137 | #gx_m = torch.max( gx, dim=2, keepdim=True )[0] 138 | #gx_f = torch.cat((gx_m.repeat( 1, 1, points_num), gx), dim = 1) 139 | feature = self.conv2( x ) 140 | #semantic = self.conv3( gx_f ) 141 | return feature 142 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from torch.autograd import Function 5 | import torch.nn.functional as F 6 | 7 | import point_utils_cuda 8 | from pytorch3d.loss import chamfer_distance 9 | from pytorch3d.ops import knn_points, knn_gather 10 | from scipy.spatial.transform import Rotation 11 | import random 12 | 13 | class FurthestPointSampling(Function): 14 | @staticmethod 15 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 16 | ''' 17 | ctx: 18 | xyz: [B,N,3] 19 | npoint: int 20 | ''' 21 | assert xyz.is_contiguous() 22 | 23 | B, N, _ = xyz.size() 24 | output = torch.cuda.IntTensor(B, npoint) 25 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 26 | 27 | point_utils_cuda.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) 28 | return output 29 | 30 | @staticmethod 31 | def backward(xyz, a=None): 32 | return None, None 33 | 34 | furthest_point_sample = FurthestPointSampling.apply 35 | 36 | class WeightedFurthestPointSampling(Function): 37 | @staticmethod 38 | def forward(ctx, xyz: torch.Tensor, weights: torch.Tensor, npoint: int) -> torch.Tensor: 39 | ''' 40 | ctx: 41 | xyz: [B,N,3] 42 | weights: [B,N] 43 | npoint: int 44 | ''' 45 | assert xyz.is_contiguous() 46 | assert weights.is_contiguous() 47 | B, N, _ = xyz.size() 48 | output = torch.cuda.IntTensor(B, npoint) 49 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 50 | 51 | point_utils_cuda.weighted_furthest_point_sampling_wrapper(B, N, npoint, xyz, weights, temp, output); 52 | return output 53 | 54 | @staticmethod 55 | def backward(xyz, a=None): 56 | return None, None 57 | 58 | weighted_furthest_point_sample = WeightedFurthestPointSampling.apply 59 | 60 | class GatherOperation(Function): 61 | @staticmethod 62 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 63 | ''' 64 | ctx 65 | features: [B,C,N] 66 | idx: [B,npoint] 67 | ''' 68 | assert features.is_contiguous() 69 | assert idx.is_contiguous() 70 | 71 | B, npoint = idx.size() 72 | _, C, N = features.size() 73 | output = torch.cuda.FloatTensor(B, C, npoint) 74 | 75 | point_utils_cuda.gather_points_wrapper(B, C, N, npoint, features, idx, output) 76 | 77 | ctx.for_backwards = (idx, C, N) 78 | return output 79 | 80 | @staticmethod 81 | def backward(ctx, grad_out): 82 | idx, C, N = ctx.for_backwards 83 | B, npoint = idx.size() 84 | grad_features = Variable(torch.cuda.FloatTensor(B,C,N).zero_()) 85 | grad_out_data = grad_out.data.contiguous() 86 | point_utils_cuda.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) 87 | return grad_features, None 88 | 89 | gather_operation = GatherOperation.apply 90 | 91 | def generate_rand_rotm(x_lim=5.0, y_lim=5.0, z_lim=180.0): 92 | ''' 93 | Input: 94 | x_lim 95 | y_lim 96 | z_lim 97 | return: 98 | rotm: [3,3] 99 | ''' 100 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 101 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 102 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 103 | 104 | rand_eul = np.array([rand_z, rand_y, rand_x]) 105 | r = Rotation.from_euler('zyx', rand_eul, degrees=True) 106 | rotm = r.as_matrix() 107 | return rotm 108 | 109 | def generate_rand_trans(x_lim=10.0, y_lim=1.0, z_lim=0.1): 110 | ''' 111 | Input: 112 | x_lim 113 | y_lim 114 | z_lim 115 | return: 116 | trans [3] 117 | ''' 118 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 119 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 120 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 121 | 122 | rand_trans = np.array([rand_x, rand_y, rand_z]) 123 | 124 | return rand_trans 125 | 126 | def apply_transform(pts, trans): 127 | R = trans[:3, :3] 128 | T = trans[:3, 3] 129 | pts = pts @ R.T + T 130 | return pts 131 | 132 | def calc_error_np(pred_R, pred_t, gt_R, gt_t): 133 | tmp = (np.trace(pred_R.transpose().dot(gt_R))-1)/2 134 | if np.abs(tmp) > 1.0: 135 | tmp = 1.0 136 | L_rot = np.arccos(tmp) 137 | L_rot = 180 * L_rot / np.pi 138 | L_trans = np.linalg.norm(pred_t - gt_t) 139 | return L_rot, L_trans 140 | 141 | def set_seed(seed): 142 | ''' 143 | Set random seed for torch, numpy and python 144 | ''' 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | torch.manual_seed(seed) 148 | if torch.cuda.is_available(): 149 | torch.cuda.manual_seed(seed) 150 | torch.cuda.manual_seed_all(seed) 151 | 152 | torch.backends.cudnn.benchmark=False 153 | torch.backends.cudnn.deterministic=True -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coloredlogs==15.0 2 | GitPython==3.1.27 3 | numpy==1.19.2 4 | nuscenes_devkit==1.1.9 5 | open3d==0.14.1 6 | pandas==1.2.4 7 | pytorch3d==0.5.0 8 | PyYAML==6.0 9 | scipy==1.5.4 10 | sympy==1.8 11 | tensorboardX==2.5.1 12 | torch_points_kernels==0.7.0 13 | torchvision==0.9.0 14 | tqdm==4.56.0 15 | -------------------------------------------------------------------------------- /script/test_kitti.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python ../test.py; -------------------------------------------------------------------------------- /script/test_nuscenes.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python ../test.py --dataset NuScenes --semantic_classes_num 14 --sample_voxel_size 0.3 \ 2 | --sample_point_num 8000 --boundingbox_diagonal 80; -------------------------------------------------------------------------------- /script/train_kitti.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1,2 python -m torch.distributed.launch --nproc_per_node=2 --master_port=12353 ../train.py; 2 | 3 | -------------------------------------------------------------------------------- /script/train_nuscenes.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port=12363 \ 2 | ../train.py --dataset NuScenes --semantic_classes_num 14 --sample_voxel_size 0.3 \ 3 | --sample_point_num 8000 --boundingbox_diagonal 80; 4 | CUDA_VISIBLE_DEVICES=1 python ../test.py --dataset NuScenes --semantic_classes_num 14 --sample_voxel_size 0.3 \ 5 | --sample_point_num 8000 --boundingbox_diagonal 80; -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import json 6 | import pandas as pd 7 | import time 8 | 9 | import torch 10 | import torchvision 11 | import torch.utils.data 12 | 13 | 14 | #ArgumentParser 15 | from arguments import test_arguments 16 | 17 | #dataloader 18 | from dataloader.transforms import get_transforms 19 | from dataloader.NuScenesDataLoader import NuScenesDataSet 20 | from dataloader.SemanticKITTYDataLoader import KittiDataset 21 | 22 | #model 23 | from models.loss import Loss 24 | 25 | #metrics 26 | from metrics import compute_metrics, summarize_metrics, print_metrics 27 | 28 | #common 29 | from tqdm import tqdm 30 | from common.torch import CheckPointManager, dict_all_to_device, to_numpy 31 | from common.math_torch import se3 32 | from common.misc import prepare_logger 33 | 34 | parser = test_arguments() 35 | _args = parser.parse_args() 36 | _device = torch.device("cuda:0") 37 | 38 | 39 | def test( pred_transforms, data_loader: torch.utils.data.dataloader.DataLoader): 40 | """ Test the computed transforms against the groundtruth 41 | 42 | Args: 43 | pred_transforms: Predicted transforms (N, B, 3/4, 4) 44 | data_loader: Loader for dataset. 45 | 46 | Returns: 47 | Computed metrics (List of dicts), and summary metrics (only for last iter) 48 | """ 49 | 50 | _logger.info('Testing transforms...') 51 | all_metrics_np = defaultdict(list) 52 | 53 | num_processed = 0 54 | for data in tqdm(data_loader, leave=False): 55 | dict_all_to_device(data, _device) 56 | batch_size = data['points_src'].shape[0] 57 | metrics = compute_metrics(data, pred_transforms[num_processed:num_processed+batch_size, :, :] ) 58 | num_processed += batch_size 59 | for k in metrics: 60 | all_metrics_np[k].append( metrics[k] ) 61 | all_metrics_np = {k: np.concatenate(all_metrics_np[k]) for k in all_metrics_np} 62 | summary_metrics = summarize_metrics(all_metrics_np, _args.RRE_thresholds, _args.RTE_thresholds ) 63 | print_metrics(_logger, summary_metrics, title='Evaluation result') 64 | 65 | return all_metrics_np, summary_metrics 66 | 67 | 68 | def inference(data_loader, model: torch.nn.Module): 69 | """Runs inference over entire dataset 70 | 71 | Args: 72 | data_loader (torch.utils.data.DataLoader): Dataset loader 73 | model (model.nn.Module): Network model to evaluate 74 | 75 | Returns: 76 | pred_transforms_all: predicted transforms (N, B, 3, 4) where N is total number of instances 77 | endpoints_out (Dict): Network endpoints 78 | """ 79 | 80 | _logger.info('Starting inference...') 81 | model.eval() 82 | 83 | pred_transforms_all = [] 84 | total_time = 0.0 85 | total_rotation = [] 86 | 87 | with torch.no_grad(): 88 | 89 | for test_data in tqdm(data_loader): 90 | rot_trace = test_data['transform_gt'][:, 0, 0] + test_data['transform_gt'][:, 1, 1] + \ 91 | test_data['transform_gt'][:, 2, 2] 92 | rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi 93 | total_rotation.append(np.abs(to_numpy(rotdeg))) 94 | 95 | dict_all_to_device(test_data, _device) 96 | time_before = time.time() 97 | pred = model( test_data['points_src'], test_data['points_ref'], 98 | test_data['seg_src'], test_data['seg_ref'], 99 | test_data['intersect_src'], test_data['intersect_ref'] ) 100 | pred_transforms = pred['pred_transforms'] 101 | total_time += time.time() - time_before 102 | pred_transforms_all.append(to_numpy(pred_transforms[-1])) 103 | 104 | _logger.info('Total inference time: {}s'.format(total_time)) 105 | total_rotation = np.concatenate(total_rotation, axis=0) 106 | _logger.info('Rotation range in data: {}(avg), {}(max)'.format(np.mean(total_rotation), np.max(total_rotation))) 107 | pred_transforms_all = np.concatenate(pred_transforms_all, axis=0) 108 | 109 | return pred_transforms_all 110 | 111 | def save_eval_data(pred_transforms, metrics, summary_metrics, save_path): 112 | """Saves out the computed transforms 113 | """ 114 | 115 | # Save transforms 116 | np.save(os.path.join(save_path, 'pred_transforms.npy'), pred_transforms) 117 | 118 | # Save metrics 119 | writer = pd.ExcelWriter(os.path.join(save_path, 'metrics.xlsx')) 120 | metrics.pop('err_r_deg') 121 | metrics.pop('err_t') 122 | metrics_df = pd.DataFrame.from_dict(metrics) 123 | metrics_df.to_excel(writer, sheet_name='metrics') 124 | writer.close() 125 | 126 | # Save summary metrics 127 | summary_metrics_float = {k: float(summary_metrics[k]) for k in summary_metrics} 128 | with open(os.path.join(save_path, 'summary_metrics.json'), 'w') as json_out: 129 | json.dump(summary_metrics_float, json_out) 130 | 131 | _logger.info('Saved evaluation results to {}'.format(save_path)) 132 | 133 | def get_model(): 134 | criteria = Loss( _args ) 135 | criteria.to( _device ) 136 | save_path = os.path.join( _args.checkpoints_path, 'ckpt') 137 | saver = CheckPointManager( save_path ) 138 | load_path = os.path.join( _args.checkpoints_path, 'ckpt-best.pth' ) 139 | global_step = 0 140 | if os.path.exists(load_path): 141 | global_step = saver.load( load_path, criteria, distributed=True ) 142 | print( "global_step:", global_step ) 143 | model = criteria.model 144 | model.eval() 145 | return model 146 | 147 | 148 | def main(): 149 | #dataloader 150 | test_set, test_loader = None, None 151 | _, test_trainsform = get_transforms( noise_type = _args.noise_type, 152 | rot_mag = _args.rot_mag, trans_mag = _args.trans_mag, voxel_size= _args.sample_voxel_size, 153 | num = _args.sample_point_num, diagonal= _args.boundingbox_diagonal, partial_p_keep = _args.partial_p_keep 154 | ) 155 | _logger.info('Test transforms: {}'.format(', '.join([type(t).__name__ for t in test_trainsform]))) 156 | test_trainsform = torchvision.transforms.Compose(test_trainsform) 157 | if _args.dataset == 'NuScenes': 158 | test_set = NuScenesDataSet( root = _args.nuscenes_root, split='test', 159 | transform = test_trainsform, ignore_label= _args.nuscenes_ignore_label, augment= _args.augment ) 160 | test_loader = torch.utils.data.DataLoader( test_set, batch_size=_args.test_batch_size, shuffle=False, num_workers=_args.num_workers) 161 | elif _args.dataset == 'SemanticKitti': 162 | test_set = KittiDataset( root = _args.kitty_root, split='test', 163 | transform = test_trainsform, ignore_label= _args.kitti_ignore_label, augment= _args.augment ) 164 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=_args.test_batch_size, shuffle=False, num_workers= _args.num_workers) 165 | 166 | #model 167 | if _args.transform_file is not None: 168 | _logger.info('Loading from precomputed transforms: {}'.format(_args.transform_file)) 169 | pred_transforms = np.load(_args.transform_file) 170 | else: 171 | model = get_model() 172 | pred_transforms = inference(test_loader, model) # Feedforward transforms 173 | 174 | eval_metrics, summary_metrics = test( torch.from_numpy(pred_transforms).to(_device), data_loader=test_loader ) 175 | save_eval_data( pred_transforms, eval_metrics, summary_metrics, _args.eval_save_path ) 176 | 177 | if __name__ == '__main__': 178 | _logger, _log_path = prepare_logger(_args, log_path=_args.eval_save_path ) 179 | #_args.transform_file = osp.join( _args.eval_save_path, 'pred_transforms.npy') 180 | main() 181 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import os.path as osp 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch._C import device 8 | import torchvision 9 | from tensorboardX import SummaryWriter 10 | 11 | #ArgumentParser 12 | from arguments import train_arguments 13 | 14 | #dataloader 15 | from dataloader.transforms import get_transforms 16 | from dataloader.NuScenesDataLoader import NuScenesDataSet 17 | from dataloader.SemanticKITTYDataLoader import KittiDataset 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | #loss 21 | from models.loss import Loss 22 | 23 | #common 24 | from common.torch import CheckPointManager, dict_all_to_device, to_numpy 25 | from common.misc import prepare_logger 26 | 27 | #metrics 28 | from metrics import compute_metrics, summarize_metrics, print_metrics 29 | 30 | #others 31 | from tqdm import tqdm 32 | from typing import Dict 33 | from collections import defaultdict 34 | 35 | 36 | parser = train_arguments() 37 | _args = parser.parse_args() 38 | #initialize 39 | torch.distributed.init_process_group(backend="nccl", world_size=2) 40 | #get gpu 41 | _local_rank = torch.distributed.get_rank() 42 | torch.cuda.set_device(_local_rank) 43 | _device = torch.device("cuda", _local_rank) 44 | 45 | def reduce_tensor(tensor: torch.Tensor): 46 | rt = tensor.clone() 47 | torch.distributed.all_reduce( rt, op=torch.distributed.ReduceOp.SUM) 48 | rt /= torch.distributed.get_world_size() 49 | return rt 50 | 51 | def validate_gradient(model): 52 | """ 53 | Confirm all the gradients are non-nan and non-inf 54 | """ 55 | for name, param in model.named_parameters(): 56 | if param.grad is not None: 57 | if torch.any(torch.isnan(param.grad)): 58 | return False 59 | if torch.any(torch.isinf(param.grad)): 60 | return False 61 | return True 62 | 63 | 64 | def save_summaries( writer: SummaryWriter, losses: Dict = None, metrics: Dict = None, step: int = 0): 65 | """Save tensorboard summaries""" 66 | with torch.no_grad(): 67 | if losses is not None: 68 | for l in losses: 69 | writer.add_scalar( 'losses/{}'.format(l), losses[l], step ) 70 | if metrics is not None: 71 | for m in metrics: 72 | writer.add_scalar( 'metrics/{}'.format(m), metrics[m], step ) 73 | writer.flush() 74 | 75 | 76 | def validate( data_loader, criteria: nn.Module, writer: SummaryWriter, step: int = 0 ): 77 | """Perform a single validation run""" 78 | 79 | with torch.no_grad(): 80 | val_losses_np = defaultdict(list) 81 | val_metrics_np = defaultdict(list) 82 | for data in data_loader: 83 | dict_all_to_device( data, _device ) 84 | predict, losses = criteria( data ) 85 | metrics = compute_metrics( data, predict['pred_transforms'][-1] ) 86 | for k in metrics: 87 | val_metrics_np[k].append( metrics[k] ) 88 | for k in losses: 89 | val_losses_np[k].append( to_numpy(losses[k]) ) 90 | val_losses_np = { k : np.mean( val_losses_np[k] ) for k in val_losses_np } 91 | val_metrics_np = { k : np.concatenate( val_metrics_np[k] ) for k in val_metrics_np } 92 | summary_metrics = summarize_metrics( val_metrics_np, _args.RRE_thresholds, _args.RTE_thresholds ) 93 | print_metrics( _logger, summary_metrics ) 94 | 95 | score = -val_losses_np['trans'] 96 | 97 | save_summaries( writer, val_losses_np, summary_metrics, step ) 98 | return score 99 | 100 | def main(): 101 | #dataloader 102 | train_loader, val_loader = None, None 103 | train_transform, val_trainsform = get_transforms( noise_type = _args.noise_type, 104 | rot_mag = _args.rot_mag, trans_mag = _args.trans_mag, voxel_size= _args.sample_voxel_size, 105 | num = _args.sample_point_num, diagonal= _args.boundingbox_diagonal, partial_p_keep = _args.partial_p_keep 106 | ) 107 | train_transform = torchvision.transforms.Compose( train_transform ) 108 | val_trainsform = torchvision.transforms.Compose( val_trainsform ) 109 | if _args.dataset == 'NuScenes': 110 | train_set = NuScenesDataSet( root = _args.nuscenes_root, split='train', 111 | transform = train_transform, ignore_label= _args.nuscenes_ignore_label, augment= _args.augment ) 112 | train_loader = torch.utils.data.DataLoader( train_set, batch_size=_args.train_batch_size, num_workers= _args.num_workers, sampler=DistributedSampler(train_set) ) 113 | 114 | val_set = NuScenesDataSet( root = _args.nuscenes_root, split='val', 115 | transform = val_trainsform, ignore_label= _args.nuscenes_ignore_label, augment= _args.augment ) 116 | val_loader = torch.utils.data.DataLoader( val_set, batch_size=_args.val_batch_size, num_workers=_args.num_workers ) 117 | elif _args.dataset == 'SemanticKitti': 118 | train_set = KittiDataset( root = _args.kitty_root, split='train', 119 | transform = train_transform, ignore_label= _args.kitti_ignore_label, augment= _args.augment ) 120 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=_args.train_batch_size, num_workers= _args.num_workers, sampler=DistributedSampler(train_set) ) 121 | 122 | val_set = KittiDataset( root = _args.kitty_root, split='val', 123 | transform = val_trainsform, ignore_label= _args.kitti_ignore_label, augment= _args.augment 124 | ) 125 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=_args.train_batch_size, num_workers= _args.num_workers) #, sampler=DistributedSampler(val_set) 126 | 127 | 128 | #SummaryWriter 129 | if _local_rank == 0: 130 | train_writer = SummaryWriter(osp.join(_log_path, 'train'), flush_secs=60) 131 | val_writer = SummaryWriter(osp.join(_log_path, 'val'), flush_secs=60) 132 | 133 | #model 134 | criteria = Loss( _args ) 135 | criteria.to( _device ) 136 | 137 | if torch.cuda.device_count() > 1: 138 | criteria = torch.nn.parallel.DistributedDataParallel(criteria, 139 | device_ids=[_local_rank], 140 | output_device=_local_rank, 141 | find_unused_parameters=True) 142 | 143 | #optimizer 144 | optimizer = torch.optim.Adam( criteria.parameters(), lr= _args.lr ) 145 | scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=_args.scheduler_step_size, gamma=_args.scheduler_gamma ) 146 | 147 | #checkpoints 148 | global_step = 0 149 | saver = CheckPointManager( _args.save_checkpoints_path, max_to_keep = 1, keep_checkpoint_every_n_hours = 0.1 ) 150 | 151 | if osp.exists( _args.load_checkpoints_path ): 152 | global_step = saver.load( _args.load_checkpoints_path, criteria, optimizer ) 153 | 154 | if _local_rank == 0: 155 | steps_per_epoch = len(train_loader) 156 | if _args.validate_every < 0: 157 | _args.validate_every = abs(_args.validate_every) * steps_per_epoch 158 | if _args.summary_every < 0: 159 | _args.summary_every = abs(_args.summary_every) * steps_per_epoch 160 | 161 | #model training 162 | criteria.train() 163 | for epoch in range(_args.epoch_num): 164 | if _local_rank == 0: 165 | tbar = tqdm(total=len(train_loader), ncols=100) 166 | for data in train_loader: 167 | optimizer.zero_grad() 168 | dict_all_to_device( data, _device ) 169 | 170 | _, losses = criteria( data ) 171 | losses['total'].backward() 172 | if validate_gradient( criteria ): 173 | optimizer.step() 174 | else: 175 | print("gradient not valid") 176 | 177 | global_step += 1 178 | avg_total_loss = reduce_tensor(losses['total']).item() 179 | if _local_rank == 0: 180 | tbar.set_description('Epoch:{:.4g} Loss:{:.4g}'.format(epoch, avg_total_loss )) 181 | tbar.update(1) 182 | 183 | if global_step % _args.validate_every == 0: 184 | criteria.eval() 185 | score = validate( val_loader, criteria, val_writer, global_step ) 186 | saver.save( criteria, optimizer, global_step, score = score ) 187 | criteria.train() 188 | if global_step % _args.summary_every == 0: 189 | save_summaries( train_writer, losses, step = global_step ) 190 | scheduler.step() 191 | if _local_rank == 0: 192 | tbar.close() 193 | 194 | if __name__ == '__main__': 195 | if _local_rank == 0: 196 | _logger, _log_path = prepare_logger(_args) 197 | main() 198 | 199 | --------------------------------------------------------------------------------