├── .github ├── approach.jpg └── decoder_detections.jpg ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── criterion.py ├── datasets ├── __init__.py ├── scannet.py └── sunrgbd.py ├── engine.py ├── main.py ├── models ├── __init__.py ├── helpers.py ├── model_3detr.py ├── position_embedding.py └── transformer.py ├── optimizer.py ├── scripts ├── scannet_ep1080.sh ├── scannet_masked_ep1080.sh ├── scannet_masked_ep1080_color.sh ├── scannet_quick.sh ├── sunrgbd_ep1080.sh ├── sunrgbd_masked_ep1080.sh └── sunrgbd_quick.sh ├── third_party └── pointnet2 │ ├── _ext_src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ ├── pointnet2_test.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py └── utils ├── ap_calculator.py ├── box_intersection.pyx ├── box_ops3d.py ├── box_util.py ├── cython_compile.py ├── cython_compile.sh ├── dist.py ├── download_weights.py ├── eval_det.py ├── io.py ├── logger.py ├── misc.py ├── nms.py ├── pc_util.py └── random_cuboid.py /.github/approach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/3detr/a5d1aadfdde245d68fc1bde9e67a1d344d4ecc47/.github/approach.jpg -------------------------------------------------------------------------------- /.github/decoder_detections.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/3detr/a5d1aadfdde245d68fc1bde9e67a1d344d4ecc47/.github/decoder_detections.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | third_party/* 4 | 5 | # generated files 6 | utils/box_intersection.c 7 | utils/*so 8 | utils/build/* 9 | 10 | # outputs 11 | outputs/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to 3DETR 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 4 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) 36 | 37 | ## License 38 | By contributing to 3DETR, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3DETR: An End-to-End Transformer Model for 3D Object Detection 2 | 3 | PyTorch implementation and models for **3DETR**. 4 | 5 | **3DETR** (**3D** **DE**tection **TR**ansformer) is a simpler alternative to complex hand-crafted 3D detection pipelines. 6 | It does not rely on 3D backbones such as PointNet++ and uses few 3D-specific operators. 7 | 3DETR obtains comparable or better performance than 3D detection methods such as VoteNet. 8 | The encoder can also be used for other 3D tasks such as shape classification. 9 | More details in the paper ["An End-to-End Transformer Model for 3D Object Detection"](http://arxiv.org/abs/2109.08141). 10 | 11 | [[`website`](https://facebookresearch.github.io/3detr)] [[`arXiv`](http://arxiv.org/abs/2109.08141)] [[`bibtex`](#Citation)] 12 | 13 | **Code description.** Our code is based on prior work such as DETR and VoteNet and we aim for simplicity in our implementation. We hope it can ease research in 3D detection. 14 | 15 | ![3DETR Approach](.github/approach.jpg) 16 | ![Decoder Detections](.github/decoder_detections.jpg) 17 | 18 | # Pretrained Models 19 | 20 | We provide the pretrained model weights and the corresponding metrics on the val set (per class APs, Recalls). 21 | We provide a Python script [`utils/download_weights.py`](utils/download_weights.py) to easily download the weights/metrics files. 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 |
ArchDatasetEpochsAP25AP50Model weightsEval metrics
3DETR-mSUN RGB-D108059.130.3weightsmetrics
3DETRSUN RGB-D108058.030.3weightsmetrics
3DETR-mScanNet108065.047.0weightsmetrics
3DETRScanNet108062.137.9weightsmetrics
70 | 71 | ## Model Zoo 72 | 73 | For convenience, we provide model weights for 3DETR trained for different number of epochs. 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 |
ArchDatasetEpochsAP25AP50Model weightsEval metrics
3DETR-mSUN RGB-D9051.022.0weightsmetrics
3DETR-mSUN RGB-D18055.627.5weightsmetrics
3DETR-mSUN RGB-D36058.230.6weightsmetrics
3DETR-mSUN RGB-D72058.130.4weightsmetrics
3DETRSUN RGB-D9043.716.2weightsmetrics
3DETRSUN RGB-D18052.125.8weightsmetrics
3DETRSUN RGB-D36056.329.6weightsmetrics
3DETRSUN RGB-D72056.027.8weightsmetrics
3DETR-mScanNet9047.119.5weightsmetrics
3DETR-mScanNet18058.733.6weightsmetrics
3DETR-mScanNet36062.437.7weightsmetrics
3DETR-mScanNet72063.744.5weightsmetrics
3DETRScanNet9042.815.3weightsmetrics
3DETRScanNet18054.528.8weightsmetrics
3DETRScanNet36059.035.4weightsmetrics
3DETRScanNet72061.140.2weightsmetrics
230 | 231 | 232 | # Running 3DETR 233 | 234 | ## Installation 235 | Our code is tested with PyTorch 1.9.0, CUDA 10.2 and Python 3.6. It may work with other versions. 236 | 237 | You will need to install `pointnet2` layers by running 238 | 239 | ``` 240 | cd third_party/pointnet2 && python setup.py install 241 | ``` 242 | 243 | You will also need Python dependencies (either `conda install` or `pip install`) 244 | 245 | ``` 246 | matplotlib 247 | opencv-python 248 | plyfile 249 | 'trimesh>=2.35.39,<2.35.40' 250 | 'networkx>=2.2,<2.3' 251 | scipy 252 | ``` 253 | 254 | Some users have experienced issues using CUDA 11 or higher. Please try using CUDA 10.2 if you run into CUDA issues. 255 | 256 | **Optionally**, you can install a Cythonized implementation of gIOU for faster training. 257 | ``` 258 | conda install cython 259 | cd utils && python cython_compile.py build_ext --inplace 260 | ``` 261 | 262 | 263 | # Benchmarking 264 | 265 | ## Dataset preparation 266 | 267 | We follow the VoteNet codebase for preprocessing our data. 268 | The instructions for preprocessing SUN RGB-D are [here](https://github.com/facebookresearch/votenet/tree/main/sunrgbd) and ScanNet are [here](https://github.com/facebookresearch/votenet/tree/main/scannet). 269 | 270 | You can edit the dataset paths in [`datasets/sunrgbd.py`](datasets/sunrgbd.py#L36) and [`datasets/scannet.py`](datasets/scannet.py#L23-L24) or choose to specify at runtime. 271 | 272 | ## Testing 273 | 274 | Once you have the datasets prepared, you can test pretrained models as 275 | 276 | ``` 277 | python main.py --dataset_name --nqueries --test_ckpt --test_only [--enc_type masked] 278 | ``` 279 | 280 | We use 128 queries for the SUN RGB-D dataset and 256 queries for the ScanNet dataset. 281 | You will need to add the flag `--enc_type masked` when testing the 3DETR-m checkpoints. 282 | Please note that the testing process is stochastic (due to randomness in point cloud sampling and sampling the queries) and so results can vary within 1% AP25 across runs. 283 | This stochastic nature of the inference process is also common for methods such as VoteNet. 284 | 285 | If you have not edited the dataset paths for the files in the `datasets` folder, you can pass the path to the datasets using the `--dataset_root_dir` flag. 286 | 287 | ## Training 288 | 289 | The model can be simply trained by running `main.py`. 290 | ``` 291 | python main.py --dataset_name --checkpoint_dir 292 | ``` 293 | 294 | To reproduce the results in the paper, we provide the arguments in the [`scripts`](scripts/) folder. 295 | A variance of 1% AP25 across different training runs can be expected. 296 | 297 | You can quickly verify your installation by training a 3DETR model for 90 epochs on ScanNet following the file `scripts/scannet_quick.sh` and compare it to the pretrained checkpoint from the Model Zoo. 298 | 299 | 300 | ## License 301 | The majority of 3DETR is licensed under the Apache 2.0 license as found in the [LICENSE](LICENSE) file, however portions of the project are available under separate license terms: licensing information for pointnet2 is available at https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/UNLICENSE 302 | 303 | ## Contributing 304 | We welcome your pull requests! Please see [CONTRIBUTING](CONTRIBUTING.md) and [CODE_OF_CONDUCT](CODE_OF_CONDUCT.md) for more info. 305 | 306 | ## Citation 307 | If you find this repository useful, please consider starring :star: us and citing 308 | 309 | ``` 310 | @inproceedings{misra2021-3detr, 311 | title={{An End-to-End Transformer Model for 3D Object Detection}}, 312 | author={Misra, Ishan and Girdhar, Rohit and Joulin, Armand}, 313 | booktitle={{ICCV}}, 314 | year={2021}, 315 | } 316 | ``` 317 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .scannet import ScannetDetectionDataset, ScannetDatasetConfig 3 | from .sunrgbd import SunrgbdDetectionDataset, SunrgbdDatasetConfig 4 | 5 | 6 | DATASET_FUNCTIONS = { 7 | "scannet": [ScannetDetectionDataset, ScannetDatasetConfig], 8 | "sunrgbd": [SunrgbdDetectionDataset, SunrgbdDatasetConfig], 9 | } 10 | 11 | 12 | def build_dataset(args): 13 | dataset_builder = DATASET_FUNCTIONS[args.dataset_name][0] 14 | dataset_config = DATASET_FUNCTIONS[args.dataset_name][1]() 15 | 16 | dataset_dict = { 17 | "train": dataset_builder( 18 | dataset_config, 19 | split_set="train", 20 | root_dir=args.dataset_root_dir, 21 | meta_data_dir=args.meta_data_dir, 22 | use_color=args.use_color, 23 | augment=True 24 | ), 25 | "test": dataset_builder( 26 | dataset_config, 27 | split_set="val", 28 | root_dir=args.dataset_root_dir, 29 | use_color=args.use_color, 30 | augment=False 31 | ), 32 | } 33 | return dataset_dict, dataset_config 34 | -------------------------------------------------------------------------------- /datasets/scannet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """ 4 | Modified from https://github.com/facebookresearch/votenet 5 | Dataset for object bounding box regression. 6 | An axis aligned bounding box is parameterized by (cx,cy,cz) and (dx,dy,dz) 7 | where (cx,cy,cz) is the center point of the box, dx is the x-axis length of the box. 8 | """ 9 | import os 10 | import sys 11 | 12 | import numpy as np 13 | import torch 14 | import utils.pc_util as pc_util 15 | from torch.utils.data import Dataset 16 | from utils.box_util import (flip_axis_to_camera_np, flip_axis_to_camera_tensor, 17 | get_3d_box_batch_np, get_3d_box_batch_tensor) 18 | from utils.pc_util import scale_points, shift_scale_points 19 | from utils.random_cuboid import RandomCuboid 20 | 21 | IGNORE_LABEL = -100 22 | MEAN_COLOR_RGB = np.array([109.8, 97.2, 83.8]) 23 | DATASET_ROOT_DIR = "" ## Replace with path to dataset 24 | DATASET_METADATA_DIR = "" ## Replace with path to dataset 25 | 26 | 27 | class ScannetDatasetConfig(object): 28 | def __init__(self): 29 | self.num_semcls = 18 30 | self.num_angle_bin = 1 31 | self.max_num_obj = 64 32 | 33 | self.type2class = { 34 | "cabinet": 0, 35 | "bed": 1, 36 | "chair": 2, 37 | "sofa": 3, 38 | "table": 4, 39 | "door": 5, 40 | "window": 6, 41 | "bookshelf": 7, 42 | "picture": 8, 43 | "counter": 9, 44 | "desk": 10, 45 | "curtain": 11, 46 | "refrigerator": 12, 47 | "showercurtrain": 13, 48 | "toilet": 14, 49 | "sink": 15, 50 | "bathtub": 16, 51 | "garbagebin": 17, 52 | } 53 | self.class2type = {self.type2class[t]: t for t in self.type2class} 54 | self.nyu40ids = np.array( 55 | [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 56 | ) 57 | self.nyu40id2class = { 58 | nyu40id: i for i, nyu40id in enumerate(list(self.nyu40ids)) 59 | } 60 | 61 | # Semantic Segmentation Classes. Not used in 3DETR 62 | self.num_class_semseg = 20 63 | self.type2class_semseg = { 64 | "wall": 0, 65 | "floor": 1, 66 | "cabinet": 2, 67 | "bed": 3, 68 | "chair": 4, 69 | "sofa": 5, 70 | "table": 6, 71 | "door": 7, 72 | "window": 8, 73 | "bookshelf": 9, 74 | "picture": 10, 75 | "counter": 11, 76 | "desk": 12, 77 | "curtain": 13, 78 | "refrigerator": 14, 79 | "showercurtrain": 15, 80 | "toilet": 16, 81 | "sink": 17, 82 | "bathtub": 18, 83 | "garbagebin": 19, 84 | } 85 | self.class2type_semseg = { 86 | self.type2class_semseg[t]: t for t in self.type2class_semseg 87 | } 88 | self.nyu40ids_semseg = np.array( 89 | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 90 | ) 91 | self.nyu40id2class_semseg = { 92 | nyu40id: i for i, nyu40id in enumerate(list(self.nyu40ids_semseg)) 93 | } 94 | 95 | def angle2class(self, angle): 96 | raise ValueError("ScanNet does not have rotated bounding boxes.") 97 | 98 | def class2anglebatch_tensor(self, pred_cls, residual, to_label_format=True): 99 | zero_angle = torch.zeros( 100 | (pred_cls.shape[0], pred_cls.shape[1]), 101 | dtype=torch.float32, 102 | device=pred_cls.device, 103 | ) 104 | return zero_angle 105 | 106 | def class2anglebatch(self, pred_cls, residual, to_label_format=True): 107 | zero_angle = np.zeros(pred_cls.shape[0], dtype=np.float32) 108 | return zero_angle 109 | 110 | def param2obb( 111 | self, 112 | center, 113 | heading_class, 114 | heading_residual, 115 | size_class, 116 | size_residual, 117 | box_size=None, 118 | ): 119 | heading_angle = self.class2angle(heading_class, heading_residual) 120 | if box_size is None: 121 | box_size = self.class2size(int(size_class), size_residual) 122 | obb = np.zeros((7,)) 123 | obb[0:3] = center 124 | obb[3:6] = box_size 125 | obb[6] = heading_angle * -1 126 | return obb 127 | 128 | def box_parametrization_to_corners(self, box_center_unnorm, box_size, box_angle): 129 | box_center_upright = flip_axis_to_camera_tensor(box_center_unnorm) 130 | boxes = get_3d_box_batch_tensor(box_size, box_angle, box_center_upright) 131 | return boxes 132 | 133 | def box_parametrization_to_corners_np(self, box_center_unnorm, box_size, box_angle): 134 | box_center_upright = flip_axis_to_camera_np(box_center_unnorm) 135 | boxes = get_3d_box_batch_np(box_size, box_angle, box_center_upright) 136 | return boxes 137 | 138 | @staticmethod 139 | def rotate_aligned_boxes(input_boxes, rot_mat): 140 | centers, lengths = input_boxes[:, 0:3], input_boxes[:, 3:6] 141 | new_centers = np.dot(centers, np.transpose(rot_mat)) 142 | 143 | dx, dy = lengths[:, 0] / 2.0, lengths[:, 1] / 2.0 144 | new_x = np.zeros((dx.shape[0], 4)) 145 | new_y = np.zeros((dx.shape[0], 4)) 146 | 147 | for i, crnr in enumerate([(-1, -1), (1, -1), (1, 1), (-1, 1)]): 148 | crnrs = np.zeros((dx.shape[0], 3)) 149 | crnrs[:, 0] = crnr[0] * dx 150 | crnrs[:, 1] = crnr[1] * dy 151 | crnrs = np.dot(crnrs, np.transpose(rot_mat)) 152 | new_x[:, i] = crnrs[:, 0] 153 | new_y[:, i] = crnrs[:, 1] 154 | 155 | new_dx = 2.0 * np.max(new_x, 1) 156 | new_dy = 2.0 * np.max(new_y, 1) 157 | new_lengths = np.stack((new_dx, new_dy, lengths[:, 2]), axis=1) 158 | 159 | return np.concatenate([new_centers, new_lengths], axis=1) 160 | 161 | 162 | class ScannetDetectionDataset(Dataset): 163 | def __init__( 164 | self, 165 | dataset_config, 166 | split_set="train", 167 | root_dir=None, 168 | meta_data_dir=None, 169 | num_points=40000, 170 | use_color=False, 171 | use_height=False, 172 | augment=False, 173 | use_random_cuboid=True, 174 | random_cuboid_min_points=30000, 175 | ): 176 | 177 | self.dataset_config = dataset_config 178 | assert split_set in ["train", "val"] 179 | if root_dir is None: 180 | root_dir = DATASET_ROOT_DIR 181 | 182 | if meta_data_dir is None: 183 | meta_data_dir = DATASET_METADATA_DIR 184 | 185 | self.data_path = root_dir 186 | all_scan_names = list( 187 | set( 188 | [ 189 | os.path.basename(x)[0:12] 190 | for x in os.listdir(self.data_path) 191 | if x.startswith("scene") 192 | ] 193 | ) 194 | ) 195 | if split_set == "all": 196 | self.scan_names = all_scan_names 197 | elif split_set in ["train", "val", "test"]: 198 | split_filenames = os.path.join(meta_data_dir, f"scannetv2_{split_set}.txt") 199 | with open(split_filenames, "r") as f: 200 | self.scan_names = f.read().splitlines() 201 | # remove unavailiable scans 202 | num_scans = len(self.scan_names) 203 | self.scan_names = [ 204 | sname for sname in self.scan_names if sname in all_scan_names 205 | ] 206 | print(f"kept {len(self.scan_names)} scans out of {num_scans}") 207 | else: 208 | raise ValueError(f"Unknown split name {split_set}") 209 | 210 | self.num_points = num_points 211 | self.use_color = use_color 212 | self.use_height = use_height 213 | self.augment = augment 214 | self.use_random_cuboid = use_random_cuboid 215 | self.random_cuboid_augmentor = RandomCuboid(min_points=random_cuboid_min_points) 216 | self.center_normalizing_range = [ 217 | np.zeros((1, 3), dtype=np.float32), 218 | np.ones((1, 3), dtype=np.float32), 219 | ] 220 | 221 | def __len__(self): 222 | return len(self.scan_names) 223 | 224 | def __getitem__(self, idx): 225 | scan_name = self.scan_names[idx] 226 | mesh_vertices = np.load(os.path.join(self.data_path, scan_name) + "_vert.npy") 227 | instance_labels = np.load( 228 | os.path.join(self.data_path, scan_name) + "_ins_label.npy" 229 | ) 230 | semantic_labels = np.load( 231 | os.path.join(self.data_path, scan_name) + "_sem_label.npy" 232 | ) 233 | instance_bboxes = np.load(os.path.join(self.data_path, scan_name) + "_bbox.npy") 234 | 235 | if not self.use_color: 236 | point_cloud = mesh_vertices[:, 0:3] # do not use color for now 237 | pcl_color = mesh_vertices[:, 3:6] 238 | else: 239 | point_cloud = mesh_vertices[:, 0:6] 240 | point_cloud[:, 3:] = (point_cloud[:, 3:] - MEAN_COLOR_RGB) / 256.0 241 | pcl_color = point_cloud[:, 3:] 242 | 243 | if self.use_height: 244 | floor_height = np.percentile(point_cloud[:, 2], 0.99) 245 | height = point_cloud[:, 2] - floor_height 246 | point_cloud = np.concatenate([point_cloud, np.expand_dims(height, 1)], 1) 247 | 248 | # ------------------------------- LABELS ------------------------------ 249 | MAX_NUM_OBJ = self.dataset_config.max_num_obj 250 | target_bboxes = np.zeros((MAX_NUM_OBJ, 6), dtype=np.float32) 251 | target_bboxes_mask = np.zeros((MAX_NUM_OBJ), dtype=np.float32) 252 | angle_classes = np.zeros((MAX_NUM_OBJ,), dtype=np.int64) 253 | angle_residuals = np.zeros((MAX_NUM_OBJ,), dtype=np.float32) 254 | raw_sizes = np.zeros((MAX_NUM_OBJ, 3), dtype=np.float32) 255 | raw_angles = np.zeros((MAX_NUM_OBJ,), dtype=np.float32) 256 | 257 | if self.augment and self.use_random_cuboid: 258 | ( 259 | point_cloud, 260 | instance_bboxes, 261 | per_point_labels, 262 | ) = self.random_cuboid_augmentor( 263 | point_cloud, instance_bboxes, [instance_labels, semantic_labels] 264 | ) 265 | instance_labels = per_point_labels[0] 266 | semantic_labels = per_point_labels[1] 267 | 268 | point_cloud, choices = pc_util.random_sampling( 269 | point_cloud, self.num_points, return_choices=True 270 | ) 271 | instance_labels = instance_labels[choices] 272 | semantic_labels = semantic_labels[choices] 273 | 274 | sem_seg_labels = np.ones_like(semantic_labels) * IGNORE_LABEL 275 | 276 | for _c in self.dataset_config.nyu40ids_semseg: 277 | sem_seg_labels[ 278 | semantic_labels == _c 279 | ] = self.dataset_config.nyu40id2class_semseg[_c] 280 | 281 | pcl_color = pcl_color[choices] 282 | 283 | target_bboxes_mask[0 : instance_bboxes.shape[0]] = 1 284 | target_bboxes[0 : instance_bboxes.shape[0], :] = instance_bboxes[:, 0:6] 285 | 286 | # ------------------------------- DATA AUGMENTATION ------------------------------ 287 | if self.augment: 288 | 289 | if np.random.random() > 0.5: 290 | # Flipping along the YZ plane 291 | point_cloud[:, 0] = -1 * point_cloud[:, 0] 292 | target_bboxes[:, 0] = -1 * target_bboxes[:, 0] 293 | 294 | if np.random.random() > 0.5: 295 | # Flipping along the XZ plane 296 | point_cloud[:, 1] = -1 * point_cloud[:, 1] 297 | target_bboxes[:, 1] = -1 * target_bboxes[:, 1] 298 | 299 | # Rotation along up-axis/Z-axis 300 | rot_angle = (np.random.random() * np.pi / 18) - np.pi / 36 # -5 ~ +5 degree 301 | rot_mat = pc_util.rotz(rot_angle) 302 | point_cloud[:, 0:3] = np.dot(point_cloud[:, 0:3], np.transpose(rot_mat)) 303 | target_bboxes = self.dataset_config.rotate_aligned_boxes( 304 | target_bboxes, rot_mat 305 | ) 306 | 307 | raw_sizes = target_bboxes[:, 3:6] 308 | point_cloud_dims_min = point_cloud.min(axis=0)[:3] 309 | point_cloud_dims_max = point_cloud.max(axis=0)[:3] 310 | 311 | box_centers = target_bboxes.astype(np.float32)[:, 0:3] 312 | box_centers_normalized = shift_scale_points( 313 | box_centers[None, ...], 314 | src_range=[ 315 | point_cloud_dims_min[None, ...], 316 | point_cloud_dims_max[None, ...], 317 | ], 318 | dst_range=self.center_normalizing_range, 319 | ) 320 | box_centers_normalized = box_centers_normalized.squeeze(0) 321 | box_centers_normalized = box_centers_normalized * target_bboxes_mask[..., None] 322 | mult_factor = point_cloud_dims_max - point_cloud_dims_min 323 | box_sizes_normalized = scale_points( 324 | raw_sizes.astype(np.float32)[None, ...], 325 | mult_factor=1.0 / mult_factor[None, ...], 326 | ) 327 | box_sizes_normalized = box_sizes_normalized.squeeze(0) 328 | 329 | box_corners = self.dataset_config.box_parametrization_to_corners_np( 330 | box_centers[None, ...], 331 | raw_sizes.astype(np.float32)[None, ...], 332 | raw_angles.astype(np.float32)[None, ...], 333 | ) 334 | box_corners = box_corners.squeeze(0) 335 | 336 | ret_dict = {} 337 | ret_dict["point_clouds"] = point_cloud.astype(np.float32) 338 | ret_dict["gt_box_corners"] = box_corners.astype(np.float32) 339 | ret_dict["gt_box_centers"] = box_centers.astype(np.float32) 340 | ret_dict["gt_box_centers_normalized"] = box_centers_normalized.astype( 341 | np.float32 342 | ) 343 | ret_dict["gt_angle_class_label"] = angle_classes.astype(np.int64) 344 | ret_dict["gt_angle_residual_label"] = angle_residuals.astype(np.float32) 345 | target_bboxes_semcls = np.zeros((MAX_NUM_OBJ)) 346 | target_bboxes_semcls[0 : instance_bboxes.shape[0]] = [ 347 | self.dataset_config.nyu40id2class[int(x)] 348 | for x in instance_bboxes[:, -1][0 : instance_bboxes.shape[0]] 349 | ] 350 | ret_dict["gt_box_sem_cls_label"] = target_bboxes_semcls.astype(np.int64) 351 | ret_dict["gt_box_present"] = target_bboxes_mask.astype(np.float32) 352 | ret_dict["scan_idx"] = np.array(idx).astype(np.int64) 353 | ret_dict["pcl_color"] = pcl_color 354 | ret_dict["gt_box_sizes"] = raw_sizes.astype(np.float32) 355 | ret_dict["gt_box_sizes_normalized"] = box_sizes_normalized.astype(np.float32) 356 | ret_dict["gt_box_angles"] = raw_angles.astype(np.float32) 357 | ret_dict["point_cloud_dims_min"] = point_cloud_dims_min.astype(np.float32) 358 | ret_dict["point_cloud_dims_max"] = point_cloud_dims_max.astype(np.float32) 359 | return ret_dict 360 | -------------------------------------------------------------------------------- /datasets/sunrgbd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | """ 5 | Modified from https://github.com/facebookresearch/votenet 6 | Dataset for 3D object detection on SUN RGB-D (with support of vote supervision). 7 | 8 | A sunrgbd oriented bounding box is parameterized by (cx,cy,cz), (l,w,h) -- (dx,dy,dz) in upright depth coord 9 | (Z is up, Y is forward, X is right ward), heading angle (from +X rotating to -Y) and semantic class 10 | 11 | Point clouds are in **upright_depth coordinate (X right, Y forward, Z upward)** 12 | Return heading class, heading residual, size class and size residual for 3D bounding boxes. 13 | Oriented bounding box is parameterized by (cx,cy,cz), (l,w,h), heading_angle and semantic class label. 14 | (cx,cy,cz) is in upright depth coordinate 15 | (l,h,w) are *half length* of the object sizes 16 | The heading angle is a rotation rad from +X rotating towards -Y. (+X is 0, -Y is pi/2) 17 | 18 | Author: Charles R. Qi 19 | Date: 2019 20 | 21 | """ 22 | import os 23 | import sys 24 | import numpy as np 25 | from torch.utils.data import Dataset 26 | import scipy.io as sio # to load .mat files for depth points 27 | 28 | import utils.pc_util as pc_util 29 | from utils.random_cuboid import RandomCuboid 30 | from utils.pc_util import shift_scale_points, scale_points 31 | from utils.box_util import ( 32 | flip_axis_to_camera_tensor, 33 | get_3d_box_batch_tensor, 34 | flip_axis_to_camera_np, 35 | get_3d_box_batch_np, 36 | ) 37 | 38 | 39 | MEAN_COLOR_RGB = np.array([0.5, 0.5, 0.5]) # sunrgbd color is in 0~1 40 | DATA_PATH_V1 = "" ## Replace with path to dataset 41 | DATA_PATH_V2 = "" ## Not used in the codebase. 42 | 43 | 44 | class SunrgbdDatasetConfig(object): 45 | def __init__(self): 46 | self.num_semcls = 10 47 | self.num_angle_bin = 12 48 | self.max_num_obj = 64 49 | self.type2class = { 50 | "bed": 0, 51 | "table": 1, 52 | "sofa": 2, 53 | "chair": 3, 54 | "toilet": 4, 55 | "desk": 5, 56 | "dresser": 6, 57 | "night_stand": 7, 58 | "bookshelf": 8, 59 | "bathtub": 9, 60 | } 61 | self.class2type = {self.type2class[t]: t for t in self.type2class} 62 | self.type2onehotclass = { 63 | "bed": 0, 64 | "table": 1, 65 | "sofa": 2, 66 | "chair": 3, 67 | "toilet": 4, 68 | "desk": 5, 69 | "dresser": 6, 70 | "night_stand": 7, 71 | "bookshelf": 8, 72 | "bathtub": 9, 73 | } 74 | 75 | def angle2class(self, angle): 76 | """Convert continuous angle to discrete class 77 | [optinal] also small regression number from 78 | class center angle to current angle. 79 | 80 | angle is from 0-2pi (or -pi~pi), class center at 0, 1*(2pi/N), 2*(2pi/N) ... (N-1)*(2pi/N) 81 | returns class [0,1,...,N-1] and a residual number such that 82 | class*(2pi/N) + number = angle 83 | """ 84 | num_class = self.num_angle_bin 85 | angle = angle % (2 * np.pi) 86 | assert angle >= 0 and angle <= 2 * np.pi 87 | angle_per_class = 2 * np.pi / float(num_class) 88 | shifted_angle = (angle + angle_per_class / 2) % (2 * np.pi) 89 | class_id = int(shifted_angle / angle_per_class) 90 | residual_angle = shifted_angle - ( 91 | class_id * angle_per_class + angle_per_class / 2 92 | ) 93 | return class_id, residual_angle 94 | 95 | def class2angle(self, pred_cls, residual, to_label_format=True): 96 | """Inverse function to angle2class""" 97 | num_class = self.num_angle_bin 98 | angle_per_class = 2 * np.pi / float(num_class) 99 | angle_center = pred_cls * angle_per_class 100 | angle = angle_center + residual 101 | if to_label_format and angle > np.pi: 102 | angle = angle - 2 * np.pi 103 | return angle 104 | 105 | def class2angle_batch(self, pred_cls, residual, to_label_format=True): 106 | num_class = self.num_angle_bin 107 | angle_per_class = 2 * np.pi / float(num_class) 108 | angle_center = pred_cls * angle_per_class 109 | angle = angle_center + residual 110 | if to_label_format: 111 | mask = angle > np.pi 112 | angle[mask] = angle[mask] - 2 * np.pi 113 | return angle 114 | 115 | def class2anglebatch_tensor(self, pred_cls, residual, to_label_format=True): 116 | return self.class2angle_batch(pred_cls, residual, to_label_format) 117 | 118 | def box_parametrization_to_corners(self, box_center_unnorm, box_size, box_angle): 119 | box_center_upright = flip_axis_to_camera_tensor(box_center_unnorm) 120 | boxes = get_3d_box_batch_tensor(box_size, box_angle, box_center_upright) 121 | return boxes 122 | 123 | def box_parametrization_to_corners_np(self, box_center_unnorm, box_size, box_angle): 124 | box_center_upright = flip_axis_to_camera_np(box_center_unnorm) 125 | boxes = get_3d_box_batch_np(box_size, box_angle, box_center_upright) 126 | return boxes 127 | 128 | def my_compute_box_3d(self, center, size, heading_angle): 129 | R = pc_util.rotz(-1 * heading_angle) 130 | l, w, h = size 131 | x_corners = [-l, l, l, -l, -l, l, l, -l] 132 | y_corners = [w, w, -w, -w, w, w, -w, -w] 133 | z_corners = [h, h, h, h, -h, -h, -h, -h] 134 | corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners])) 135 | corners_3d[0, :] += center[0] 136 | corners_3d[1, :] += center[1] 137 | corners_3d[2, :] += center[2] 138 | return np.transpose(corners_3d) 139 | 140 | 141 | class SunrgbdDetectionDataset(Dataset): 142 | def __init__( 143 | self, 144 | dataset_config, 145 | split_set="train", 146 | root_dir=None, 147 | num_points=20000, 148 | use_color=False, 149 | use_height=False, 150 | use_v1=True, 151 | augment=False, 152 | use_random_cuboid=True, 153 | random_cuboid_min_points=30000, 154 | ): 155 | assert num_points <= 50000 156 | assert split_set in ["train", "val", "trainval"] 157 | self.dataset_config = dataset_config 158 | self.use_v1 = use_v1 159 | 160 | if root_dir is None: 161 | root_dir = DATA_PATH_V1 if use_v1 else DATA_PATH_V2 162 | 163 | self.data_path = root_dir + "_%s" % (split_set) 164 | 165 | if split_set in ["train", "val"]: 166 | self.scan_names = sorted( 167 | list( 168 | set([os.path.basename(x)[0:6] for x in os.listdir(self.data_path)]) 169 | ) 170 | ) 171 | elif split_set in ["trainval"]: 172 | # combine names from both 173 | sub_splits = ["train", "val"] 174 | all_paths = [] 175 | for sub_split in sub_splits: 176 | data_path = self.data_path.replace("trainval", sub_split) 177 | basenames = sorted( 178 | list(set([os.path.basename(x)[0:6] for x in os.listdir(data_path)])) 179 | ) 180 | basenames = [os.path.join(data_path, x) for x in basenames] 181 | all_paths.extend(basenames) 182 | all_paths.sort() 183 | self.scan_names = all_paths 184 | 185 | self.num_points = num_points 186 | self.augment = augment 187 | self.use_color = use_color 188 | self.use_height = use_height 189 | self.use_random_cuboid = use_random_cuboid 190 | self.random_cuboid_augmentor = RandomCuboid( 191 | min_points=random_cuboid_min_points, 192 | aspect=0.75, 193 | min_crop=0.75, 194 | max_crop=1.0, 195 | ) 196 | self.center_normalizing_range = [ 197 | np.zeros((1, 3), dtype=np.float32), 198 | np.ones((1, 3), dtype=np.float32), 199 | ] 200 | self.max_num_obj = 64 201 | 202 | def __len__(self): 203 | return len(self.scan_names) 204 | 205 | def __getitem__(self, idx): 206 | scan_name = self.scan_names[idx] 207 | if scan_name.startswith("/"): 208 | scan_path = scan_name 209 | else: 210 | scan_path = os.path.join(self.data_path, scan_name) 211 | point_cloud = np.load(scan_path + "_pc.npz")["pc"] # Nx6 212 | bboxes = np.load(scan_path + "_bbox.npy") # K,8 213 | 214 | if not self.use_color: 215 | point_cloud = point_cloud[:, 0:3] 216 | else: 217 | assert point_cloud.shape[1] == 6 218 | point_cloud = point_cloud[:, 0:6] 219 | point_cloud[:, 3:] = point_cloud[:, 3:] - MEAN_COLOR_RGB 220 | 221 | if self.use_height: 222 | floor_height = np.percentile(point_cloud[:, 2], 0.99) 223 | height = point_cloud[:, 2] - floor_height 224 | point_cloud = np.concatenate( 225 | [point_cloud, np.expand_dims(height, 1)], 1 226 | ) # (N,4) or (N,7) 227 | 228 | # ------------------------------- DATA AUGMENTATION ------------------------------ 229 | if self.augment: 230 | if np.random.random() > 0.5: 231 | # Flipping along the YZ plane 232 | point_cloud[:, 0] = -1 * point_cloud[:, 0] 233 | bboxes[:, 0] = -1 * bboxes[:, 0] 234 | bboxes[:, 6] = np.pi - bboxes[:, 6] 235 | 236 | # Rotation along up-axis/Z-axis 237 | rot_angle = (np.random.random() * np.pi / 3) - np.pi / 6 # -30 ~ +30 degree 238 | rot_mat = pc_util.rotz(rot_angle) 239 | 240 | point_cloud[:, 0:3] = np.dot(point_cloud[:, 0:3], np.transpose(rot_mat)) 241 | bboxes[:, 0:3] = np.dot(bboxes[:, 0:3], np.transpose(rot_mat)) 242 | bboxes[:, 6] -= rot_angle 243 | 244 | # Augment RGB color 245 | if self.use_color: 246 | rgb_color = point_cloud[:, 3:6] + MEAN_COLOR_RGB 247 | rgb_color *= ( 248 | 1 + 0.4 * np.random.random(3) - 0.2 249 | ) # brightness change for each channel 250 | rgb_color += ( 251 | 0.1 * np.random.random(3) - 0.05 252 | ) # color shift for each channel 253 | rgb_color += np.expand_dims( 254 | (0.05 * np.random.random(point_cloud.shape[0]) - 0.025), -1 255 | ) # jittering on each pixel 256 | rgb_color = np.clip(rgb_color, 0, 1) 257 | # randomly drop out 30% of the points' colors 258 | rgb_color *= np.expand_dims( 259 | np.random.random(point_cloud.shape[0]) > 0.3, -1 260 | ) 261 | point_cloud[:, 3:6] = rgb_color - MEAN_COLOR_RGB 262 | 263 | # Augment point cloud scale: 0.85x-1.15x 264 | scale_ratio = np.random.random() * 0.3 + 0.85 265 | scale_ratio = np.expand_dims(np.tile(scale_ratio, 3), 0) 266 | point_cloud[:, 0:3] *= scale_ratio 267 | bboxes[:, 0:3] *= scale_ratio 268 | bboxes[:, 3:6] *= scale_ratio 269 | 270 | if self.use_height: 271 | point_cloud[:, -1] *= scale_ratio[0, 0] 272 | 273 | if self.use_random_cuboid: 274 | point_cloud, bboxes, _ = self.random_cuboid_augmentor( 275 | point_cloud, bboxes 276 | ) 277 | 278 | # ------------------------------- LABELS ------------------------------ 279 | angle_classes = np.zeros((self.max_num_obj,), dtype=np.float32) 280 | angle_residuals = np.zeros((self.max_num_obj,), dtype=np.float32) 281 | raw_angles = np.zeros((self.max_num_obj,), dtype=np.float32) 282 | raw_sizes = np.zeros((self.max_num_obj, 3), dtype=np.float32) 283 | label_mask = np.zeros((self.max_num_obj)) 284 | label_mask[0 : bboxes.shape[0]] = 1 285 | max_bboxes = np.zeros((self.max_num_obj, 8)) 286 | max_bboxes[0 : bboxes.shape[0], :] = bboxes 287 | 288 | target_bboxes_mask = label_mask 289 | target_bboxes = np.zeros((self.max_num_obj, 6)) 290 | 291 | for i in range(bboxes.shape[0]): 292 | bbox = bboxes[i] 293 | semantic_class = bbox[7] 294 | raw_angles[i] = bbox[6] % 2 * np.pi 295 | box3d_size = bbox[3:6] * 2 296 | raw_sizes[i, :] = box3d_size 297 | angle_class, angle_residual = self.dataset_config.angle2class(bbox[6]) 298 | angle_classes[i] = angle_class 299 | angle_residuals[i] = angle_residual 300 | corners_3d = self.dataset_config.my_compute_box_3d( 301 | bbox[0:3], bbox[3:6], bbox[6] 302 | ) 303 | # compute axis aligned box 304 | xmin = np.min(corners_3d[:, 0]) 305 | ymin = np.min(corners_3d[:, 1]) 306 | zmin = np.min(corners_3d[:, 2]) 307 | xmax = np.max(corners_3d[:, 0]) 308 | ymax = np.max(corners_3d[:, 1]) 309 | zmax = np.max(corners_3d[:, 2]) 310 | target_bbox = np.array( 311 | [ 312 | (xmin + xmax) / 2, 313 | (ymin + ymax) / 2, 314 | (zmin + zmax) / 2, 315 | xmax - xmin, 316 | ymax - ymin, 317 | zmax - zmin, 318 | ] 319 | ) 320 | target_bboxes[i, :] = target_bbox 321 | 322 | point_cloud, choices = pc_util.random_sampling( 323 | point_cloud, self.num_points, return_choices=True 324 | ) 325 | 326 | point_cloud_dims_min = point_cloud.min(axis=0) 327 | point_cloud_dims_max = point_cloud.max(axis=0) 328 | 329 | mult_factor = point_cloud_dims_max - point_cloud_dims_min 330 | box_sizes_normalized = scale_points( 331 | raw_sizes.astype(np.float32)[None, ...], 332 | mult_factor=1.0 / mult_factor[None, ...], 333 | ) 334 | box_sizes_normalized = box_sizes_normalized.squeeze(0) 335 | 336 | box_centers = target_bboxes.astype(np.float32)[:, 0:3] 337 | box_centers_normalized = shift_scale_points( 338 | box_centers[None, ...], 339 | src_range=[ 340 | point_cloud_dims_min[None, ...], 341 | point_cloud_dims_max[None, ...], 342 | ], 343 | dst_range=self.center_normalizing_range, 344 | ) 345 | box_centers_normalized = box_centers_normalized.squeeze(0) 346 | box_centers_normalized = box_centers_normalized * target_bboxes_mask[..., None] 347 | 348 | # re-encode angles to be consistent with VoteNet eval 349 | angle_classes = angle_classes.astype(np.int64) 350 | angle_residuals = angle_residuals.astype(np.float32) 351 | raw_angles = self.dataset_config.class2angle_batch( 352 | angle_classes, angle_residuals 353 | ) 354 | 355 | box_corners = self.dataset_config.box_parametrization_to_corners_np( 356 | box_centers[None, ...], 357 | raw_sizes.astype(np.float32)[None, ...], 358 | raw_angles.astype(np.float32)[None, ...], 359 | ) 360 | box_corners = box_corners.squeeze(0) 361 | 362 | ret_dict = {} 363 | ret_dict["point_clouds"] = point_cloud.astype(np.float32) 364 | ret_dict["gt_box_corners"] = box_corners.astype(np.float32) 365 | ret_dict["gt_box_centers"] = box_centers.astype(np.float32) 366 | ret_dict["gt_box_centers_normalized"] = box_centers_normalized.astype( 367 | np.float32 368 | ) 369 | target_bboxes_semcls = np.zeros((self.max_num_obj)) 370 | target_bboxes_semcls[0 : bboxes.shape[0]] = bboxes[:, -1] # from 0 to 9 371 | ret_dict["gt_box_sem_cls_label"] = target_bboxes_semcls.astype(np.int64) 372 | ret_dict["gt_box_present"] = target_bboxes_mask.astype(np.float32) 373 | ret_dict["scan_idx"] = np.array(idx).astype(np.int64) 374 | ret_dict["gt_box_sizes"] = raw_sizes.astype(np.float32) 375 | ret_dict["gt_box_sizes_normalized"] = box_sizes_normalized.astype(np.float32) 376 | ret_dict["gt_box_angles"] = raw_angles.astype(np.float32) 377 | ret_dict["gt_angle_class_label"] = angle_classes 378 | ret_dict["gt_angle_residual_label"] = angle_residuals 379 | ret_dict["point_cloud_dims_min"] = point_cloud_dims_min 380 | ret_dict["point_cloud_dims_max"] = point_cloud_dims_max 381 | return ret_dict 382 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import datetime 4 | import logging 5 | import math 6 | import time 7 | import sys 8 | 9 | from torch.distributed.distributed_c10d import reduce 10 | from utils.ap_calculator import APCalculator 11 | from utils.misc import SmoothedValue 12 | from utils.dist import ( 13 | all_gather_dict, 14 | all_reduce_average, 15 | is_primary, 16 | reduce_dict, 17 | barrier, 18 | ) 19 | 20 | 21 | def compute_learning_rate(args, curr_epoch_normalized): 22 | assert curr_epoch_normalized <= 1.0 and curr_epoch_normalized >= 0.0 23 | if ( 24 | curr_epoch_normalized <= (args.warm_lr_epochs / args.max_epoch) 25 | and args.warm_lr_epochs > 0 26 | ): 27 | # Linear Warmup 28 | curr_lr = args.warm_lr + curr_epoch_normalized * args.max_epoch * ( 29 | (args.base_lr - args.warm_lr) / args.warm_lr_epochs 30 | ) 31 | else: 32 | # Cosine Learning Rate Schedule 33 | curr_lr = args.final_lr + 0.5 * (args.base_lr - args.final_lr) * ( 34 | 1 + math.cos(math.pi * curr_epoch_normalized) 35 | ) 36 | return curr_lr 37 | 38 | 39 | def adjust_learning_rate(args, optimizer, curr_epoch): 40 | curr_lr = compute_learning_rate(args, curr_epoch) 41 | for param_group in optimizer.param_groups: 42 | param_group["lr"] = curr_lr 43 | return curr_lr 44 | 45 | 46 | def train_one_epoch( 47 | args, 48 | curr_epoch, 49 | model, 50 | optimizer, 51 | criterion, 52 | dataset_config, 53 | dataset_loader, 54 | logger, 55 | ): 56 | 57 | ap_calculator = APCalculator( 58 | dataset_config=dataset_config, 59 | ap_iou_thresh=[0.25, 0.5], 60 | class2type_map=dataset_config.class2type, 61 | exact_eval=False, 62 | ) 63 | 64 | curr_iter = curr_epoch * len(dataset_loader) 65 | max_iters = args.max_epoch * len(dataset_loader) 66 | net_device = next(model.parameters()).device 67 | 68 | time_delta = SmoothedValue(window_size=10) 69 | loss_avg = SmoothedValue(window_size=10) 70 | 71 | model.train() 72 | barrier() 73 | 74 | for batch_idx, batch_data_label in enumerate(dataset_loader): 75 | curr_time = time.time() 76 | curr_lr = adjust_learning_rate(args, optimizer, curr_iter / max_iters) 77 | for key in batch_data_label: 78 | batch_data_label[key] = batch_data_label[key].to(net_device) 79 | 80 | # Forward pass 81 | optimizer.zero_grad() 82 | inputs = { 83 | "point_clouds": batch_data_label["point_clouds"], 84 | "point_cloud_dims_min": batch_data_label["point_cloud_dims_min"], 85 | "point_cloud_dims_max": batch_data_label["point_cloud_dims_max"], 86 | } 87 | outputs = model(inputs) 88 | 89 | # Compute loss 90 | loss, loss_dict = criterion(outputs, batch_data_label) 91 | 92 | loss_reduced = all_reduce_average(loss) 93 | loss_dict_reduced = reduce_dict(loss_dict) 94 | 95 | if not math.isfinite(loss_reduced.item()): 96 | logging.info(f"Loss in not finite. Training will be stopped.") 97 | sys.exit(1) 98 | 99 | loss.backward() 100 | if args.clip_gradient > 0: 101 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_gradient) 102 | optimizer.step() 103 | 104 | if curr_iter % args.log_metrics_every == 0: 105 | # This step is slow. AP is computed approximately and locally during training. 106 | # It will gather outputs and ground truth across all ranks. 107 | # It is memory intensive as point_cloud ground truth is a large tensor. 108 | # If GPU memory is not an issue, uncomment the following lines. 109 | # outputs["outputs"] = all_gather_dict(outputs["outputs"]) 110 | # batch_data_label = all_gather_dict(batch_data_label) 111 | ap_calculator.step_meter(outputs, batch_data_label) 112 | 113 | time_delta.update(time.time() - curr_time) 114 | loss_avg.update(loss_reduced.item()) 115 | 116 | # logging 117 | if is_primary() and curr_iter % args.log_every == 0: 118 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) 119 | eta_seconds = (max_iters - curr_iter) * time_delta.avg 120 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 121 | print( 122 | f"Epoch [{curr_epoch}/{args.max_epoch}]; Iter [{curr_iter}/{max_iters}]; Loss {loss_avg.avg:0.2f}; LR {curr_lr:0.2e}; Iter time {time_delta.avg:0.2f}; ETA {eta_str}; Mem {mem_mb:0.2f}MB" 123 | ) 124 | logger.log_scalars(loss_dict_reduced, curr_iter, prefix="Train_details/") 125 | 126 | train_dict = {} 127 | train_dict["lr"] = curr_lr 128 | train_dict["memory"] = mem_mb 129 | train_dict["loss"] = loss_avg.avg 130 | train_dict["batch_time"] = time_delta.avg 131 | logger.log_scalars(train_dict, curr_iter, prefix="Train/") 132 | 133 | curr_iter += 1 134 | barrier() 135 | 136 | return ap_calculator 137 | 138 | 139 | @torch.no_grad() 140 | def evaluate( 141 | args, 142 | curr_epoch, 143 | model, 144 | criterion, 145 | dataset_config, 146 | dataset_loader, 147 | logger, 148 | curr_train_iter, 149 | ): 150 | 151 | # ap calculator is exact for evaluation. This is slower than the ap calculator used during training. 152 | ap_calculator = APCalculator( 153 | dataset_config=dataset_config, 154 | ap_iou_thresh=[0.25, 0.5], 155 | class2type_map=dataset_config.class2type, 156 | exact_eval=True, 157 | ) 158 | 159 | curr_iter = 0 160 | net_device = next(model.parameters()).device 161 | num_batches = len(dataset_loader) 162 | 163 | time_delta = SmoothedValue(window_size=10) 164 | loss_avg = SmoothedValue(window_size=10) 165 | model.eval() 166 | barrier() 167 | epoch_str = f"[{curr_epoch}/{args.max_epoch}]" if curr_epoch > 0 else "" 168 | 169 | for batch_idx, batch_data_label in enumerate(dataset_loader): 170 | curr_time = time.time() 171 | for key in batch_data_label: 172 | batch_data_label[key] = batch_data_label[key].to(net_device) 173 | 174 | inputs = { 175 | "point_clouds": batch_data_label["point_clouds"], 176 | "point_cloud_dims_min": batch_data_label["point_cloud_dims_min"], 177 | "point_cloud_dims_max": batch_data_label["point_cloud_dims_max"], 178 | } 179 | outputs = model(inputs) 180 | 181 | # Compute loss 182 | loss_str = "" 183 | if criterion is not None: 184 | loss, loss_dict = criterion(outputs, batch_data_label) 185 | 186 | loss_reduced = all_reduce_average(loss) 187 | loss_dict_reduced = reduce_dict(loss_dict) 188 | loss_avg.update(loss_reduced.item()) 189 | loss_str = f"Loss {loss_avg.avg:0.2f};" 190 | 191 | # Memory intensive as it gathers point cloud GT tensor across all ranks 192 | outputs["outputs"] = all_gather_dict(outputs["outputs"]) 193 | batch_data_label = all_gather_dict(batch_data_label) 194 | ap_calculator.step_meter(outputs, batch_data_label) 195 | time_delta.update(time.time() - curr_time) 196 | if is_primary() and curr_iter % args.log_every == 0: 197 | mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) 198 | print( 199 | f"Evaluate {epoch_str}; Batch [{curr_iter}/{num_batches}]; {loss_str} Iter time {time_delta.avg:0.2f}; Mem {mem_mb:0.2f}MB" 200 | ) 201 | 202 | test_dict = {} 203 | test_dict["memory"] = mem_mb 204 | test_dict["batch_time"] = time_delta.avg 205 | if criterion is not None: 206 | test_dict["loss"] = loss_avg.avg 207 | curr_iter += 1 208 | barrier() 209 | if is_primary(): 210 | if criterion is not None: 211 | logger.log_scalars( 212 | loss_dict_reduced, curr_train_iter, prefix="Test_details/" 213 | ) 214 | logger.log_scalars(test_dict, curr_train_iter, prefix="Test/") 215 | 216 | return ap_calculator 217 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import pickle 7 | 8 | import numpy as np 9 | import torch 10 | from torch.multiprocessing import set_start_method 11 | from torch.utils.data import DataLoader, DistributedSampler 12 | 13 | # 3DETR codebase specific imports 14 | from datasets import build_dataset 15 | from engine import evaluate, train_one_epoch 16 | from models import build_model 17 | from optimizer import build_optimizer 18 | from criterion import build_criterion 19 | from utils.dist import init_distributed, is_distributed, is_primary, get_rank, barrier 20 | from utils.misc import my_worker_init_fn 21 | from utils.io import save_checkpoint, resume_if_possible 22 | from utils.logger import Logger 23 | 24 | 25 | def make_args_parser(): 26 | parser = argparse.ArgumentParser("3D Detection Using Transformers", add_help=False) 27 | 28 | ##### Optimizer ##### 29 | parser.add_argument("--base_lr", default=5e-4, type=float) 30 | parser.add_argument("--warm_lr", default=1e-6, type=float) 31 | parser.add_argument("--warm_lr_epochs", default=9, type=int) 32 | parser.add_argument("--final_lr", default=1e-6, type=float) 33 | parser.add_argument("--lr_scheduler", default="cosine", type=str) 34 | parser.add_argument("--weight_decay", default=0.1, type=float) 35 | parser.add_argument("--filter_biases_wd", default=False, action="store_true") 36 | parser.add_argument( 37 | "--clip_gradient", default=0.1, type=float, help="Max L2 norm of the gradient" 38 | ) 39 | 40 | ##### Model ##### 41 | parser.add_argument( 42 | "--model_name", 43 | default="3detr", 44 | type=str, 45 | help="Name of the model", 46 | choices=["3detr"], 47 | ) 48 | ### Encoder 49 | parser.add_argument( 50 | "--enc_type", default="vanilla", choices=["masked", "maskedv2", "vanilla"] 51 | ) 52 | # Below options are only valid for vanilla encoder 53 | parser.add_argument("--enc_nlayers", default=3, type=int) 54 | parser.add_argument("--enc_dim", default=256, type=int) 55 | parser.add_argument("--enc_ffn_dim", default=128, type=int) 56 | parser.add_argument("--enc_dropout", default=0.1, type=float) 57 | parser.add_argument("--enc_nhead", default=4, type=int) 58 | parser.add_argument("--enc_pos_embed", default=None, type=str) 59 | parser.add_argument("--enc_activation", default="relu", type=str) 60 | 61 | ### Decoder 62 | parser.add_argument("--dec_nlayers", default=8, type=int) 63 | parser.add_argument("--dec_dim", default=256, type=int) 64 | parser.add_argument("--dec_ffn_dim", default=256, type=int) 65 | parser.add_argument("--dec_dropout", default=0.1, type=float) 66 | parser.add_argument("--dec_nhead", default=4, type=int) 67 | 68 | ### MLP heads for predicting bounding boxes 69 | parser.add_argument("--mlp_dropout", default=0.3, type=float) 70 | parser.add_argument( 71 | "--nsemcls", 72 | default=-1, 73 | type=int, 74 | help="Number of semantic object classes. Can be inferred from dataset", 75 | ) 76 | 77 | ### Other model params 78 | parser.add_argument("--preenc_npoints", default=2048, type=int) 79 | parser.add_argument( 80 | "--pos_embed", default="fourier", type=str, choices=["fourier", "sine"] 81 | ) 82 | parser.add_argument("--nqueries", default=256, type=int) 83 | parser.add_argument("--use_color", default=False, action="store_true") 84 | 85 | ##### Set Loss ##### 86 | ### Matcher 87 | parser.add_argument("--matcher_giou_cost", default=2, type=float) 88 | parser.add_argument("--matcher_cls_cost", default=1, type=float) 89 | parser.add_argument("--matcher_center_cost", default=0, type=float) 90 | parser.add_argument("--matcher_objectness_cost", default=0, type=float) 91 | 92 | ### Loss Weights 93 | parser.add_argument("--loss_giou_weight", default=0, type=float) 94 | parser.add_argument("--loss_sem_cls_weight", default=1, type=float) 95 | parser.add_argument( 96 | "--loss_no_object_weight", default=0.2, type=float 97 | ) # "no object" or "background" class for detection 98 | parser.add_argument("--loss_angle_cls_weight", default=0.1, type=float) 99 | parser.add_argument("--loss_angle_reg_weight", default=0.5, type=float) 100 | parser.add_argument("--loss_center_weight", default=5.0, type=float) 101 | parser.add_argument("--loss_size_weight", default=1.0, type=float) 102 | 103 | ##### Dataset ##### 104 | parser.add_argument( 105 | "--dataset_name", required=True, type=str, choices=["scannet", "sunrgbd"] 106 | ) 107 | parser.add_argument( 108 | "--dataset_root_dir", 109 | type=str, 110 | default=None, 111 | help="Root directory containing the dataset files. \ 112 | If None, default values from scannet.py/sunrgbd.py are used", 113 | ) 114 | parser.add_argument( 115 | "--meta_data_dir", 116 | type=str, 117 | default=None, 118 | help="Root directory containing the metadata files. \ 119 | If None, default values from scannet.py/sunrgbd.py are used", 120 | ) 121 | parser.add_argument("--dataset_num_workers", default=4, type=int) 122 | parser.add_argument("--batchsize_per_gpu", default=8, type=int) 123 | 124 | ##### Training ##### 125 | parser.add_argument("--start_epoch", default=-1, type=int) 126 | parser.add_argument("--max_epoch", default=720, type=int) 127 | parser.add_argument("--eval_every_epoch", default=10, type=int) 128 | parser.add_argument("--seed", default=0, type=int) 129 | 130 | ##### Testing ##### 131 | parser.add_argument("--test_only", default=False, action="store_true") 132 | parser.add_argument("--test_ckpt", default=None, type=str) 133 | 134 | ##### I/O ##### 135 | parser.add_argument("--checkpoint_dir", default=None, type=str) 136 | parser.add_argument("--log_every", default=10, type=int) 137 | parser.add_argument("--log_metrics_every", default=20, type=int) 138 | parser.add_argument("--save_separate_checkpoint_every_epoch", default=100, type=int) 139 | 140 | ##### Distributed Training ##### 141 | parser.add_argument("--ngpus", default=1, type=int) 142 | parser.add_argument("--dist_url", default="tcp://localhost:12345", type=str) 143 | 144 | return parser 145 | 146 | 147 | def do_train( 148 | args, 149 | model, 150 | model_no_ddp, 151 | optimizer, 152 | criterion, 153 | dataset_config, 154 | dataloaders, 155 | best_val_metrics, 156 | ): 157 | """ 158 | Main training loop. 159 | This trains the model for `args.max_epoch` epochs and tests the model after every `args.eval_every_epoch`. 160 | We always evaluate the final checkpoint and report both the final AP and best AP on the val set. 161 | """ 162 | 163 | num_iters_per_epoch = len(dataloaders["train"]) 164 | num_iters_per_eval_epoch = len(dataloaders["test"]) 165 | print(f"Model is {model}") 166 | print(f"Training started at epoch {args.start_epoch} until {args.max_epoch}.") 167 | print(f"One training epoch = {num_iters_per_epoch} iters.") 168 | print(f"One eval epoch = {num_iters_per_eval_epoch} iters.") 169 | 170 | final_eval = os.path.join(args.checkpoint_dir, "final_eval.txt") 171 | final_eval_pkl = os.path.join(args.checkpoint_dir, "final_eval.pkl") 172 | 173 | if os.path.isfile(final_eval): 174 | print(f"Found final eval file {final_eval}. Skipping training.") 175 | return 176 | 177 | logger = Logger(args.checkpoint_dir) 178 | 179 | for epoch in range(args.start_epoch, args.max_epoch): 180 | if is_distributed(): 181 | dataloaders["train_sampler"].set_epoch(epoch) 182 | 183 | aps = train_one_epoch( 184 | args, 185 | epoch, 186 | model, 187 | optimizer, 188 | criterion, 189 | dataset_config, 190 | dataloaders["train"], 191 | logger, 192 | ) 193 | 194 | # latest checkpoint is always stored in checkpoint.pth 195 | save_checkpoint( 196 | args.checkpoint_dir, 197 | model_no_ddp, 198 | optimizer, 199 | epoch, 200 | args, 201 | best_val_metrics, 202 | filename="checkpoint.pth", 203 | ) 204 | 205 | metrics = aps.compute_metrics() 206 | metric_str = aps.metrics_to_str(metrics, per_class=False) 207 | metrics_dict = aps.metrics_to_dict(metrics) 208 | curr_iter = epoch * len(dataloaders["train"]) 209 | if is_primary(): 210 | print("==" * 10) 211 | print(f"Epoch [{epoch}/{args.max_epoch}]; Metrics {metric_str}") 212 | print("==" * 10) 213 | logger.log_scalars(metrics_dict, curr_iter, prefix="Train/") 214 | 215 | if ( 216 | epoch > 0 217 | and args.save_separate_checkpoint_every_epoch > 0 218 | and epoch % args.save_separate_checkpoint_every_epoch == 0 219 | ): 220 | # separate checkpoints are stored as checkpoint_{epoch}.pth 221 | save_checkpoint( 222 | args.checkpoint_dir, 223 | model_no_ddp, 224 | optimizer, 225 | epoch, 226 | args, 227 | best_val_metrics, 228 | ) 229 | 230 | if epoch % args.eval_every_epoch == 0 or epoch == (args.max_epoch - 1): 231 | ap_calculator = evaluate( 232 | args, 233 | epoch, 234 | model, 235 | criterion, 236 | dataset_config, 237 | dataloaders["test"], 238 | logger, 239 | curr_iter, 240 | ) 241 | metrics = ap_calculator.compute_metrics() 242 | ap25 = metrics[0.25]["mAP"] 243 | metric_str = ap_calculator.metrics_to_str(metrics, per_class=True) 244 | metrics_dict = ap_calculator.metrics_to_dict(metrics) 245 | if is_primary(): 246 | print("==" * 10) 247 | print(f"Evaluate Epoch [{epoch}/{args.max_epoch}]; Metrics {metric_str}") 248 | print("==" * 10) 249 | logger.log_scalars(metrics_dict, curr_iter, prefix="Test/") 250 | 251 | if is_primary() and ( 252 | len(best_val_metrics) == 0 or best_val_metrics[0.25]["mAP"] < ap25 253 | ): 254 | best_val_metrics = metrics 255 | filename = "checkpoint_best.pth" 256 | save_checkpoint( 257 | args.checkpoint_dir, 258 | model_no_ddp, 259 | optimizer, 260 | epoch, 261 | args, 262 | best_val_metrics, 263 | filename=filename, 264 | ) 265 | print( 266 | f"Epoch [{epoch}/{args.max_epoch}] saved current best val checkpoint at {filename}; ap25 {ap25}" 267 | ) 268 | 269 | # always evaluate last checkpoint 270 | epoch = args.max_epoch - 1 271 | curr_iter = epoch * len(dataloaders["train"]) 272 | ap_calculator = evaluate( 273 | args, 274 | epoch, 275 | model, 276 | criterion, 277 | dataset_config, 278 | dataloaders["test"], 279 | logger, 280 | curr_iter, 281 | ) 282 | metrics = ap_calculator.compute_metrics() 283 | metric_str = ap_calculator.metrics_to_str(metrics) 284 | if is_primary(): 285 | print("==" * 10) 286 | print(f"Evaluate Final [{epoch}/{args.max_epoch}]; Metrics {metric_str}") 287 | print("==" * 10) 288 | 289 | with open(final_eval, "w") as fh: 290 | fh.write("Training Finished.\n") 291 | fh.write("==" * 10) 292 | fh.write("Final Eval Numbers.\n") 293 | fh.write(metric_str) 294 | fh.write("\n") 295 | fh.write("==" * 10) 296 | fh.write("Best Eval Numbers.\n") 297 | fh.write(ap_calculator.metrics_to_str(best_val_metrics)) 298 | fh.write("\n") 299 | 300 | with open(final_eval_pkl, "wb") as fh: 301 | pickle.dump(metrics, fh) 302 | 303 | 304 | def test_model(args, model, model_no_ddp, criterion, dataset_config, dataloaders): 305 | if args.test_ckpt is None or not os.path.isfile(args.test_ckpt): 306 | f"Please specify a test checkpoint using --test_ckpt. Found invalid value {args.test_ckpt}" 307 | sys.exit(1) 308 | 309 | sd = torch.load(args.test_ckpt, map_location=torch.device("cpu")) 310 | model_no_ddp.load_state_dict(sd["model"]) 311 | logger = Logger() 312 | criterion = None # do not compute loss for speed-up; Comment out to see test loss 313 | epoch = -1 314 | curr_iter = 0 315 | ap_calculator = evaluate( 316 | args, 317 | epoch, 318 | model, 319 | criterion, 320 | dataset_config, 321 | dataloaders["test"], 322 | logger, 323 | curr_iter, 324 | ) 325 | metrics = ap_calculator.compute_metrics() 326 | metric_str = ap_calculator.metrics_to_str(metrics) 327 | if is_primary(): 328 | print("==" * 10) 329 | print(f"Test model; Metrics {metric_str}") 330 | print("==" * 10) 331 | 332 | 333 | def main(local_rank, args): 334 | if args.ngpus > 1: 335 | print( 336 | "Initializing Distributed Training. This is in BETA mode and hasn't been tested thoroughly. Use at your own risk :)" 337 | ) 338 | print("To get the maximum speed-up consider reducing evaluations on val set by setting --eval_every_epoch to greater than 50") 339 | init_distributed( 340 | local_rank, 341 | global_rank=local_rank, 342 | world_size=args.ngpus, 343 | dist_url=args.dist_url, 344 | dist_backend="nccl", 345 | ) 346 | 347 | print(f"Called with args: {args}") 348 | torch.cuda.set_device(local_rank) 349 | np.random.seed(args.seed + get_rank()) 350 | torch.manual_seed(args.seed + get_rank()) 351 | if torch.cuda.is_available(): 352 | torch.cuda.manual_seed_all(args.seed + get_rank()) 353 | 354 | datasets, dataset_config = build_dataset(args) 355 | model, _ = build_model(args, dataset_config) 356 | model = model.cuda(local_rank) 357 | model_no_ddp = model 358 | 359 | if is_distributed(): 360 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 361 | model = torch.nn.parallel.DistributedDataParallel( 362 | model, device_ids=[local_rank] 363 | ) 364 | criterion = build_criterion(args, dataset_config) 365 | criterion = criterion.cuda(local_rank) 366 | 367 | dataloaders = {} 368 | if args.test_only: 369 | dataset_splits = ["test"] 370 | else: 371 | dataset_splits = ["train", "test"] 372 | for split in dataset_splits: 373 | if split == "train": 374 | shuffle = True 375 | else: 376 | shuffle = False 377 | if is_distributed(): 378 | sampler = DistributedSampler(datasets[split], shuffle=shuffle) 379 | elif shuffle: 380 | sampler = torch.utils.data.RandomSampler(datasets[split]) 381 | else: 382 | sampler = torch.utils.data.SequentialSampler(datasets[split]) 383 | 384 | dataloaders[split] = DataLoader( 385 | datasets[split], 386 | sampler=sampler, 387 | batch_size=args.batchsize_per_gpu, 388 | num_workers=args.dataset_num_workers, 389 | worker_init_fn=my_worker_init_fn, 390 | ) 391 | dataloaders[split + "_sampler"] = sampler 392 | 393 | if args.test_only: 394 | criterion = None # faster evaluation 395 | test_model(args, model, model_no_ddp, criterion, dataset_config, dataloaders) 396 | else: 397 | assert ( 398 | args.checkpoint_dir is not None 399 | ), f"Please specify a checkpoint dir using --checkpoint_dir" 400 | if is_primary() and not os.path.isdir(args.checkpoint_dir): 401 | os.makedirs(args.checkpoint_dir, exist_ok=True) 402 | optimizer = build_optimizer(args, model_no_ddp) 403 | loaded_epoch, best_val_metrics = resume_if_possible( 404 | args.checkpoint_dir, model_no_ddp, optimizer 405 | ) 406 | args.start_epoch = loaded_epoch + 1 407 | do_train( 408 | args, 409 | model, 410 | model_no_ddp, 411 | optimizer, 412 | criterion, 413 | dataset_config, 414 | dataloaders, 415 | best_val_metrics, 416 | ) 417 | 418 | 419 | def launch_distributed(args): 420 | world_size = args.ngpus 421 | if world_size == 1: 422 | main(local_rank=0, args=args) 423 | else: 424 | torch.multiprocessing.spawn(main, nprocs=world_size, args=(args,)) 425 | 426 | 427 | if __name__ == "__main__": 428 | parser = make_args_parser() 429 | args = parser.parse_args() 430 | try: 431 | set_start_method("spawn") 432 | except RuntimeError: 433 | pass 434 | launch_distributed(args) 435 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .model_3detr import build_3detr 3 | 4 | MODEL_FUNCS = { 5 | "3detr": build_3detr, 6 | } 7 | 8 | def build_model(args, dataset_config): 9 | model, processor = MODEL_FUNCS[args.model_name](args, dataset_config) 10 | return model, processor -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch.nn as nn 3 | from functools import partial 4 | import copy 5 | 6 | 7 | class BatchNormDim1Swap(nn.BatchNorm1d): 8 | """ 9 | Used for nn.Transformer that uses a HW x N x C rep 10 | """ 11 | 12 | def forward(self, x): 13 | """ 14 | x: HW x N x C 15 | permute to N x C x HW 16 | Apply BN on C 17 | permute back 18 | """ 19 | hw, n, c = x.shape 20 | x = x.permute(1, 2, 0) 21 | x = super(BatchNormDim1Swap, self).forward(x) 22 | # x: n x c x hw -> hw x n x c 23 | x = x.permute(2, 0, 1) 24 | return x 25 | 26 | 27 | NORM_DICT = { 28 | "bn": BatchNormDim1Swap, 29 | "bn1d": nn.BatchNorm1d, 30 | "id": nn.Identity, 31 | "ln": nn.LayerNorm, 32 | } 33 | 34 | ACTIVATION_DICT = { 35 | "relu": nn.ReLU, 36 | "gelu": nn.GELU, 37 | "leakyrelu": partial(nn.LeakyReLU, negative_slope=0.1), 38 | } 39 | 40 | WEIGHT_INIT_DICT = { 41 | "xavier_uniform": nn.init.xavier_uniform_, 42 | } 43 | 44 | 45 | class GenericMLP(nn.Module): 46 | def __init__( 47 | self, 48 | input_dim, 49 | hidden_dims, 50 | output_dim, 51 | norm_fn_name=None, 52 | activation="relu", 53 | use_conv=False, 54 | dropout=None, 55 | hidden_use_bias=False, 56 | output_use_bias=True, 57 | output_use_activation=False, 58 | output_use_norm=False, 59 | weight_init_name=None, 60 | ): 61 | super().__init__() 62 | activation = ACTIVATION_DICT[activation] 63 | norm = None 64 | if norm_fn_name is not None: 65 | norm = NORM_DICT[norm_fn_name] 66 | if norm_fn_name == "ln" and use_conv: 67 | norm = lambda x: nn.GroupNorm(1, x) # easier way to use LayerNorm 68 | 69 | if dropout is not None: 70 | if not isinstance(dropout, list): 71 | dropout = [dropout for _ in range(len(hidden_dims))] 72 | 73 | layers = [] 74 | prev_dim = input_dim 75 | for idx, x in enumerate(hidden_dims): 76 | if use_conv: 77 | layer = nn.Conv1d(prev_dim, x, 1, bias=hidden_use_bias) 78 | else: 79 | layer = nn.Linear(prev_dim, x, bias=hidden_use_bias) 80 | layers.append(layer) 81 | if norm: 82 | layers.append(norm(x)) 83 | layers.append(activation()) 84 | if dropout is not None: 85 | layers.append(nn.Dropout(p=dropout[idx])) 86 | prev_dim = x 87 | if use_conv: 88 | layer = nn.Conv1d(prev_dim, output_dim, 1, bias=output_use_bias) 89 | else: 90 | layer = nn.Linear(prev_dim, output_dim, bias=output_use_bias) 91 | layers.append(layer) 92 | 93 | if output_use_norm: 94 | layers.append(norm(output_dim)) 95 | 96 | if output_use_activation: 97 | layers.append(activation()) 98 | 99 | self.layers = nn.Sequential(*layers) 100 | 101 | if weight_init_name is not None: 102 | self.do_weight_init(weight_init_name) 103 | 104 | def do_weight_init(self, weight_init_name): 105 | func = WEIGHT_INIT_DICT[weight_init_name] 106 | for (_, param) in self.named_parameters(): 107 | if param.dim() > 1: # skips batchnorm/layernorm 108 | func(param) 109 | 110 | def forward(self, x): 111 | output = self.layers(x) 112 | return output 113 | 114 | 115 | def get_clones(module, N): 116 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 117 | -------------------------------------------------------------------------------- /models/position_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | import numpy as np 9 | from utils.pc_util import shift_scale_points 10 | 11 | 12 | class PositionEmbeddingCoordsSine(nn.Module): 13 | def __init__( 14 | self, 15 | temperature=10000, 16 | normalize=False, 17 | scale=None, 18 | pos_type="fourier", 19 | d_pos=None, 20 | d_in=3, 21 | gauss_scale=1.0, 22 | ): 23 | super().__init__() 24 | self.temperature = temperature 25 | self.normalize = normalize 26 | if scale is not None and normalize is False: 27 | raise ValueError("normalize should be True if scale is passed") 28 | if scale is None: 29 | scale = 2 * math.pi 30 | assert pos_type in ["sine", "fourier"] 31 | self.pos_type = pos_type 32 | self.scale = scale 33 | if pos_type == "fourier": 34 | assert d_pos is not None 35 | assert d_pos % 2 == 0 36 | # define a gaussian matrix input_ch -> output_ch 37 | B = torch.empty((d_in, d_pos // 2)).normal_() 38 | B *= gauss_scale 39 | self.register_buffer("gauss_B", B) 40 | self.d_pos = d_pos 41 | 42 | def get_sine_embeddings(self, xyz, num_channels, input_range): 43 | # clone coords so that shift/scale operations do not affect original tensor 44 | orig_xyz = xyz 45 | xyz = orig_xyz.clone() 46 | 47 | ncoords = xyz.shape[1] 48 | if self.normalize: 49 | xyz = shift_scale_points(xyz, src_range=input_range) 50 | 51 | ndim = num_channels // xyz.shape[2] 52 | if ndim % 2 != 0: 53 | ndim -= 1 54 | # automatically handle remainder by assiging it to the first dim 55 | rems = num_channels - (ndim * xyz.shape[2]) 56 | 57 | assert ( 58 | ndim % 2 == 0 59 | ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}" 60 | 61 | final_embeds = [] 62 | prev_dim = 0 63 | 64 | for d in range(xyz.shape[2]): 65 | cdim = ndim 66 | if rems > 0: 67 | # add remainder in increments of two to maintain even size 68 | cdim += 2 69 | rems -= 2 70 | 71 | if cdim != prev_dim: 72 | dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device) 73 | dim_t = self.temperature ** (2 * (dim_t // 2) / cdim) 74 | 75 | # create batch x cdim x nccords embedding 76 | raw_pos = xyz[:, :, d] 77 | if self.scale: 78 | raw_pos *= self.scale 79 | pos = raw_pos[:, :, None] / dim_t 80 | pos = torch.stack( 81 | (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3 82 | ).flatten(2) 83 | final_embeds.append(pos) 84 | prev_dim = cdim 85 | 86 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 87 | return final_embeds 88 | 89 | def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None): 90 | # Follows - https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 91 | 92 | if num_channels is None: 93 | num_channels = self.gauss_B.shape[1] * 2 94 | 95 | bsize, npoints = xyz.shape[0], xyz.shape[1] 96 | assert num_channels > 0 and num_channels % 2 == 0 97 | d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1] 98 | d_out = num_channels // 2 99 | assert d_out <= max_d_out 100 | assert d_in == xyz.shape[-1] 101 | 102 | # clone coords so that shift/scale operations do not affect original tensor 103 | orig_xyz = xyz 104 | xyz = orig_xyz.clone() 105 | 106 | ncoords = xyz.shape[1] 107 | if self.normalize: 108 | xyz = shift_scale_points(xyz, src_range=input_range) 109 | 110 | xyz *= 2 * np.pi 111 | xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view( 112 | bsize, npoints, d_out 113 | ) 114 | final_embeds = [xyz_proj.sin(), xyz_proj.cos()] 115 | 116 | # return batch x d_pos x npoints embedding 117 | final_embeds = torch.cat(final_embeds, dim=2).permute(0, 2, 1) 118 | return final_embeds 119 | 120 | def forward(self, xyz, num_channels=None, input_range=None): 121 | assert isinstance(xyz, torch.Tensor) 122 | assert xyz.ndim == 3 123 | # xyz is batch x npoints x 3 124 | if self.pos_type == "sine": 125 | with torch.no_grad(): 126 | return self.get_sine_embeddings(xyz, num_channels, input_range) 127 | elif self.pos_type == "fourier": 128 | with torch.no_grad(): 129 | return self.get_fourier_embeddings(xyz, num_channels, input_range) 130 | else: 131 | raise ValueError(f"Unknown {self.pos_type}") 132 | 133 | def extra_repr(self): 134 | st = f"type={self.pos_type}, scale={self.scale}, normalize={self.normalize}" 135 | if hasattr(self, "gauss_B"): 136 | st += ( 137 | f", gaussB={self.gauss_B.shape}, gaussBsum={self.gauss_B.sum().item()}" 138 | ) 139 | return st 140 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | 4 | 5 | def build_optimizer(args, model): 6 | 7 | params_with_decay = [] 8 | params_without_decay = [] 9 | for name, param in model.named_parameters(): 10 | if param.requires_grad is False: 11 | continue 12 | if args.filter_biases_wd and (len(param.shape) == 1 or name.endswith("bias")): 13 | params_without_decay.append(param) 14 | else: 15 | params_with_decay.append(param) 16 | 17 | if args.filter_biases_wd: 18 | param_groups = [ 19 | {"params": params_without_decay, "weight_decay": 0.0}, 20 | {"params": params_with_decay, "weight_decay": args.weight_decay}, 21 | ] 22 | else: 23 | param_groups = [ 24 | {"params": params_with_decay, "weight_decay": args.weight_decay}, 25 | ] 26 | optimizer = torch.optim.AdamW(param_groups, lr=args.base_lr) 27 | return optimizer 28 | -------------------------------------------------------------------------------- /scripts/scannet_ep1080.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py \ 5 | --dataset_name scannet \ 6 | --max_epoch 1080 \ 7 | --nqueries 256 \ 8 | --matcher_giou_cost 2 \ 9 | --matcher_cls_cost 1 \ 10 | --matcher_center_cost 0 \ 11 | --matcher_objectness_cost 0 \ 12 | --loss_giou_weight 1 \ 13 | --loss_no_object_weight 0.25 \ 14 | --save_separate_checkpoint_every_epoch -1 \ 15 | --checkpoint_dir outputs/scannet_ep1080 16 | -------------------------------------------------------------------------------- /scripts/scannet_masked_ep1080.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py \ 5 | --dataset_name scannet \ 6 | --max_epoch 1080 \ 7 | --enc_type masked \ 8 | --enc_dropout 0.3 \ 9 | --nqueries 256 \ 10 | --base_lr 5e-4 \ 11 | --matcher_giou_cost 2 \ 12 | --matcher_cls_cost 1 \ 13 | --matcher_center_cost 0 \ 14 | --matcher_objectness_cost 0 \ 15 | --loss_giou_weight 1 \ 16 | --loss_no_object_weight 0.25 \ 17 | --save_separate_checkpoint_every_epoch -1 \ 18 | --checkpoint_dir outputs/scannet_masked_ep1080 19 | -------------------------------------------------------------------------------- /scripts/scannet_masked_ep1080_color.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py \ 5 | --dataset_name scannet \ 6 | --max_epoch 1080 \ 7 | --enc_type masked \ 8 | --enc_dropout 0.3 \ 9 | --nqueries 256 \ 10 | --base_lr 5e-4 \ 11 | --matcher_giou_cost 2 \ 12 | --matcher_cls_cost 1 \ 13 | --matcher_center_cost 0 \ 14 | --matcher_objectness_cost 0 \ 15 | --loss_giou_weight 1 \ 16 | --loss_no_object_weight 0.25 \ 17 | --save_separate_checkpoint_every_epoch -1 \ 18 | --use_color \ 19 | --checkpoint_dir outputs/scannet_masked_ep1080_color 20 | -------------------------------------------------------------------------------- /scripts/scannet_quick.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py 5 | --dataset_name scannet \ 6 | --nqueries 256 \ 7 | --max_epoch 90 \ 8 | --matcher_giou_cost 2 \ 9 | --matcher_cls_cost 1 \ 10 | --matcher_center_cost 0 \ 11 | --matcher_objectness_cost 0 \ 12 | --loss_giou_weight 1 \ 13 | --loss_no_object_weight 0.25 \ 14 | --save_separate_checkpoint_every_epoch -1 \ 15 | --checkpoint_dir outputs/scannet_quick 16 | -------------------------------------------------------------------------------- /scripts/sunrgbd_ep1080.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py 5 | --dataset_name sunrgbd \ 6 | --max_epoch 1080 \ 7 | --nqueries 128 \ 8 | --base_lr 7e-4 \ 9 | --matcher_giou_cost 3 \ 10 | --matcher_cls_cost 1 \ 11 | --matcher_center_cost 5 \ 12 | --matcher_objectness_cost 5 \ 13 | --loss_giou_weight 0 \ 14 | --loss_no_object_weight 0.1 \ 15 | --save_separate_checkpoint_every_epoch -1 \ 16 | --checkpoint_dir outputs/sunrgbd_ep1080 17 | -------------------------------------------------------------------------------- /scripts/sunrgbd_masked_ep1080.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py \ 5 | --dataset_name sunrgbd \ 6 | --max_epoch 1080 \ 7 | --enc_type masked \ 8 | --nqueries 128 \ 9 | --base_lr 7e-4 \ 10 | --matcher_giou_cost 3 \ 11 | --matcher_cls_cost 1 \ 12 | --matcher_center_cost 5 \ 13 | --matcher_objectness_cost 5 \ 14 | --loss_giou_weight 0 \ 15 | --loss_sem_cls_weight 0.8 \ 16 | --loss_no_object_weight 0.1 \ 17 | --save_separate_checkpoint_every_epoch -1 \ 18 | --checkpoint_dir outputs/sunrgbd_masked_ep1080 19 | -------------------------------------------------------------------------------- /scripts/sunrgbd_quick.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | python main.py \ 5 | --dataset_name sunrgbd \ 6 | --max_epoch 90 \ 7 | --nqueries 128 \ 8 | --base_lr 7e-4 \ 9 | --matcher_giou_cost 3 \ 10 | --matcher_cls_cost 1 \ 11 | --matcher_center_cost 5 \ 12 | --matcher_objectness_cost 5 \ 13 | --loss_giou_weight 0 \ 14 | --loss_no_object_weight 0.1 \ 15 | --save_separate_checkpoint_every_epoch -1 \ 16 | --checkpoint_dir outputs/sunrgbd_quick 17 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | #include 5 | 6 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 7 | const int nsample); 8 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #ifndef _CUDA_UTILS_H 4 | #define _CUDA_UTILS_H 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #define TOTAL_THREADS 512 16 | 17 | inline int opt_n_threads(int work_size) { 18 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 19 | 20 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 21 | } 22 | 23 | inline dim3 opt_block_config(int x, int y) { 24 | const int x_threads = opt_n_threads(x); 25 | const int y_threads = 26 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 27 | dim3 block_config(x_threads, y_threads, 1); 28 | 29 | return block_config; 30 | } 31 | 32 | #define CUDA_CHECK_ERRORS() \ 33 | do { \ 34 | cudaError_t err = cudaGetLastError(); \ 35 | if (cudaSuccess != err) { \ 36 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 37 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 38 | __FILE__); \ 39 | exit(-1); \ 40 | } \ 41 | } while (0) 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #pragma once 5 | #include 6 | 7 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 8 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 9 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #pragma once 4 | 5 | #include 6 | #include 7 | 8 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 9 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 10 | at::Tensor weight); 11 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 12 | at::Tensor weight, const int m); 13 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #pragma once 5 | #include 6 | 7 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 8 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 9 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 10 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #pragma once 5 | #include 6 | #include 7 | 8 | #define CHECK_CUDA(x) \ 9 | do { \ 10 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 11 | } while (0) 12 | 13 | #define CHECK_CONTIGUOUS(x) \ 14 | do { \ 15 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 16 | } while (0) 17 | 18 | #define CHECK_IS_INT(x) \ 19 | do { \ 20 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 21 | #x " must be an int tensor"); \ 22 | } while (0) 23 | 24 | #define CHECK_IS_FLOAT(x) \ 25 | do { \ 26 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 27 | #x " must be a float tensor"); \ 28 | } while (0) 29 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include "ball_query.h" 5 | #include "utils.h" 6 | 7 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 8 | int nsample, const float *new_xyz, 9 | const float *xyz, int *idx); 10 | 11 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 12 | const int nsample) { 13 | CHECK_CONTIGUOUS(new_xyz); 14 | CHECK_CONTIGUOUS(xyz); 15 | CHECK_IS_FLOAT(new_xyz); 16 | CHECK_IS_FLOAT(xyz); 17 | 18 | if (new_xyz.is_cuda()) { 19 | CHECK_CUDA(xyz); 20 | } 21 | 22 | at::Tensor idx = 23 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 24 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 25 | 26 | if (new_xyz.is_cuda()) { 27 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 28 | radius, nsample, new_xyz.data(), 29 | xyz.data(), idx.data()); 30 | } else { 31 | AT_ASSERT(false, "CPU not supported"); 32 | } 33 | 34 | return idx; 35 | } 36 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cuda_utils.h" 9 | 10 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 11 | // output: idx(b, m, nsample) 12 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 13 | int nsample, 14 | const float *__restrict__ new_xyz, 15 | const float *__restrict__ xyz, 16 | int *__restrict__ idx) { 17 | int batch_index = blockIdx.x; 18 | xyz += batch_index * n * 3; 19 | new_xyz += batch_index * m * 3; 20 | idx += m * nsample * batch_index; 21 | 22 | int index = threadIdx.x; 23 | int stride = blockDim.x; 24 | 25 | float radius2 = radius * radius; 26 | for (int j = index; j < m; j += stride) { 27 | float new_x = new_xyz[j * 3 + 0]; 28 | float new_y = new_xyz[j * 3 + 1]; 29 | float new_z = new_xyz[j * 3 + 2]; 30 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 31 | float x = xyz[k * 3 + 0]; 32 | float y = xyz[k * 3 + 1]; 33 | float z = xyz[k * 3 + 2]; 34 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 35 | (new_z - z) * (new_z - z); 36 | if (d2 < radius2) { 37 | if (cnt == 0) { 38 | for (int l = 0; l < nsample; ++l) { 39 | idx[j * nsample + l] = k; 40 | } 41 | } 42 | idx[j * nsample + cnt] = k; 43 | ++cnt; 44 | } 45 | } 46 | } 47 | } 48 | 49 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 50 | int nsample, const float *new_xyz, 51 | const float *xyz, int *idx) { 52 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 53 | query_ball_point_kernel<<>>( 54 | b, n, m, radius, nsample, new_xyz, xyz, idx); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include "ball_query.h" 5 | #include "group_points.h" 6 | #include "interpolate.h" 7 | #include "sampling.h" 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("gather_points", &gather_points); 11 | m.def("gather_points_grad", &gather_points_grad); 12 | m.def("furthest_point_sampling", &furthest_point_sampling); 13 | 14 | m.def("three_nn", &three_nn); 15 | m.def("three_interpolate", &three_interpolate); 16 | m.def("three_interpolate_grad", &three_interpolate_grad); 17 | 18 | m.def("ball_query", &ball_query); 19 | 20 | m.def("group_points", &group_points); 21 | m.def("group_points_grad", &group_points_grad); 22 | } 23 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include "group_points.h" 5 | #include "utils.h" 6 | 7 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 8 | const float *points, const int *idx, 9 | float *out); 10 | 11 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 12 | int nsample, const float *grad_out, 13 | const int *idx, float *grad_points); 14 | 15 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), idx.size(2), points.data(), 32 | idx.data(), output.data()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 41 | CHECK_CONTIGUOUS(grad_out); 42 | CHECK_CONTIGUOUS(idx); 43 | CHECK_IS_FLOAT(grad_out); 44 | CHECK_IS_INT(idx); 45 | 46 | if (grad_out.is_cuda()) { 47 | CHECK_CUDA(idx); 48 | } 49 | 50 | at::Tensor output = 51 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 52 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 53 | 54 | if (grad_out.is_cuda()) { 55 | group_points_grad_kernel_wrapper( 56 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 57 | grad_out.data(), idx.data(), output.data()); 58 | } else { 59 | AT_ASSERT(false, "CPU not supported"); 60 | } 61 | 62 | return output; 63 | } 64 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | // input: points(b, c, n) idx(b, npoints, nsample) 10 | // output: out(b, c, npoints, nsample) 11 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 12 | int nsample, 13 | const float *__restrict__ points, 14 | const int *__restrict__ idx, 15 | float *__restrict__ out) { 16 | int batch_index = blockIdx.x; 17 | points += batch_index * n * c; 18 | idx += batch_index * npoints * nsample; 19 | out += batch_index * npoints * nsample * c; 20 | 21 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 22 | const int stride = blockDim.y * blockDim.x; 23 | for (int i = index; i < c * npoints; i += stride) { 24 | const int l = i / npoints; 25 | const int j = i % npoints; 26 | for (int k = 0; k < nsample; ++k) { 27 | int ii = idx[j * nsample + k]; 28 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 29 | } 30 | } 31 | } 32 | 33 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 34 | const float *points, const int *idx, 35 | float *out) { 36 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 37 | 38 | group_points_kernel<<>>( 39 | b, c, n, npoints, nsample, points, idx, out); 40 | 41 | CUDA_CHECK_ERRORS(); 42 | } 43 | 44 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 45 | // output: grad_points(b, c, n) 46 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 47 | int nsample, 48 | const float *__restrict__ grad_out, 49 | const int *__restrict__ idx, 50 | float *__restrict__ grad_points) { 51 | int batch_index = blockIdx.x; 52 | grad_out += batch_index * npoints * nsample * c; 53 | idx += batch_index * npoints * nsample; 54 | grad_points += batch_index * n * c; 55 | 56 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 57 | const int stride = blockDim.y * blockDim.x; 58 | for (int i = index; i < c * npoints; i += stride) { 59 | const int l = i / npoints; 60 | const int j = i % npoints; 61 | for (int k = 0; k < nsample; ++k) { 62 | int ii = idx[j * nsample + k]; 63 | atomicAdd(grad_points + l * n + ii, 64 | grad_out[(l * npoints + j) * nsample + k]); 65 | } 66 | } 67 | } 68 | 69 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 70 | int nsample, const float *grad_out, 71 | const int *idx, float *grad_points) { 72 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 73 | 74 | group_points_grad_kernel<<>>( 75 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 76 | 77 | CUDA_CHECK_ERRORS(); 78 | } 79 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include "interpolate.h" 4 | #include "utils.h" 5 | 6 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 7 | const float *known, float *dist2, int *idx); 8 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 9 | const float *points, const int *idx, 10 | const float *weight, float *out); 11 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 12 | const float *grad_out, 13 | const int *idx, const float *weight, 14 | float *grad_points); 15 | 16 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 17 | CHECK_CONTIGUOUS(unknowns); 18 | CHECK_CONTIGUOUS(knows); 19 | CHECK_IS_FLOAT(unknowns); 20 | CHECK_IS_FLOAT(knows); 21 | 22 | if (unknowns.is_cuda()) { 23 | CHECK_CUDA(knows); 24 | } 25 | 26 | at::Tensor idx = 27 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 28 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 29 | at::Tensor dist2 = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 32 | 33 | if (unknowns.is_cuda()) { 34 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 35 | unknowns.data(), knows.data(), 36 | dist2.data(), idx.data()); 37 | } else { 38 | AT_ASSERT(false, "CPU not supported"); 39 | } 40 | 41 | return {dist2, idx}; 42 | } 43 | 44 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 45 | at::Tensor weight) { 46 | CHECK_CONTIGUOUS(points); 47 | CHECK_CONTIGUOUS(idx); 48 | CHECK_CONTIGUOUS(weight); 49 | CHECK_IS_FLOAT(points); 50 | CHECK_IS_INT(idx); 51 | CHECK_IS_FLOAT(weight); 52 | 53 | if (points.is_cuda()) { 54 | CHECK_CUDA(idx); 55 | CHECK_CUDA(weight); 56 | } 57 | 58 | at::Tensor output = 59 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 60 | at::device(points.device()).dtype(at::ScalarType::Float)); 61 | 62 | if (points.is_cuda()) { 63 | three_interpolate_kernel_wrapper( 64 | points.size(0), points.size(1), points.size(2), idx.size(1), 65 | points.data(), idx.data(), weight.data(), 66 | output.data()); 67 | } else { 68 | AT_ASSERT(false, "CPU not supported"); 69 | } 70 | 71 | return output; 72 | } 73 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 74 | at::Tensor weight, const int m) { 75 | CHECK_CONTIGUOUS(grad_out); 76 | CHECK_CONTIGUOUS(idx); 77 | CHECK_CONTIGUOUS(weight); 78 | CHECK_IS_FLOAT(grad_out); 79 | CHECK_IS_INT(idx); 80 | CHECK_IS_FLOAT(weight); 81 | 82 | if (grad_out.is_cuda()) { 83 | CHECK_CUDA(idx); 84 | CHECK_CUDA(weight); 85 | } 86 | 87 | at::Tensor output = 88 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 89 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 90 | 91 | if (grad_out.is_cuda()) { 92 | three_interpolate_grad_kernel_wrapper( 93 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 94 | grad_out.data(), idx.data(), weight.data(), 95 | output.data()); 96 | } else { 97 | AT_ASSERT(false, "CPU not supported"); 98 | } 99 | 100 | return output; 101 | } 102 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cuda_utils.h" 9 | 10 | // input: unknown(b, n, 3) known(b, m, 3) 11 | // output: dist2(b, n, 3), idx(b, n, 3) 12 | __global__ void three_nn_kernel(int b, int n, int m, 13 | const float *__restrict__ unknown, 14 | const float *__restrict__ known, 15 | float *__restrict__ dist2, 16 | int *__restrict__ idx) { 17 | int batch_index = blockIdx.x; 18 | unknown += batch_index * n * 3; 19 | known += batch_index * m * 3; 20 | dist2 += batch_index * n * 3; 21 | idx += batch_index * n * 3; 22 | 23 | int index = threadIdx.x; 24 | int stride = blockDim.x; 25 | for (int j = index; j < n; j += stride) { 26 | float ux = unknown[j * 3 + 0]; 27 | float uy = unknown[j * 3 + 1]; 28 | float uz = unknown[j * 3 + 2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; 39 | besti3 = besti2; 40 | best2 = best1; 41 | besti2 = besti1; 42 | best1 = d; 43 | besti1 = k; 44 | } else if (d < best2) { 45 | best3 = best2; 46 | besti3 = besti2; 47 | best2 = d; 48 | besti2 = k; 49 | } else if (d < best3) { 50 | best3 = d; 51 | besti3 = k; 52 | } 53 | } 54 | dist2[j * 3 + 0] = best1; 55 | dist2[j * 3 + 1] = best2; 56 | dist2[j * 3 + 2] = best3; 57 | 58 | idx[j * 3 + 0] = besti1; 59 | idx[j * 3 + 1] = besti2; 60 | idx[j * 3 + 2] = besti3; 61 | } 62 | } 63 | 64 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 65 | const float *known, float *dist2, int *idx) { 66 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 67 | three_nn_kernel<<>>(b, n, m, unknown, known, 68 | dist2, idx); 69 | 70 | CUDA_CHECK_ERRORS(); 71 | } 72 | 73 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 74 | // output: out(b, c, n) 75 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 76 | const float *__restrict__ points, 77 | const int *__restrict__ idx, 78 | const float *__restrict__ weight, 79 | float *__restrict__ out) { 80 | int batch_index = blockIdx.x; 81 | points += batch_index * m * c; 82 | 83 | idx += batch_index * n * 3; 84 | weight += batch_index * n * 3; 85 | 86 | out += batch_index * n * c; 87 | 88 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 89 | const int stride = blockDim.y * blockDim.x; 90 | for (int i = index; i < c * n; i += stride) { 91 | const int l = i / n; 92 | const int j = i % n; 93 | float w1 = weight[j * 3 + 0]; 94 | float w2 = weight[j * 3 + 1]; 95 | float w3 = weight[j * 3 + 2]; 96 | 97 | int i1 = idx[j * 3 + 0]; 98 | int i2 = idx[j * 3 + 1]; 99 | int i3 = idx[j * 3 + 2]; 100 | 101 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 102 | points[l * m + i3] * w3; 103 | } 104 | } 105 | 106 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 107 | const float *points, const int *idx, 108 | const float *weight, float *out) { 109 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 110 | three_interpolate_kernel<<>>( 111 | b, c, m, n, points, idx, weight, out); 112 | 113 | CUDA_CHECK_ERRORS(); 114 | } 115 | 116 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 117 | // output: grad_points(b, c, m) 118 | 119 | __global__ void three_interpolate_grad_kernel( 120 | int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, 122 | float *__restrict__ grad_points) { 123 | int batch_index = blockIdx.x; 124 | grad_out += batch_index * n * c; 125 | idx += batch_index * n * 3; 126 | weight += batch_index * n * 3; 127 | grad_points += batch_index * m * c; 128 | 129 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 130 | const int stride = blockDim.y * blockDim.x; 131 | for (int i = index; i < c * n; i += stride) { 132 | const int l = i / n; 133 | const int j = i % n; 134 | float w1 = weight[j * 3 + 0]; 135 | float w2 = weight[j * 3 + 1]; 136 | float w3 = weight[j * 3 + 2]; 137 | 138 | int i1 = idx[j * 3 + 0]; 139 | int i2 = idx[j * 3 + 1]; 140 | int i3 = idx[j * 3 + 2]; 141 | 142 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 143 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 144 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 145 | } 146 | } 147 | 148 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 149 | const float *grad_out, 150 | const int *idx, const float *weight, 151 | float *grad_points) { 152 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 153 | three_interpolate_grad_kernel<<>>( 154 | b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | CUDA_CHECK_ERRORS(); 157 | } 158 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | #include "sampling.h" 4 | #include "utils.h" 5 | 6 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 7 | const float *points, const int *idx, 8 | float *out); 9 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *grad_out, const int *idx, 11 | float *grad_points); 12 | 13 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 14 | const float *dataset, float *temp, 15 | int *idxs); 16 | 17 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.is_cuda()) { 32 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 43 | const int n) { 44 | CHECK_CONTIGUOUS(grad_out); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_IS_FLOAT(grad_out); 47 | CHECK_IS_INT(idx); 48 | 49 | if (grad_out.is_cuda()) { 50 | CHECK_CUDA(idx); 51 | } 52 | 53 | at::Tensor output = 54 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 55 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 56 | 57 | if (grad_out.is_cuda()) { 58 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 59 | idx.size(1), grad_out.data(), 60 | idx.data(), output.data()); 61 | } else { 62 | AT_ASSERT(false, "CPU not supported"); 63 | } 64 | 65 | return output; 66 | } 67 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 68 | CHECK_CONTIGUOUS(points); 69 | CHECK_IS_FLOAT(points); 70 | 71 | at::Tensor output = 72 | torch::zeros({points.size(0), nsamples}, 73 | at::device(points.device()).dtype(at::ScalarType::Int)); 74 | 75 | at::Tensor tmp = 76 | torch::full({points.size(0), points.size(1)}, 1e10, 77 | at::device(points.device()).dtype(at::ScalarType::Float)); 78 | 79 | if (points.is_cuda()) { 80 | furthest_point_sampling_kernel_wrapper( 81 | points.size(0), points.size(1), nsamples, points.data(), 82 | tmp.data(), output.data()); 83 | } else { 84 | AT_ASSERT(false, "CPU not supported"); 85 | } 86 | 87 | return output; 88 | } 89 | -------------------------------------------------------------------------------- /third_party/pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | // input: points(b, c, n) idx(b, m) 10 | // output: out(b, c, m) 11 | __global__ void gather_points_kernel(int b, int c, int n, int m, 12 | const float *__restrict__ points, 13 | const int *__restrict__ idx, 14 | float *__restrict__ out) { 15 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 16 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 17 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 18 | int a = idx[i * m + j]; 19 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 20 | } 21 | } 22 | } 23 | } 24 | 25 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 26 | const float *points, const int *idx, 27 | float *out) { 28 | gather_points_kernel<<>>(b, c, n, npoints, 30 | points, idx, out); 31 | 32 | CUDA_CHECK_ERRORS(); 33 | } 34 | 35 | // input: grad_out(b, c, m) idx(b, m) 36 | // output: grad_points(b, c, n) 37 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 38 | const float *__restrict__ grad_out, 39 | const int *__restrict__ idx, 40 | float *__restrict__ grad_points) { 41 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 42 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 43 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 44 | int a = idx[i * m + j]; 45 | atomicAdd(grad_points + (i * c + l) * n + a, 46 | grad_out[(i * c + l) * m + j]); 47 | } 48 | } 49 | } 50 | } 51 | 52 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 53 | const float *grad_out, const int *idx, 54 | float *grad_points) { 55 | gather_points_grad_kernel<<>>( 57 | b, c, n, npoints, grad_out, idx, grad_points); 58 | 59 | CUDA_CHECK_ERRORS(); 60 | } 61 | 62 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 63 | int idx1, int idx2) { 64 | const float v1 = dists[idx1], v2 = dists[idx2]; 65 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 66 | dists[idx1] = max(v1, v2); 67 | dists_i[idx1] = v2 > v1 ? i2 : i1; 68 | } 69 | 70 | // Input dataset: (b, n, 3), tmp: (b, n) 71 | // Ouput idxs (b, m) 72 | template 73 | __global__ void furthest_point_sampling_kernel( 74 | int b, int n, int m, const float *__restrict__ dataset, 75 | float *__restrict__ temp, int *__restrict__ idxs) { 76 | if (m <= 0) return; 77 | __shared__ float dists[block_size]; 78 | __shared__ int dists_i[block_size]; 79 | 80 | int batch_index = blockIdx.x; 81 | dataset += batch_index * n * 3; 82 | temp += batch_index * n; 83 | idxs += batch_index * m; 84 | 85 | int tid = threadIdx.x; 86 | const int stride = block_size; 87 | 88 | int old = 0; 89 | if (threadIdx.x == 0) idxs[0] = old; 90 | 91 | __syncthreads(); 92 | for (int j = 1; j < m; j++) { 93 | int besti = 0; 94 | float best = -1; 95 | float x1 = dataset[old * 3 + 0]; 96 | float y1 = dataset[old * 3 + 1]; 97 | float z1 = dataset[old * 3 + 2]; 98 | for (int k = tid; k < n; k += stride) { 99 | float x2, y2, z2; 100 | x2 = dataset[k * 3 + 0]; 101 | y2 = dataset[k * 3 + 1]; 102 | z2 = dataset[k * 3 + 2]; 103 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 104 | if (mag <= 1e-3) continue; 105 | 106 | float d = 107 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 108 | 109 | float d2 = min(d, temp[k]); 110 | temp[k] = d2; 111 | besti = d2 > best ? k : besti; 112 | best = d2 > best ? d2 : best; 113 | } 114 | dists[tid] = best; 115 | dists_i[tid] = besti; 116 | __syncthreads(); 117 | 118 | if (block_size >= 512) { 119 | if (tid < 256) { 120 | __update(dists, dists_i, tid, tid + 256); 121 | } 122 | __syncthreads(); 123 | } 124 | if (block_size >= 256) { 125 | if (tid < 128) { 126 | __update(dists, dists_i, tid, tid + 128); 127 | } 128 | __syncthreads(); 129 | } 130 | if (block_size >= 128) { 131 | if (tid < 64) { 132 | __update(dists, dists_i, tid, tid + 64); 133 | } 134 | __syncthreads(); 135 | } 136 | if (block_size >= 64) { 137 | if (tid < 32) { 138 | __update(dists, dists_i, tid, tid + 32); 139 | } 140 | __syncthreads(); 141 | } 142 | if (block_size >= 32) { 143 | if (tid < 16) { 144 | __update(dists, dists_i, tid, tid + 16); 145 | } 146 | __syncthreads(); 147 | } 148 | if (block_size >= 16) { 149 | if (tid < 8) { 150 | __update(dists, dists_i, tid, tid + 8); 151 | } 152 | __syncthreads(); 153 | } 154 | if (block_size >= 8) { 155 | if (tid < 4) { 156 | __update(dists, dists_i, tid, tid + 4); 157 | } 158 | __syncthreads(); 159 | } 160 | if (block_size >= 4) { 161 | if (tid < 2) { 162 | __update(dists, dists_i, tid, tid + 2); 163 | } 164 | __syncthreads(); 165 | } 166 | if (block_size >= 2) { 167 | if (tid < 1) { 168 | __update(dists, dists_i, tid, tid + 1); 169 | } 170 | __syncthreads(); 171 | } 172 | 173 | old = dists_i[0]; 174 | if (tid == 0) idxs[j] = old; 175 | } 176 | } 177 | 178 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 179 | const float *dataset, float *temp, 180 | int *idxs) { 181 | unsigned int n_threads = opt_n_threads(n); 182 | 183 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 184 | 185 | switch (n_threads) { 186 | case 512: 187 | furthest_point_sampling_kernel<512> 188 | <<>>(b, n, m, dataset, temp, idxs); 189 | break; 190 | case 256: 191 | furthest_point_sampling_kernel<256> 192 | <<>>(b, n, m, dataset, temp, idxs); 193 | break; 194 | case 128: 195 | furthest_point_sampling_kernel<128> 196 | <<>>(b, n, m, dataset, temp, idxs); 197 | break; 198 | case 64: 199 | furthest_point_sampling_kernel<64> 200 | <<>>(b, n, m, dataset, temp, idxs); 201 | break; 202 | case 32: 203 | furthest_point_sampling_kernel<32> 204 | <<>>(b, n, m, dataset, temp, idxs); 205 | break; 206 | case 16: 207 | furthest_point_sampling_kernel<16> 208 | <<>>(b, n, m, dataset, temp, idxs); 209 | break; 210 | case 8: 211 | furthest_point_sampling_kernel<8> 212 | <<>>(b, n, m, dataset, temp, idxs); 213 | break; 214 | case 4: 215 | furthest_point_sampling_kernel<4> 216 | <<>>(b, n, m, dataset, temp, idxs); 217 | break; 218 | case 2: 219 | furthest_point_sampling_kernel<2> 220 | <<>>(b, n, m, dataset, temp, idxs); 221 | break; 222 | case 1: 223 | furthest_point_sampling_kernel<1> 224 | <<>>(b, n, m, dataset, temp, idxs); 225 | break; 226 | default: 227 | furthest_point_sampling_kernel<512> 228 | <<>>(b, n, m, dataset, temp, idxs); 229 | } 230 | 231 | CUDA_CHECK_ERRORS(); 232 | } 233 | -------------------------------------------------------------------------------- /third_party/pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | ''' Testing customized ops. ''' 4 | 5 | import torch 6 | from torch.autograd import gradcheck 7 | import numpy as np 8 | 9 | import os 10 | import sys 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | import pointnet2_utils 14 | 15 | def test_interpolation_grad(): 16 | batch_size = 1 17 | feat_dim = 2 18 | m = 4 19 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 20 | 21 | def interpolate_func(inputs): 22 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 23 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 24 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 25 | return interpolated_feats 26 | 27 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 28 | 29 | if __name__=='__main__': 30 | test_interpolation_grad() 31 | -------------------------------------------------------------------------------- /third_party/pointnet2/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | ''' Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 4 | from __future__ import ( 5 | division, 6 | absolute_import, 7 | with_statement, 8 | print_function, 9 | unicode_literals, 10 | ) 11 | import torch 12 | from torch.autograd import Function 13 | import torch.nn as nn 14 | import pytorch_utils as pt_utils 15 | import sys 16 | 17 | try: 18 | import builtins 19 | except: 20 | import __builtin__ as builtins 21 | 22 | try: 23 | import pointnet2._ext as _ext 24 | except ImportError: 25 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 26 | raise ImportError( 27 | "Could not import _ext module.\n" 28 | "Please see the setup instructions in the README: " 29 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 30 | ) 31 | 32 | if False: 33 | # Workaround for type hints without depending on the `typing` module 34 | from typing import * 35 | 36 | 37 | class RandomDropout(nn.Module): 38 | def __init__(self, p=0.5, inplace=False): 39 | super(RandomDropout, self).__init__() 40 | self.p = p 41 | self.inplace = inplace 42 | 43 | def forward(self, X): 44 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 45 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 46 | 47 | 48 | class FurthestPointSampling(Function): 49 | @staticmethod 50 | def forward(ctx, xyz, npoint): 51 | # type: (Any, torch.Tensor, int) -> torch.Tensor 52 | r""" 53 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 54 | minimum distance 55 | 56 | Parameters 57 | ---------- 58 | xyz : torch.Tensor 59 | (B, N, 3) tensor where N > npoint 60 | npoint : int32 61 | number of features in the sampled set 62 | 63 | Returns 64 | ------- 65 | torch.Tensor 66 | (B, npoint) tensor containing the set 67 | """ 68 | fps_inds = _ext.furthest_point_sampling(xyz, npoint) 69 | ctx.mark_non_differentiable(fps_inds) 70 | return fps_inds 71 | 72 | @staticmethod 73 | def backward(xyz, a=None): 74 | return None, None 75 | 76 | 77 | furthest_point_sample = FurthestPointSampling.apply 78 | 79 | 80 | class GatherOperation(Function): 81 | @staticmethod 82 | def forward(ctx, features, idx): 83 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 84 | r""" 85 | 86 | Parameters 87 | ---------- 88 | features : torch.Tensor 89 | (B, C, N) tensor 90 | 91 | idx : torch.Tensor 92 | (B, npoint) tensor of the features to gather 93 | 94 | Returns 95 | ------- 96 | torch.Tensor 97 | (B, C, npoint) tensor 98 | """ 99 | 100 | _, C, N = features.size() 101 | 102 | ctx.for_backwards = (idx, C, N) 103 | 104 | return _ext.gather_points(features, idx) 105 | 106 | @staticmethod 107 | def backward(ctx, grad_out): 108 | idx, C, N = ctx.for_backwards 109 | 110 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 111 | return grad_features, None 112 | 113 | 114 | gather_operation = GatherOperation.apply 115 | 116 | 117 | class ThreeNN(Function): 118 | @staticmethod 119 | def forward(ctx, unknown, known): 120 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 121 | r""" 122 | Find the three nearest neighbors of unknown in known 123 | Parameters 124 | ---------- 125 | unknown : torch.Tensor 126 | (B, n, 3) tensor of known features 127 | known : torch.Tensor 128 | (B, m, 3) tensor of unknown features 129 | 130 | Returns 131 | ------- 132 | dist : torch.Tensor 133 | (B, n, 3) l2 distance to the three nearest neighbors 134 | idx : torch.Tensor 135 | (B, n, 3) index of 3 nearest neighbors 136 | """ 137 | dist2, idx = _ext.three_nn(unknown, known) 138 | 139 | return torch.sqrt(dist2), idx 140 | 141 | @staticmethod 142 | def backward(ctx, a=None, b=None): 143 | return None, None 144 | 145 | 146 | three_nn = ThreeNN.apply 147 | 148 | 149 | class ThreeInterpolate(Function): 150 | @staticmethod 151 | def forward(ctx, features, idx, weight): 152 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 153 | r""" 154 | Performs weight linear interpolation on 3 features 155 | Parameters 156 | ---------- 157 | features : torch.Tensor 158 | (B, c, m) Features descriptors to be interpolated from 159 | idx : torch.Tensor 160 | (B, n, 3) three nearest neighbors of the target features in features 161 | weight : torch.Tensor 162 | (B, n, 3) weights 163 | 164 | Returns 165 | ------- 166 | torch.Tensor 167 | (B, c, n) tensor of the interpolated features 168 | """ 169 | B, c, m = features.size() 170 | n = idx.size(1) 171 | 172 | ctx.three_interpolate_for_backward = (idx, weight, m) 173 | 174 | return _ext.three_interpolate(features, idx, weight) 175 | 176 | @staticmethod 177 | def backward(ctx, grad_out): 178 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 179 | r""" 180 | Parameters 181 | ---------- 182 | grad_out : torch.Tensor 183 | (B, c, n) tensor with gradients of ouputs 184 | 185 | Returns 186 | ------- 187 | grad_features : torch.Tensor 188 | (B, c, m) tensor with gradients of features 189 | 190 | None 191 | 192 | None 193 | """ 194 | idx, weight, m = ctx.three_interpolate_for_backward 195 | 196 | grad_features = _ext.three_interpolate_grad( 197 | grad_out.contiguous(), idx, weight, m 198 | ) 199 | 200 | return grad_features, None, None 201 | 202 | 203 | three_interpolate = ThreeInterpolate.apply 204 | 205 | 206 | class GroupingOperation(Function): 207 | @staticmethod 208 | def forward(ctx, features, idx): 209 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 210 | r""" 211 | 212 | Parameters 213 | ---------- 214 | features : torch.Tensor 215 | (B, C, N) tensor of features to group 216 | idx : torch.Tensor 217 | (B, npoint, nsample) tensor containing the indicies of features to group with 218 | 219 | Returns 220 | ------- 221 | torch.Tensor 222 | (B, C, npoint, nsample) tensor 223 | """ 224 | B, nfeatures, nsample = idx.size() 225 | _, C, N = features.size() 226 | 227 | ctx.for_backwards = (idx, N) 228 | 229 | return _ext.group_points(features, idx) 230 | 231 | @staticmethod 232 | def backward(ctx, grad_out): 233 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 234 | r""" 235 | 236 | Parameters 237 | ---------- 238 | grad_out : torch.Tensor 239 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 240 | 241 | Returns 242 | ------- 243 | torch.Tensor 244 | (B, C, N) gradient of the features 245 | None 246 | """ 247 | idx, N = ctx.for_backwards 248 | 249 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 250 | 251 | return grad_features, None 252 | 253 | 254 | grouping_operation = GroupingOperation.apply 255 | 256 | 257 | class BallQuery(Function): 258 | @staticmethod 259 | def forward(ctx, radius, nsample, xyz, new_xyz): 260 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 261 | r""" 262 | 263 | Parameters 264 | ---------- 265 | radius : float 266 | radius of the balls 267 | nsample : int 268 | maximum number of features in the balls 269 | xyz : torch.Tensor 270 | (B, N, 3) xyz coordinates of the features 271 | new_xyz : torch.Tensor 272 | (B, npoint, 3) centers of the ball query 273 | 274 | Returns 275 | ------- 276 | torch.Tensor 277 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 278 | """ 279 | inds = _ext.ball_query(new_xyz, xyz, radius, nsample) 280 | ctx.mark_non_differentiable(inds) 281 | return inds 282 | 283 | @staticmethod 284 | def backward(ctx, a=None): 285 | return None, None, None, None 286 | 287 | 288 | ball_query = BallQuery.apply 289 | 290 | 291 | class QueryAndGroup(nn.Module): 292 | r""" 293 | Groups with a ball query of radius 294 | 295 | Parameters 296 | --------- 297 | radius : float32 298 | Radius of ball 299 | nsample : int32 300 | Maximum number of features to gather in the ball 301 | """ 302 | 303 | def __init__(self, radius, nsample, use_xyz=True, ret_grouped_xyz=False, normalize_xyz=False, sample_uniformly=False, ret_unique_cnt=False): 304 | # type: (QueryAndGroup, float, int, bool) -> None 305 | super(QueryAndGroup, self).__init__() 306 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 307 | self.ret_grouped_xyz = ret_grouped_xyz 308 | self.normalize_xyz = normalize_xyz 309 | self.sample_uniformly = sample_uniformly 310 | self.ret_unique_cnt = ret_unique_cnt 311 | if self.ret_unique_cnt: 312 | assert(self.sample_uniformly) 313 | 314 | def forward(self, xyz, new_xyz, features=None): 315 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 316 | r""" 317 | Parameters 318 | ---------- 319 | xyz : torch.Tensor 320 | xyz coordinates of the features (B, N, 3) 321 | new_xyz : torch.Tensor 322 | centriods (B, npoint, 3) 323 | features : torch.Tensor 324 | Descriptors of the features (B, C, N) 325 | 326 | Returns 327 | ------- 328 | new_features : torch.Tensor 329 | (B, 3 + C, npoint, nsample) tensor 330 | """ 331 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 332 | 333 | if self.sample_uniformly: 334 | unique_cnt = torch.zeros((idx.shape[0], idx.shape[1])) 335 | for i_batch in range(idx.shape[0]): 336 | for i_region in range(idx.shape[1]): 337 | unique_ind = torch.unique(idx[i_batch, i_region, :]) 338 | num_unique = unique_ind.shape[0] 339 | unique_cnt[i_batch, i_region] = num_unique 340 | sample_ind = torch.randint(0, num_unique, (self.nsample - num_unique,), dtype=torch.long) 341 | all_ind = torch.cat((unique_ind, unique_ind[sample_ind])) 342 | idx[i_batch, i_region, :] = all_ind 343 | 344 | 345 | xyz_trans = xyz.transpose(1, 2).contiguous() 346 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 347 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 348 | if self.normalize_xyz: 349 | grouped_xyz /= self.radius 350 | 351 | if features is not None: 352 | grouped_features = grouping_operation(features, idx) 353 | if self.use_xyz: 354 | new_features = torch.cat( 355 | [grouped_xyz, grouped_features], dim=1 356 | ) # (B, C + 3, npoint, nsample) 357 | else: 358 | new_features = grouped_features 359 | else: 360 | assert ( 361 | self.use_xyz 362 | ), "Cannot have not features and not use xyz as a feature!" 363 | new_features = grouped_xyz 364 | 365 | ret = [new_features] 366 | if self.ret_grouped_xyz: 367 | ret.append(grouped_xyz) 368 | if self.ret_unique_cnt: 369 | ret.append(unique_cnt) 370 | if len(ret) == 1: 371 | return ret[0] 372 | else: 373 | return tuple(ret) 374 | 375 | 376 | class GroupAll(nn.Module): 377 | r""" 378 | Groups all features 379 | 380 | Parameters 381 | --------- 382 | """ 383 | 384 | def __init__(self, use_xyz=True, ret_grouped_xyz=False): 385 | # type: (GroupAll, bool) -> None 386 | super(GroupAll, self).__init__() 387 | self.use_xyz = use_xyz 388 | 389 | def forward(self, xyz, new_xyz, features=None): 390 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 391 | r""" 392 | Parameters 393 | ---------- 394 | xyz : torch.Tensor 395 | xyz coordinates of the features (B, N, 3) 396 | new_xyz : torch.Tensor 397 | Ignored 398 | features : torch.Tensor 399 | Descriptors of the features (B, C, N) 400 | 401 | Returns 402 | ------- 403 | new_features : torch.Tensor 404 | (B, C + 3, 1, N) tensor 405 | """ 406 | 407 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 408 | if features is not None: 409 | grouped_features = features.unsqueeze(2) 410 | if self.use_xyz: 411 | new_features = torch.cat( 412 | [grouped_xyz, grouped_features], dim=1 413 | ) # (B, 3 + C, 1, N) 414 | else: 415 | new_features = grouped_features 416 | else: 417 | new_features = grouped_xyz 418 | 419 | if self.ret_grouped_xyz: 420 | return new_features, grouped_xyz 421 | else: 422 | return new_features 423 | -------------------------------------------------------------------------------- /third_party/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 4 | import torch 5 | import torch.nn as nn 6 | from typing import List, Tuple 7 | 8 | class SharedMLP(nn.Sequential): 9 | 10 | def __init__( 11 | self, 12 | args: List[int], 13 | *, 14 | bn: bool = False, 15 | activation=nn.ReLU(inplace=True), 16 | preact: bool = False, 17 | first: bool = False, 18 | name: str = "" 19 | ): 20 | super().__init__() 21 | 22 | for i in range(len(args) - 1): 23 | self.add_module( 24 | name + 'layer{}'.format(i), 25 | Conv2d( 26 | args[i], 27 | args[i + 1], 28 | bn=(not first or not preact or (i != 0)) and bn, 29 | activation=activation 30 | if (not first or not preact or (i != 0)) else None, 31 | preact=preact 32 | ) 33 | ) 34 | 35 | 36 | class _BNBase(nn.Sequential): 37 | 38 | def __init__(self, in_size, batch_norm=None, name=""): 39 | super().__init__() 40 | self.add_module(name + "bn", batch_norm(in_size)) 41 | 42 | nn.init.constant_(self[0].weight, 1.0) 43 | nn.init.constant_(self[0].bias, 0) 44 | 45 | 46 | class BatchNorm1d(_BNBase): 47 | 48 | def __init__(self, in_size: int, *, name: str = ""): 49 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 50 | 51 | 52 | class BatchNorm2d(_BNBase): 53 | 54 | def __init__(self, in_size: int, name: str = ""): 55 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 56 | 57 | 58 | class BatchNorm3d(_BNBase): 59 | 60 | def __init__(self, in_size: int, name: str = ""): 61 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 62 | 63 | 64 | class _ConvBase(nn.Sequential): 65 | 66 | def __init__( 67 | self, 68 | in_size, 69 | out_size, 70 | kernel_size, 71 | stride, 72 | padding, 73 | activation, 74 | bn, 75 | init, 76 | conv=None, 77 | batch_norm=None, 78 | bias=True, 79 | preact=False, 80 | name="" 81 | ): 82 | super().__init__() 83 | 84 | bias = bias and (not bn) 85 | conv_unit = conv( 86 | in_size, 87 | out_size, 88 | kernel_size=kernel_size, 89 | stride=stride, 90 | padding=padding, 91 | bias=bias 92 | ) 93 | init(conv_unit.weight) 94 | if bias: 95 | nn.init.constant_(conv_unit.bias, 0) 96 | 97 | if bn: 98 | if not preact: 99 | bn_unit = batch_norm(out_size) 100 | else: 101 | bn_unit = batch_norm(in_size) 102 | 103 | if preact: 104 | if bn: 105 | self.add_module(name + 'bn', bn_unit) 106 | 107 | if activation is not None: 108 | self.add_module(name + 'activation', activation) 109 | 110 | self.add_module(name + 'conv', conv_unit) 111 | 112 | if not preact: 113 | if bn: 114 | self.add_module(name + 'bn', bn_unit) 115 | 116 | if activation is not None: 117 | self.add_module(name + 'activation', activation) 118 | 119 | 120 | class Conv1d(_ConvBase): 121 | 122 | def __init__( 123 | self, 124 | in_size: int, 125 | out_size: int, 126 | *, 127 | kernel_size: int = 1, 128 | stride: int = 1, 129 | padding: int = 0, 130 | activation=nn.ReLU(inplace=True), 131 | bn: bool = False, 132 | init=nn.init.kaiming_normal_, 133 | bias: bool = True, 134 | preact: bool = False, 135 | name: str = "" 136 | ): 137 | super().__init__( 138 | in_size, 139 | out_size, 140 | kernel_size, 141 | stride, 142 | padding, 143 | activation, 144 | bn, 145 | init, 146 | conv=nn.Conv1d, 147 | batch_norm=BatchNorm1d, 148 | bias=bias, 149 | preact=preact, 150 | name=name 151 | ) 152 | 153 | 154 | class Conv2d(_ConvBase): 155 | 156 | def __init__( 157 | self, 158 | in_size: int, 159 | out_size: int, 160 | *, 161 | kernel_size: Tuple[int, int] = (1, 1), 162 | stride: Tuple[int, int] = (1, 1), 163 | padding: Tuple[int, int] = (0, 0), 164 | activation=nn.ReLU(inplace=True), 165 | bn: bool = False, 166 | init=nn.init.kaiming_normal_, 167 | bias: bool = True, 168 | preact: bool = False, 169 | name: str = "" 170 | ): 171 | super().__init__( 172 | in_size, 173 | out_size, 174 | kernel_size, 175 | stride, 176 | padding, 177 | activation, 178 | bn, 179 | init, 180 | conv=nn.Conv2d, 181 | batch_norm=BatchNorm2d, 182 | bias=bias, 183 | preact=preact, 184 | name=name 185 | ) 186 | 187 | 188 | class Conv3d(_ConvBase): 189 | 190 | def __init__( 191 | self, 192 | in_size: int, 193 | out_size: int, 194 | *, 195 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 196 | stride: Tuple[int, int, int] = (1, 1, 1), 197 | padding: Tuple[int, int, int] = (0, 0, 0), 198 | activation=nn.ReLU(inplace=True), 199 | bn: bool = False, 200 | init=nn.init.kaiming_normal_, 201 | bias: bool = True, 202 | preact: bool = False, 203 | name: str = "" 204 | ): 205 | super().__init__( 206 | in_size, 207 | out_size, 208 | kernel_size, 209 | stride, 210 | padding, 211 | activation, 212 | bn, 213 | init, 214 | conv=nn.Conv3d, 215 | batch_norm=BatchNorm3d, 216 | bias=bias, 217 | preact=preact, 218 | name=name 219 | ) 220 | 221 | 222 | class FC(nn.Sequential): 223 | 224 | def __init__( 225 | self, 226 | in_size: int, 227 | out_size: int, 228 | *, 229 | activation=nn.ReLU(inplace=True), 230 | bn: bool = False, 231 | init=None, 232 | preact: bool = False, 233 | name: str = "" 234 | ): 235 | super().__init__() 236 | 237 | fc = nn.Linear(in_size, out_size, bias=not bn) 238 | if init is not None: 239 | init(fc.weight) 240 | if not bn: 241 | nn.init.constant_(fc.bias, 0) 242 | 243 | if preact: 244 | if bn: 245 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 246 | 247 | if activation is not None: 248 | self.add_module(name + 'activation', activation) 249 | 250 | self.add_module(name + 'fc', fc) 251 | 252 | if not preact: 253 | if bn: 254 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 255 | 256 | if activation is not None: 257 | self.add_module(name + 'activation', activation) 258 | 259 | def set_bn_momentum_default(bn_momentum): 260 | 261 | def fn(m): 262 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 263 | m.momentum = bn_momentum 264 | 265 | return fn 266 | 267 | 268 | class BNMomentumScheduler(object): 269 | 270 | def __init__( 271 | self, model, bn_lambda, last_epoch=-1, 272 | setter=set_bn_momentum_default 273 | ): 274 | if not isinstance(model, nn.Module): 275 | raise RuntimeError( 276 | "Class '{}' is not a PyTorch nn Module".format( 277 | type(model).__name__ 278 | ) 279 | ) 280 | 281 | self.model = model 282 | self.setter = setter 283 | self.lmbd = bn_lambda 284 | 285 | self.step(last_epoch + 1) 286 | self.last_epoch = last_epoch 287 | 288 | def step(self, epoch=None): 289 | if epoch is None: 290 | epoch = self.last_epoch + 1 291 | 292 | self.last_epoch = epoch 293 | self.model.apply(self.setter(self.lmbd(epoch))) 294 | 295 | 296 | -------------------------------------------------------------------------------- /third_party/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os.path as osp 10 | 11 | this_dir = osp.dirname(osp.abspath(__file__)) 12 | 13 | _ext_src_root = "_ext_src" 14 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 15 | "{}/src/*.cu".format(_ext_src_root) 16 | ) 17 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 18 | 19 | setup( 20 | name='pointnet2', 21 | ext_modules=[ 22 | CUDAExtension( 23 | name='pointnet2._ext', 24 | sources=_ext_sources, 25 | extra_compile_args={ 26 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 27 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 28 | }, 29 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 30 | ) 31 | ], 32 | cmdclass={ 33 | 'build_ext': BuildExtension 34 | } 35 | ) 36 | -------------------------------------------------------------------------------- /utils/box_intersection.pyx: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | cimport numpy as np 4 | cimport cython 5 | cdef bint boolean_variable = True 6 | np.import_array() 7 | 8 | 9 | FLOAT = np.float32 10 | 11 | @cython.boundscheck(False) 12 | @cython.wraparound(False) 13 | def computeIntersection(cp1, cp2, s, e): 14 | dc = [ cp1[0] - cp2[0], cp1[1] - cp2[1] ] 15 | dp = [ s[0] - e[0], s[1] - e[1] ] 16 | n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] 17 | n2 = s[0] * e[1] - s[1] * e[0] 18 | n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) 19 | return [(n1*dp[0] - n2*dc[0]) * n3, (n1*dp[1] - n2*dc[1]) * n3] 20 | 21 | @cython.boundscheck(False) 22 | @cython.wraparound(False) 23 | cdef inline bint inside(cp1, cp2, p): 24 | return(cp2[0]-cp1[0])*(p[1]-cp1[1]) > (cp2[1]-cp1[1])*(p[0]-cp1[0]) 25 | 26 | @cython.boundscheck(False) 27 | def polygon_clip_unnest(float [:, :] subjectPolygon, float [:, :] clipPolygon): 28 | """ Clip a polygon with another polygon. 29 | 30 | Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python 31 | 32 | Args: 33 | subjectPolygon: a list of (x,y) 2d points, any polygon. 34 | clipPolygon: a list of (x,y) 2d points, has to be *convex* 35 | Note: 36 | **points have to be counter-clockwise ordered** 37 | 38 | Return: 39 | a list of (x,y) vertex point for the intersection polygon. 40 | """ 41 | outputList = [subjectPolygon[x] for x in range(subjectPolygon.shape[0])] 42 | cp1 = clipPolygon[-1] 43 | cdef int lenc = len(clipPolygon) 44 | cdef int iidx = 0 45 | 46 | # for clipVertex in clipPolygon: 47 | for cidx in range(lenc): 48 | clipVertex = clipPolygon[cidx] 49 | cp2 = clipVertex 50 | inputList = outputList.copy() 51 | outputList.clear() 52 | s = inputList[-1] 53 | 54 | inc = len(inputList) 55 | 56 | # for subjectVertex in inputList: 57 | for iidx in range(inc): 58 | subjectVertex = inputList[iidx] 59 | e = subjectVertex 60 | if inside(cp1, cp2, e): 61 | if not inside(cp1, cp2, s): 62 | outputList.append(computeIntersection(cp1, cp2, s, e)) 63 | outputList.append(e) 64 | elif inside(cp1, cp2, s): 65 | outputList.append(computeIntersection(cp1, cp2, s, e)) 66 | s = e 67 | cp1 = cp2 68 | if len(outputList) == 0: 69 | break 70 | return outputList 71 | 72 | 73 | @cython.boundscheck(False) 74 | @cython.wraparound(False) 75 | cdef void copy_points(float[:, :] src, float[:, :] dst, Py_ssize_t num_points): 76 | cdef Py_ssize_t i 77 | for i in range(num_points): 78 | dst[i][0] = src[i][0] 79 | dst[i][1] = src[i][1] 80 | 81 | 82 | @cython.boundscheck(False) 83 | @cython.wraparound(False) 84 | cdef inline Py_ssize_t add_point(float[:, :] arr, float[:] point, Py_ssize_t num_points): 85 | # assert num_points < arr.shape[0] - 1 86 | # for j in range(dim): 87 | arr[num_points][0] = point[0] 88 | arr[num_points][1] = point[1] 89 | num_points = num_points + 1 90 | return num_points 91 | 92 | @cython.boundscheck(False) 93 | @cython.wraparound(False) 94 | cdef Py_ssize_t computeIntersection_and_add(float[:] cp1, float[:] cp2, float[:] s, float[:] e, float[:, :] arr, Py_ssize_t num_points): 95 | # dc_np = np.zeros(2, dtype=np.float32) 96 | cdef float[2] dc 97 | dc[0] = cp1[0] - cp2[0] 98 | dc[1] = cp1[1] - cp2[1] 99 | 100 | # dp_np = np.zeros(2, dtype=np.float32) 101 | cdef float[2] dp 102 | dp[0] = s[0] - e[0] 103 | dp[1] = s[1] - e[1] 104 | 105 | cdef float n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] 106 | cdef float n2 = s[0] * e[1] - s[1] * e[0] 107 | cdef float n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) 108 | 109 | arr[num_points][0] = (n1*dp[0] - n2*dc[0]) * n3 110 | arr[num_points][1] = (n1*dp[1] - n2*dc[1]) * n3 111 | num_points = num_points + 1 112 | 113 | return num_points 114 | 115 | @cython.boundscheck(False) 116 | @cython.wraparound(False) 117 | def polygon_clip_float(float [:, :] subjectPolygon, float [:, :] clipPolygon): 118 | """ 119 | Assumes subjectPolygon and clipPolygon have 4 vertices 120 | """ 121 | cdef Py_ssize_t num_clip_points = clipPolygon.shape[0] 122 | cp1 = clipPolygon[num_clip_points - 1] 123 | 124 | MAX_INTERSECT_POINTS = 10 125 | num_intersect_points = 0 126 | outputList_np = np.zeros((MAX_INTERSECT_POINTS, 2), dtype=np.float32) 127 | cdef float[:, :] outputList = outputList_np 128 | 129 | inputList_np = np.zeros((MAX_INTERSECT_POINTS, 2), dtype=np.float32) 130 | cdef float[:, :] inputList = inputList_np 131 | 132 | copy_points(subjectPolygon, outputList, subjectPolygon.shape[0]) 133 | cdef Py_ssize_t noutput_list = subjectPolygon.shape[0] 134 | cdef Py_ssize_t ninput_list = 0 135 | cdef Py_ssize_t iidx = 0 136 | 137 | for cidx in range(num_clip_points): 138 | clipVertex = clipPolygon[cidx] 139 | cp2 = clipVertex 140 | 141 | copy_points(outputList, inputList, noutput_list) 142 | ninput_list = noutput_list 143 | noutput_list = 0 144 | 145 | s = inputList[ninput_list - 1] 146 | 147 | for iidx in range(ninput_list): 148 | e = inputList[iidx] 149 | if inside(cp1, cp2, e): 150 | if not inside(cp1, cp2, s): 151 | noutput_list = computeIntersection_and_add(cp1, cp2, s, e, outputList, noutput_list) 152 | 153 | noutput_list = add_point(outputList, e, noutput_list) 154 | elif inside(cp1, cp2, s): 155 | noutput_list = computeIntersection_and_add(cp1, cp2, s, e, outputList, noutput_list) 156 | s = e 157 | cp1 = cp2 158 | if noutput_list == 0: 159 | break 160 | return outputList_np, noutput_list 161 | 162 | 163 | 164 | @cython.boundscheck(False) 165 | @cython.wraparound(False) 166 | def box_intersection(float [:, :, :, :] rect1, 167 | float [:, :, :, :] rect2, 168 | float [:, :, :] non_rot_inter_areas, 169 | int[:] nums_k2, 170 | float [:, :, :] inter_areas, 171 | bint approximate): 172 | """ 173 | rect1 - B x K1 x 8 x 3 matrix of box corners 174 | rect2 - B x K2 x 8 x 3 matrix of box corners 175 | non_rot_inter_areas - intersection areas of boxes 176 | """ 177 | 178 | cdef Py_ssize_t B = rect1.shape[0] 179 | cdef Py_ssize_t K1 = rect1.shape[1] 180 | cdef Py_ssize_t K2 = rect2.shape[2] 181 | 182 | 183 | for b in range(B): 184 | for k1 in range(K1): 185 | for k2 in range(K2): 186 | if k2 >= nums_k2[b]: 187 | break 188 | 189 | if approximate and non_rot_inter_areas[b][k1][k2] == 0: 190 | continue 191 | 192 | ##### compute volume of intersection 193 | inter = polygon_clip_unnest(rect1[b, k1], rect2[b, k2]) 194 | ninter = len(inter) 195 | if ninter > 0: # there is some intersection between the boxes 196 | xs = np.array([x[0] for x in inter]).astype(dtype=FLOAT) 197 | ys = np.array([x[1] for x in inter]).astype(dtype=FLOAT) 198 | inter_areas[b,k1,k2] = 0.5 * np.abs(np.dot(xs,np.roll(ys,1))-np.dot(ys,np.roll(xs,1))) 199 | 200 | 201 | -------------------------------------------------------------------------------- /utils/cython_compile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from setuptools import setup, Extension 4 | from Cython.Build import cythonize 5 | import numpy as np 6 | 7 | 8 | # hacky way to find numpy include path 9 | # replace with actual path if this does not work 10 | np_include_path = np.__file__.replace("__init__.py", "core/include/") 11 | INCLUDE_PATH = [ 12 | np_include_path 13 | ] 14 | 15 | setup( 16 | ext_modules = cythonize( 17 | Extension( 18 | "box_intersection", 19 | sources=["box_intersection.pyx"], 20 | include_dirs=INCLUDE_PATH 21 | )), 22 | ) 23 | 24 | -------------------------------------------------------------------------------- /utils/cython_compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | python cython_compile.py build_ext --inplace 4 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import pickle 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def is_distributed(): 9 | if not dist.is_available() or not dist.is_initialized(): 10 | return False 11 | return True 12 | 13 | 14 | def get_rank(): 15 | if not is_distributed(): 16 | return 0 17 | return dist.get_rank() 18 | 19 | 20 | def is_primary(): 21 | return get_rank() == 0 22 | 23 | 24 | def get_world_size(): 25 | if not is_distributed(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def barrier(): 31 | if not is_distributed(): 32 | return 33 | torch.distributed.barrier() 34 | 35 | 36 | def setup_print_for_distributed(is_primary): 37 | """ 38 | This function disables printing when not in primary process 39 | """ 40 | import builtins as __builtin__ 41 | builtin_print = __builtin__.print 42 | 43 | def print(*args, **kwargs): 44 | force = kwargs.pop('force', False) 45 | if is_primary or force: 46 | builtin_print(*args, **kwargs) 47 | 48 | __builtin__.print = print 49 | 50 | 51 | def init_distributed(gpu_id, global_rank, world_size, dist_url, dist_backend): 52 | torch.cuda.set_device(gpu_id) 53 | print( 54 | f"| distributed init (rank {global_rank}) (world {world_size}): {dist_url}", 55 | flush=True, 56 | ) 57 | torch.distributed.init_process_group( 58 | backend=dist_backend, 59 | init_method=dist_url, 60 | world_size=world_size, 61 | rank=global_rank, 62 | ) 63 | torch.distributed.barrier() 64 | setup_print_for_distributed(is_primary()) 65 | 66 | 67 | def all_reduce_sum(tensor): 68 | if not is_distributed(): 69 | return tensor 70 | dim_squeeze = False 71 | if tensor.ndim == 0: 72 | tensor = tensor[None, ...] 73 | dim_squeeze = True 74 | torch.distributed.all_reduce(tensor) 75 | if dim_squeeze: 76 | tensor = tensor.squeeze(0) 77 | return tensor 78 | 79 | 80 | def all_reduce_average(tensor): 81 | val = all_reduce_sum(tensor) 82 | return val / get_world_size() 83 | 84 | 85 | # Function from DETR - https://github.com/facebookresearch/detr/blob/master/util/misc.py 86 | def reduce_dict(input_dict, average=True): 87 | """ 88 | Args: 89 | input_dict (dict): all the values will be reduced 90 | average (bool): whether to do average or sum 91 | Reduce the values in the dictionary from all processes so that all processes 92 | have the averaged results. Returns a dict with the same fields as 93 | input_dict, after reduction. 94 | """ 95 | world_size = get_world_size() 96 | if world_size < 2: 97 | return input_dict 98 | with torch.no_grad(): 99 | names = [] 100 | values = [] 101 | # sort the keys so that they are consistent across processes 102 | for k in sorted(input_dict.keys()): 103 | names.append(k) 104 | values.append(input_dict[k]) 105 | values = torch.stack(values, dim=0) 106 | torch.distributed.all_reduce(values) 107 | if average: 108 | values /= world_size 109 | reduced_dict = {k: v for k, v in zip(names, values)} 110 | return reduced_dict 111 | 112 | 113 | # Function from https://github.com/facebookresearch/detr/blob/master/util/misc.py 114 | def all_gather_pickle(data, device): 115 | """ 116 | Run all_gather on arbitrary picklable data (not necessarily tensors) 117 | Args: 118 | data: any picklable object 119 | Returns: 120 | list[data]: list of data gathered from each rank 121 | """ 122 | world_size = get_world_size() 123 | if world_size == 1: 124 | return [data] 125 | 126 | # serialized to a Tensor 127 | buffer = pickle.dumps(data) 128 | storage = torch.ByteStorage.from_buffer(buffer) 129 | tensor = torch.ByteTensor(storage).to(device) 130 | 131 | # obtain Tensor size of each rank 132 | local_size = torch.tensor([tensor.numel()], device=device) 133 | size_list = [torch.tensor([0], device=device) for _ in range(world_size)] 134 | dist.all_gather(size_list, local_size) 135 | size_list = [int(size.item()) for size in size_list] 136 | max_size = max(size_list) 137 | 138 | # receiving Tensor from all ranks 139 | # we pad the tensor because torch all_gather does not support 140 | # gathering tensors of different shapes 141 | tensor_list = [] 142 | for _ in size_list: 143 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) 144 | if local_size != max_size: 145 | padding = torch.empty( 146 | size=(max_size - local_size,), dtype=torch.uint8, device=device 147 | ) 148 | tensor = torch.cat((tensor, padding), dim=0) 149 | dist.all_gather(tensor_list, tensor) 150 | 151 | data_list = [] 152 | for size, tensor in zip(size_list, tensor_list): 153 | buffer = tensor.cpu().numpy().tobytes()[:size] 154 | data_list.append(pickle.loads(buffer)) 155 | 156 | return data_list 157 | 158 | 159 | def all_gather_dict(data): 160 | """ 161 | Run all_gather on data which is a dictionary of Tensors 162 | """ 163 | assert isinstance(data, dict) 164 | 165 | gathered_dict = {} 166 | for item_key in data: 167 | if isinstance(data[item_key], torch.Tensor): 168 | if is_distributed(): 169 | data[item_key] = data[item_key].contiguous() 170 | tensor_list = [torch.empty_like(data[item_key]) for _ in range(get_world_size())] 171 | dist.all_gather(tensor_list, data[item_key]) 172 | gathered_tensor = torch.cat(tensor_list, dim=0) 173 | else: 174 | gathered_tensor = data[item_key] 175 | gathered_dict[item_key] = gathered_tensor 176 | return gathered_dict 177 | -------------------------------------------------------------------------------- /utils/download_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | 4 | import os 5 | from urllib import request 6 | import torch 7 | import pickle 8 | 9 | ## Define the weights you want and where to store them 10 | dataset = "scannet" 11 | encoder = "_masked" # or "" 12 | epoch = 1080 13 | base_url = "https://dl.fbaipublicfiles.com/3detr/checkpoints" 14 | local_dir = "/tmp/" 15 | 16 | ### Downloading the weights 17 | weights_file = f"{dataset}{encoder}_ep{epoch}.pth" 18 | metrics_file = f"{dataset}{encoder}_ep{epoch}_metrics.pkl" 19 | local_weights = os.path.join(local_dir, weights_file) 20 | local_metrics = os.path.join(local_dir, metrics_file) 21 | 22 | url = os.path.join(base_url, weights_file) 23 | request.urlretrieve(url, local_weights) 24 | print(f"Downloaded weights from {url} to {local_weights}") 25 | 26 | url = os.path.join(base_url, metrics_file) 27 | request.urlretrieve(url, local_metrics) 28 | print(f"Downloaded metrics from {url} to {local_metrics}") 29 | 30 | # weights can be simply loaded with pytorch 31 | weights = torch.load(local_weights, map_location=torch.device("cpu")) 32 | print("Weights loaded successfully.") 33 | 34 | # metrics can be loaded with pickle 35 | with open(local_metrics, "rb") as fh: 36 | metrics = pickle.load(fh) 37 | print("Metrics loaded successfully.") -------------------------------------------------------------------------------- /utils/eval_det.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """ Generic Code for Object Detection Evaluation 4 | 5 | Input: 6 | For each class: 7 | For each image: 8 | Predictions: box, score 9 | Groundtruths: box 10 | 11 | Output: 12 | For each class: 13 | precision-recal and average precision 14 | 15 | Author: Charles R. Qi 16 | 17 | Ref: https://raw.githubusercontent.com/rbgirshick/py-faster-rcnn/master/lib/datasets/voc_eval.py 18 | """ 19 | import numpy as np 20 | from utils.box_util import box3d_iou 21 | 22 | 23 | def voc_ap(rec, prec, use_07_metric=False): 24 | """ap = voc_ap(rec, prec, [use_07_metric]) 25 | Compute VOC AP given precision and recall. 26 | If use_07_metric is true, uses the 27 | VOC 07 11 point method (default:False). 28 | """ 29 | if use_07_metric: 30 | # 11 point metric 31 | ap = 0.0 32 | for t in np.arange(0.0, 1.1, 0.1): 33 | if np.sum(rec >= t) == 0: 34 | p = 0 35 | else: 36 | p = np.max(prec[rec >= t]) 37 | ap = ap + p / 11.0 38 | else: 39 | # correct AP calculation 40 | # first append sentinel values at the end 41 | mrec = np.concatenate(([0.0], rec, [1.0])) 42 | mpre = np.concatenate(([0.0], prec, [0.0])) 43 | 44 | # compute the precision envelope 45 | for i in range(mpre.size - 1, 0, -1): 46 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 47 | 48 | # to calculate area under PR curve, look for points 49 | # where X axis (recall) changes value 50 | i = np.where(mrec[1:] != mrec[:-1])[0] 51 | 52 | # and sum (\Delta recall) * prec 53 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 54 | return ap 55 | 56 | 57 | def get_iou_obb(bb1, bb2): 58 | iou3d, iou2d = box3d_iou(bb1, bb2) 59 | return iou3d 60 | 61 | 62 | def get_iou_main(get_iou_func, args): 63 | return get_iou_func(*args) 64 | 65 | 66 | def eval_det_cls( 67 | pred, gt, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou_obb 68 | ): 69 | """Generic functions to compute precision/recall for object detection 70 | for a single class. 71 | Input: 72 | pred: map of {img_id: [(bbox, score)]} where bbox is numpy array 73 | gt: map of {img_id: [bbox]} 74 | ovthresh: scalar, iou threshold 75 | use_07_metric: bool, if True use VOC07 11 point method 76 | Output: 77 | rec: numpy array of length nd 78 | prec: numpy array of length nd 79 | ap: scalar, average precision 80 | """ 81 | 82 | # construct gt objects 83 | class_recs = {} # {img_id: {'bbox': bbox list, 'det': matched list}} 84 | npos = 0 85 | for img_id in gt.keys(): 86 | bbox = np.array(gt[img_id]) 87 | det = [False] * len(bbox) 88 | npos += len(bbox) 89 | class_recs[img_id] = {"bbox": bbox, "det": det} 90 | # pad empty list to all other imgids 91 | for img_id in pred.keys(): 92 | if img_id not in gt: 93 | class_recs[img_id] = {"bbox": np.array([]), "det": []} 94 | 95 | # construct dets 96 | image_ids = [] 97 | confidence = [] 98 | BB = [] 99 | for img_id in pred.keys(): 100 | for box, score in pred[img_id]: 101 | image_ids.append(img_id) 102 | confidence.append(score) 103 | BB.append(box) 104 | confidence = np.array(confidence) 105 | BB = np.array(BB) # (nd,4 or 8,3 or 6) 106 | 107 | # sort by confidence 108 | sorted_ind = np.argsort(-confidence) 109 | sorted_scores = np.sort(-confidence) 110 | BB = BB[sorted_ind, ...] 111 | image_ids = [image_ids[x] for x in sorted_ind] 112 | 113 | # go down dets and mark TPs and FPs 114 | nd = len(image_ids) 115 | tp = np.zeros(nd) 116 | fp = np.zeros(nd) 117 | for d in range(nd): 118 | # if d%100==0: print(d) 119 | R = class_recs[image_ids[d]] 120 | bb = BB[d, ...].astype(float) 121 | ovmax = -np.inf 122 | BBGT = R["bbox"].astype(float) 123 | 124 | if BBGT.size > 0: 125 | # compute overlaps 126 | for j in range(BBGT.shape[0]): 127 | iou = get_iou_main(get_iou_func, (bb, BBGT[j, ...])) 128 | if iou > ovmax: 129 | ovmax = iou 130 | jmax = j 131 | 132 | # print d, ovmax 133 | if ovmax > ovthresh: 134 | if not R["det"][jmax]: 135 | tp[d] = 1.0 136 | R["det"][jmax] = 1 137 | else: 138 | fp[d] = 1.0 139 | else: 140 | fp[d] = 1.0 141 | 142 | # compute precision recall 143 | fp = np.cumsum(fp) 144 | tp = np.cumsum(tp) 145 | if npos == 0: 146 | rec = np.zeros_like(tp) 147 | else: 148 | rec = tp / float(npos) 149 | # print('NPOS: ', npos) 150 | # avoid divide by zero in case the first detection matches a difficult 151 | # ground truth 152 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 153 | ap = voc_ap(rec, prec, use_07_metric) 154 | 155 | return rec, prec, ap 156 | 157 | 158 | def eval_det_cls_wrapper(arguments): 159 | pred, gt, ovthresh, use_07_metric, get_iou_func = arguments 160 | rec, prec, ap = eval_det_cls(pred, gt, ovthresh, use_07_metric, get_iou_func) 161 | return (rec, prec, ap) 162 | 163 | 164 | def eval_det(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=None): 165 | """Generic functions to compute precision/recall for object detection 166 | for multiple classes. 167 | Input: 168 | pred_all: map of {img_id: [(classname, bbox, score)]} 169 | gt_all: map of {img_id: [(classname, bbox)]} 170 | ovthresh: scalar, iou threshold 171 | use_07_metric: bool, if true use VOC07 11 point method 172 | Output: 173 | rec: {classname: rec} 174 | prec: {classname: prec_all} 175 | ap: {classname: scalar} 176 | """ 177 | pred = {} # map {classname: pred} 178 | gt = {} # map {classname: gt} 179 | for img_id in pred_all.keys(): 180 | for classname, bbox, score in pred_all[img_id]: 181 | if classname not in pred: 182 | pred[classname] = {} 183 | if img_id not in pred[classname]: 184 | pred[classname][img_id] = [] 185 | if classname not in gt: 186 | gt[classname] = {} 187 | if img_id not in gt[classname]: 188 | gt[classname][img_id] = [] 189 | pred[classname][img_id].append((bbox, score)) 190 | for img_id in gt_all.keys(): 191 | for classname, bbox in gt_all[img_id]: 192 | if classname not in gt: 193 | gt[classname] = {} 194 | if img_id not in gt[classname]: 195 | gt[classname][img_id] = [] 196 | gt[classname][img_id].append(bbox) 197 | 198 | rec = {} 199 | prec = {} 200 | ap = {} 201 | for classname in gt.keys(): 202 | # print('Computing AP for class: ', classname) 203 | rec[classname], prec[classname], ap[classname] = eval_det_cls( 204 | pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func 205 | ) 206 | # print(classname, ap[classname]) 207 | 208 | return rec, prec, ap 209 | 210 | 211 | from multiprocessing import Pool 212 | 213 | 214 | def eval_det_multiprocessing( 215 | pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou_obb 216 | ): 217 | """Generic functions to compute precision/recall for object detection 218 | for multiple classes. 219 | Input: 220 | pred_all: map of {img_id: [(classname, bbox, score)]} 221 | gt_all: map of {img_id: [(classname, bbox)]} 222 | ovthresh: scalar, iou threshold 223 | use_07_metric: bool, if true use VOC07 11 point method 224 | Output: 225 | rec: {classname: rec} 226 | prec: {classname: prec_all} 227 | ap: {classname: scalar} 228 | """ 229 | pred = {} # map {classname: pred} 230 | gt = {} # map {classname: gt} 231 | for img_id in pred_all.keys(): 232 | for classname, bbox, score in pred_all[img_id]: 233 | if classname not in pred: 234 | pred[classname] = {} 235 | if img_id not in pred[classname]: 236 | pred[classname][img_id] = [] 237 | if classname not in gt: 238 | gt[classname] = {} 239 | if img_id not in gt[classname]: 240 | gt[classname][img_id] = [] 241 | pred[classname][img_id].append((bbox, score)) 242 | for img_id in gt_all.keys(): 243 | for classname, bbox in gt_all[img_id]: 244 | if classname not in gt: 245 | gt[classname] = {} 246 | if img_id not in gt[classname]: 247 | gt[classname][img_id] = [] 248 | gt[classname][img_id].append(bbox) 249 | 250 | rec = {} 251 | prec = {} 252 | ap = {} 253 | p = Pool(processes=10) 254 | ret_values = p.map( 255 | eval_det_cls_wrapper, 256 | [ 257 | (pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func) 258 | for classname in gt.keys() 259 | if classname in pred 260 | ], 261 | ) 262 | p.close() 263 | for i, classname in enumerate(gt.keys()): 264 | if classname in pred: 265 | rec[classname], prec[classname], ap[classname] = ret_values[i] 266 | else: 267 | rec[classname] = 0 268 | prec[classname] = 0 269 | ap[classname] = 0 270 | # print(classname, ap[classname]) 271 | 272 | return rec, prec, ap 273 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | import os 5 | from utils.dist import is_primary 6 | 7 | 8 | def save_checkpoint( 9 | checkpoint_dir, 10 | model_no_ddp, 11 | optimizer, 12 | epoch, 13 | args, 14 | best_val_metrics, 15 | filename=None, 16 | ): 17 | if not is_primary(): 18 | return 19 | if filename is None: 20 | filename = f"checkpoint_{epoch:04d}.pth" 21 | checkpoint_name = os.path.join(checkpoint_dir, filename) 22 | 23 | sd = { 24 | "model": model_no_ddp.state_dict(), 25 | "optimizer": optimizer.state_dict(), 26 | "epoch": epoch, 27 | "args": args, 28 | "best_val_metrics": best_val_metrics, 29 | } 30 | torch.save(sd, checkpoint_name) 31 | 32 | 33 | def resume_if_possible(checkpoint_dir, model_no_ddp, optimizer): 34 | """ 35 | Resume if checkpoint is available. 36 | Return 37 | - epoch of loaded checkpoint. 38 | """ 39 | epoch = -1 40 | best_val_metrics = {} 41 | if not os.path.isdir(checkpoint_dir): 42 | return epoch, best_val_metrics 43 | 44 | last_checkpoint = os.path.join(checkpoint_dir, "checkpoint.pth") 45 | if not os.path.isfile(last_checkpoint): 46 | return epoch, best_val_metrics 47 | 48 | sd = torch.load(last_checkpoint, map_location=torch.device("cpu")) 49 | epoch = sd["epoch"] 50 | best_val_metrics = sd["best_val_metrics"] 51 | print(f"Found checkpoint at {epoch}. Resuming.") 52 | 53 | model_no_ddp.load_state_dict(sd["model"]) 54 | optimizer.load_state_dict(sd["optimizer"]) 55 | print( 56 | f"Loaded model and optimizer state at {epoch}. Loaded best val metrics so far." 57 | ) 58 | return epoch, best_val_metrics 59 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import torch 4 | 5 | try: 6 | from tensorboardX import SummaryWriter 7 | except ImportError: 8 | print("Cannot import tensorboard. Will log to txt files only.") 9 | SummaryWriter = None 10 | 11 | from utils.dist import is_primary 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, log_dir=None) -> None: 16 | self.log_dir = log_dir 17 | if SummaryWriter is not None and is_primary(): 18 | self.writer = SummaryWriter(self.log_dir) 19 | else: 20 | self.writer = None 21 | 22 | def log_scalars(self, scalar_dict, step, prefix=None): 23 | if self.writer is None: 24 | return 25 | for k in scalar_dict: 26 | v = scalar_dict[k] 27 | if isinstance(v, torch.Tensor): 28 | v = v.detach().cpu().item() 29 | if prefix is not None: 30 | k = prefix + k 31 | self.writer.add_scalar(k, v, step) 32 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import numpy as np 4 | from collections import deque 5 | from typing import List 6 | from utils.dist import is_distributed, barrier, all_reduce_sum 7 | 8 | 9 | def my_worker_init_fn(worker_id): 10 | np.random.seed(np.random.get_state()[1][0] + worker_id) 11 | 12 | 13 | @torch.jit.ignore 14 | def to_list_1d(arr) -> List[float]: 15 | arr = arr.detach().cpu().numpy().tolist() 16 | return arr 17 | 18 | 19 | @torch.jit.ignore 20 | def to_list_3d(arr) -> List[List[List[float]]]: 21 | arr = arr.detach().cpu().numpy().tolist() 22 | return arr 23 | 24 | 25 | def huber_loss(error, delta=1.0): 26 | """ 27 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 28 | x = error = pred - gt or dist(pred,gt) 29 | 0.5 * |x|^2 if |x|<=d 30 | 0.5 * d^2 + d * (|x|-d) if |x|>d 31 | """ 32 | abs_error = torch.abs(error) 33 | quadratic = torch.clamp(abs_error, max=delta) 34 | linear = abs_error - quadratic 35 | loss = 0.5 * quadratic ** 2 + delta * linear 36 | return loss 37 | 38 | 39 | # From https://github.com/facebookresearch/detr/blob/master/util/misc.py 40 | class SmoothedValue(object): 41 | """Track a series of values and provide access to smoothed values over a 42 | window or the global series average. 43 | """ 44 | 45 | def __init__(self, window_size=20, fmt=None): 46 | if fmt is None: 47 | fmt = "{median:.4f} ({global_avg:.4f})" 48 | self.deque = deque(maxlen=window_size) 49 | self.total = 0.0 50 | self.count = 0 51 | self.fmt = fmt 52 | 53 | def update(self, value, n=1): 54 | self.deque.append(value) 55 | self.count += n 56 | self.total += value * n 57 | 58 | def synchronize_between_processes(self): 59 | """ 60 | Warning: does not synchronize the deque! 61 | """ 62 | if not is_distributed(): 63 | return 64 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 65 | barrier() 66 | all_reduce_sum(t) 67 | t = t.tolist() 68 | self.count = int(t[0]) 69 | self.total = t[1] 70 | 71 | @property 72 | def median(self): 73 | d = torch.tensor(list(self.deque)) 74 | return d.median().item() 75 | 76 | @property 77 | def avg(self): 78 | d = torch.tensor(list(self.deque), dtype=torch.float32) 79 | return d.mean().item() 80 | 81 | @property 82 | def global_avg(self): 83 | return self.total / self.count 84 | 85 | @property 86 | def max(self): 87 | return max(self.deque) 88 | 89 | @property 90 | def value(self): 91 | return self.deque[-1] 92 | 93 | def __str__(self): 94 | return self.fmt.format( 95 | median=self.median, 96 | avg=self.avg, 97 | global_avg=self.global_avg, 98 | max=self.max, 99 | value=self.value, 100 | ) 101 | -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import numpy as np 4 | 5 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score) 6 | """ Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ 7 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py 8 | """ 9 | 10 | 11 | def nms_2d(boxes, overlap_threshold): 12 | x1 = boxes[:, 0] 13 | y1 = boxes[:, 1] 14 | x2 = boxes[:, 2] 15 | y2 = boxes[:, 3] 16 | score = boxes[:, 4] 17 | area = (x2 - x1) * (y2 - y1) 18 | 19 | I = np.argsort(score) 20 | pick = [] 21 | while I.size != 0: 22 | last = I.size 23 | i = I[-1] 24 | pick.append(i) 25 | suppress = [last - 1] 26 | for pos in range(last - 1): 27 | j = I[pos] 28 | xx1 = max(x1[i], x1[j]) 29 | yy1 = max(y1[i], y1[j]) 30 | xx2 = min(x2[i], x2[j]) 31 | yy2 = min(y2[i], y2[j]) 32 | w = xx2 - xx1 33 | h = yy2 - yy1 34 | if w > 0 and h > 0: 35 | o = w * h / area[j] 36 | print("Overlap is", o) 37 | if o > overlap_threshold: 38 | suppress.append(pos) 39 | I = np.delete(I, suppress) 40 | return pick 41 | 42 | 43 | def nms_2d_faster(boxes, overlap_threshold, old_type=False): 44 | x1 = boxes[:, 0] 45 | y1 = boxes[:, 1] 46 | x2 = boxes[:, 2] 47 | y2 = boxes[:, 3] 48 | score = boxes[:, 4] 49 | area = (x2 - x1) * (y2 - y1) 50 | 51 | I = np.argsort(score) 52 | pick = [] 53 | while I.size != 0: 54 | last = I.size 55 | i = I[-1] 56 | pick.append(i) 57 | 58 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 59 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 60 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 61 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 62 | 63 | w = np.maximum(0, xx2 - xx1) 64 | h = np.maximum(0, yy2 - yy1) 65 | 66 | if old_type: 67 | o = (w * h) / area[I[: last - 1]] 68 | else: 69 | inter = w * h 70 | o = inter / (area[i] + area[I[: last - 1]] - inter) 71 | 72 | I = np.delete( 73 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 74 | ) 75 | 76 | return pick 77 | 78 | 79 | def nms_3d_faster(boxes, overlap_threshold, old_type=False): 80 | x1 = boxes[:, 0] 81 | y1 = boxes[:, 1] 82 | z1 = boxes[:, 2] 83 | x2 = boxes[:, 3] 84 | y2 = boxes[:, 4] 85 | z2 = boxes[:, 5] 86 | score = boxes[:, 6] 87 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 88 | 89 | I = np.argsort(score) 90 | pick = [] 91 | while I.size != 0: 92 | last = I.size 93 | i = I[-1] 94 | pick.append(i) 95 | 96 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 97 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 98 | zz1 = np.maximum(z1[i], z1[I[: last - 1]]) 99 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 100 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 101 | zz2 = np.minimum(z2[i], z2[I[: last - 1]]) 102 | 103 | l = np.maximum(0, xx2 - xx1) 104 | w = np.maximum(0, yy2 - yy1) 105 | h = np.maximum(0, zz2 - zz1) 106 | 107 | if old_type: 108 | o = (l * w * h) / area[I[: last - 1]] 109 | else: 110 | inter = l * w * h 111 | o = inter / (area[i] + area[I[: last - 1]] - inter) 112 | 113 | I = np.delete( 114 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 115 | ) 116 | 117 | return pick 118 | 119 | 120 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False): 121 | x1 = boxes[:, 0] 122 | y1 = boxes[:, 1] 123 | z1 = boxes[:, 2] 124 | x2 = boxes[:, 3] 125 | y2 = boxes[:, 4] 126 | z2 = boxes[:, 5] 127 | score = boxes[:, 6] 128 | cls = boxes[:, 7] 129 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 130 | 131 | I = np.argsort(score) 132 | pick = [] 133 | while I.size != 0: 134 | last = I.size 135 | i = I[-1] 136 | pick.append(i) 137 | 138 | xx1 = np.maximum(x1[i], x1[I[: last - 1]]) 139 | yy1 = np.maximum(y1[i], y1[I[: last - 1]]) 140 | zz1 = np.maximum(z1[i], z1[I[: last - 1]]) 141 | xx2 = np.minimum(x2[i], x2[I[: last - 1]]) 142 | yy2 = np.minimum(y2[i], y2[I[: last - 1]]) 143 | zz2 = np.minimum(z2[i], z2[I[: last - 1]]) 144 | cls1 = cls[i] 145 | cls2 = cls[I[: last - 1]] 146 | 147 | l = np.maximum(0, xx2 - xx1) 148 | w = np.maximum(0, yy2 - yy1) 149 | h = np.maximum(0, zz2 - zz1) 150 | 151 | if old_type: 152 | o = (l * w * h) / area[I[: last - 1]] 153 | else: 154 | inter = l * w * h 155 | o = inter / (area[i] + area[I[: last - 1]] - inter) 156 | o = o * (cls1 == cls2) 157 | 158 | I = np.delete( 159 | I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])) 160 | ) 161 | 162 | return pick 163 | -------------------------------------------------------------------------------- /utils/pc_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """ Utility functions for processing point clouds. 4 | 5 | Author: Charles R. Qi and Or Litany 6 | """ 7 | 8 | import os 9 | import sys 10 | import torch 11 | 12 | # Point cloud IO 13 | import numpy as np 14 | from plyfile import PlyData, PlyElement 15 | 16 | # Mesh IO 17 | import trimesh 18 | 19 | # ---------------------------------------- 20 | # Point Cloud Sampling 21 | # ---------------------------------------- 22 | 23 | 24 | def random_sampling(pc, num_sample, replace=None, return_choices=False): 25 | """Input is NxC, output is num_samplexC""" 26 | if replace is None: 27 | replace = pc.shape[0] < num_sample 28 | choices = np.random.choice(pc.shape[0], num_sample, replace=replace) 29 | if return_choices: 30 | return pc[choices], choices 31 | else: 32 | return pc[choices] 33 | 34 | 35 | # ---------------------------------------- 36 | # Simple Point manipulations 37 | # ---------------------------------------- 38 | def shift_scale_points(pred_xyz, src_range, dst_range=None): 39 | """ 40 | pred_xyz: B x N x 3 41 | src_range: [[B x 3], [B x 3]] - min and max XYZ coords 42 | dst_range: [[B x 3], [B x 3]] - min and max XYZ coords 43 | """ 44 | if dst_range is None: 45 | dst_range = [ 46 | torch.zeros((src_range[0].shape[0], 3), device=src_range[0].device), 47 | torch.ones((src_range[0].shape[0], 3), device=src_range[0].device), 48 | ] 49 | 50 | if pred_xyz.ndim == 4: 51 | src_range = [x[:, None] for x in src_range] 52 | dst_range = [x[:, None] for x in dst_range] 53 | 54 | assert src_range[0].shape[0] == pred_xyz.shape[0] 55 | assert dst_range[0].shape[0] == pred_xyz.shape[0] 56 | assert src_range[0].shape[-1] == pred_xyz.shape[-1] 57 | assert src_range[0].shape == src_range[1].shape 58 | assert dst_range[0].shape == dst_range[1].shape 59 | assert src_range[0].shape == dst_range[1].shape 60 | 61 | src_diff = src_range[1][:, None, :] - src_range[0][:, None, :] 62 | dst_diff = dst_range[1][:, None, :] - dst_range[0][:, None, :] 63 | prop_xyz = ( 64 | ((pred_xyz - src_range[0][:, None, :]) * dst_diff) / src_diff 65 | ) + dst_range[0][:, None, :] 66 | return prop_xyz 67 | 68 | 69 | def scale_points(pred_xyz, mult_factor): 70 | if pred_xyz.ndim == 4: 71 | mult_factor = mult_factor[:, None] 72 | scaled_xyz = pred_xyz * mult_factor[:, None, :] 73 | return scaled_xyz 74 | 75 | 76 | def rotate_point_cloud(points, rotation_matrix=None): 77 | """Input: (n,3), Output: (n,3)""" 78 | # Rotate in-place around Z axis. 79 | if rotation_matrix is None: 80 | rotation_angle = np.random.uniform() * 2 * np.pi 81 | sinval, cosval = np.sin(rotation_angle), np.cos(rotation_angle) 82 | rotation_matrix = np.array( 83 | [[cosval, sinval, 0], [-sinval, cosval, 0], [0, 0, 1]] 84 | ) 85 | ctr = points.mean(axis=0) 86 | rotated_data = np.dot(points - ctr, rotation_matrix) + ctr 87 | return rotated_data, rotation_matrix 88 | 89 | 90 | def rotate_pc_along_y(pc, rot_angle): 91 | """Input ps is NxC points with first 3 channels as XYZ 92 | z is facing forward, x is left ward, y is downward 93 | """ 94 | cosval = np.cos(rot_angle) 95 | sinval = np.sin(rot_angle) 96 | rotmat = np.array([[cosval, -sinval], [sinval, cosval]]) 97 | pc[:, [0, 2]] = np.dot(pc[:, [0, 2]], np.transpose(rotmat)) 98 | return pc 99 | 100 | 101 | def roty(t): 102 | """Rotation about the y-axis.""" 103 | c = np.cos(t) 104 | s = np.sin(t) 105 | return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]]) 106 | 107 | 108 | def roty_batch(t): 109 | """Rotation about the y-axis. 110 | t: (x1,x2,...xn) 111 | return: (x1,x2,...,xn,3,3) 112 | """ 113 | input_shape = t.shape 114 | output = np.zeros(tuple(list(input_shape) + [3, 3])) 115 | c = np.cos(t) 116 | s = np.sin(t) 117 | output[..., 0, 0] = c 118 | output[..., 0, 2] = s 119 | output[..., 1, 1] = 1 120 | output[..., 2, 0] = -s 121 | output[..., 2, 2] = c 122 | return output 123 | 124 | 125 | def rotz(t): 126 | """Rotation about the z-axis.""" 127 | c = np.cos(t) 128 | s = np.sin(t) 129 | return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) 130 | 131 | 132 | def point_cloud_to_bbox(points): 133 | """Extract the axis aligned box from a pcl or batch of pcls 134 | Args: 135 | points: Nx3 points or BxNx3 136 | output is 6 dim: xyz pos of center and 3 lengths 137 | """ 138 | which_dim = len(points.shape) - 2 # first dim if a single cloud and second if batch 139 | mn, mx = points.min(which_dim), points.max(which_dim) 140 | lengths = mx - mn 141 | cntr = 0.5 * (mn + mx) 142 | return np.concatenate([cntr, lengths], axis=which_dim) 143 | 144 | 145 | def write_bbox(scene_bbox, out_filename): 146 | """Export scene bbox to meshes 147 | Args: 148 | scene_bbox: (N x 6 numpy array): xyz pos of center and 3 lengths 149 | out_filename: (string) filename 150 | 151 | Note: 152 | To visualize the boxes in MeshLab. 153 | 1. Select the objects (the boxes) 154 | 2. Filters -> Polygon and Quad Mesh -> Turn into Quad-Dominant Mesh 155 | 3. Select Wireframe view. 156 | """ 157 | 158 | def convert_box_to_trimesh_fmt(box): 159 | ctr = box[:3] 160 | lengths = box[3:] 161 | trns = np.eye(4) 162 | trns[0:3, 3] = ctr 163 | trns[3, 3] = 1.0 164 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 165 | return box_trimesh_fmt 166 | 167 | scene = trimesh.scene.Scene() 168 | for box in scene_bbox: 169 | scene.add_geometry(convert_box_to_trimesh_fmt(box)) 170 | 171 | mesh_list = trimesh.util.concatenate(scene.dump()) 172 | # save to ply file 173 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 174 | 175 | return 176 | 177 | 178 | def write_oriented_bbox(scene_bbox, out_filename, colors=None): 179 | """Export oriented (around Z axis) scene bbox to meshes 180 | Args: 181 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 182 | and heading angle around Z axis. 183 | Y forward, X right, Z upward. heading angle of positive X is 0, 184 | heading angle of positive Y is 90 degrees. 185 | out_filename: (string) filename 186 | """ 187 | 188 | def heading2rotmat(heading_angle): 189 | pass 190 | rotmat = np.zeros((3, 3)) 191 | rotmat[2, 2] = 1 192 | cosval = np.cos(heading_angle) 193 | sinval = np.sin(heading_angle) 194 | rotmat[0:2, 0:2] = np.array([[cosval, -sinval], [sinval, cosval]]) 195 | return rotmat 196 | 197 | def convert_oriented_box_to_trimesh_fmt(box): 198 | ctr = box[:3] 199 | lengths = box[3:6] 200 | trns = np.eye(4) 201 | trns[0:3, 3] = ctr 202 | trns[3, 3] = 1.0 203 | trns[0:3, 0:3] = heading2rotmat(box[6]) 204 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 205 | return box_trimesh_fmt 206 | 207 | if colors is not None: 208 | if colors.shape[0] != len(scene_bbox): 209 | colors = [colors for _ in range(len(scene_bbox))] 210 | colors = np.array(colors).astype(np.uint8) 211 | assert colors.shape[0] == len(scene_bbox) 212 | assert colors.shape[1] == 4 213 | 214 | scene = trimesh.scene.Scene() 215 | for idx, box in enumerate(scene_bbox): 216 | box_tr = convert_oriented_box_to_trimesh_fmt(box) 217 | if colors is not None: 218 | box_tr.visual.main_color[:] = colors[idx] 219 | box_tr.visual.vertex_colors[:] = colors[idx] 220 | for facet in box_tr.facets: 221 | box_tr.visual.face_colors[facet] = colors[idx] 222 | scene.add_geometry(box_tr) 223 | 224 | mesh_list = trimesh.util.concatenate(scene.dump()) 225 | # save to ply file 226 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 227 | 228 | return 229 | 230 | 231 | def write_oriented_bbox_camera_coord(scene_bbox, out_filename): 232 | """Export oriented (around Y axis) scene bbox to meshes 233 | Args: 234 | scene_bbox: (N x 7 numpy array): xyz pos of center and 3 lengths (dx,dy,dz) 235 | and heading angle around Y axis. 236 | Z forward, X rightward, Y downward. heading angle of positive X is 0, 237 | heading angle of negative Z is 90 degrees. 238 | out_filename: (string) filename 239 | """ 240 | 241 | def heading2rotmat(heading_angle): 242 | pass 243 | rotmat = np.zeros((3, 3)) 244 | rotmat[1, 1] = 1 245 | cosval = np.cos(heading_angle) 246 | sinval = np.sin(heading_angle) 247 | rotmat[0, :] = np.array([cosval, 0, sinval]) 248 | rotmat[2, :] = np.array([-sinval, 0, cosval]) 249 | return rotmat 250 | 251 | def convert_oriented_box_to_trimesh_fmt(box): 252 | ctr = box[:3] 253 | lengths = box[3:6] 254 | trns = np.eye(4) 255 | trns[0:3, 3] = ctr 256 | trns[3, 3] = 1.0 257 | trns[0:3, 0:3] = heading2rotmat(box[6]) 258 | box_trimesh_fmt = trimesh.creation.box(lengths, trns) 259 | return box_trimesh_fmt 260 | 261 | scene = trimesh.scene.Scene() 262 | for box in scene_bbox: 263 | scene.add_geometry(convert_oriented_box_to_trimesh_fmt(box)) 264 | 265 | mesh_list = trimesh.util.concatenate(scene.dump()) 266 | # save to ply file 267 | trimesh.io.export.export_mesh(mesh_list, out_filename, file_type="ply") 268 | 269 | return 270 | 271 | 272 | def write_lines_as_cylinders(pcl, filename, rad=0.005, res=64): 273 | """Create lines represented as cylinders connecting pairs of 3D points 274 | Args: 275 | pcl: (N x 2 x 3 numpy array): N pairs of xyz pos 276 | filename: (string) filename for the output mesh (ply) file 277 | rad: radius for the cylinder 278 | res: number of sections used to create the cylinder 279 | """ 280 | scene = trimesh.scene.Scene() 281 | for src, tgt in pcl: 282 | # compute line 283 | vec = tgt - src 284 | M = trimesh.geometry.align_vectors([0, 0, 1], vec, False) 285 | vec = tgt - src # compute again since align_vectors modifies vec in-place! 286 | M[:3, 3] = 0.5 * src + 0.5 * tgt 287 | height = np.sqrt(np.dot(vec, vec)) 288 | scene.add_geometry( 289 | trimesh.creation.cylinder( 290 | radius=rad, height=height, sections=res, transform=M 291 | ) 292 | ) 293 | mesh_list = trimesh.util.concatenate(scene.dump()) 294 | trimesh.io.export.export_mesh(mesh_list, "%s.ply" % (filename), file_type="ply") 295 | -------------------------------------------------------------------------------- /utils/random_cuboid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import numpy as np 3 | 4 | 5 | def check_aspect(crop_range, aspect_min): 6 | xy_aspect = np.min(crop_range[:2]) / np.max(crop_range[:2]) 7 | xz_aspect = np.min(crop_range[[0, 2]]) / np.max(crop_range[[0, 2]]) 8 | yz_aspect = np.min(crop_range[1:]) / np.max(crop_range[1:]) 9 | return ( 10 | (xy_aspect >= aspect_min) 11 | or (xz_aspect >= aspect_min) 12 | or (yz_aspect >= aspect_min) 13 | ) 14 | 15 | 16 | class RandomCuboid(object): 17 | """ 18 | RandomCuboid augmentation from DepthContrast [https://arxiv.org/abs/2101.02691] 19 | We slightly modify this operation to account for object detection. 20 | This augmentation randomly crops a cuboid from the input and 21 | ensures that the cropped cuboid contains at least one bounding box 22 | """ 23 | 24 | def __init__( 25 | self, 26 | min_points, 27 | aspect=0.8, 28 | min_crop=0.5, 29 | max_crop=1.0, 30 | box_filter_policy="center", 31 | ): 32 | self.aspect = aspect 33 | self.min_crop = min_crop 34 | self.max_crop = max_crop 35 | self.min_points = min_points 36 | self.box_filter_policy = box_filter_policy 37 | 38 | def __call__(self, point_cloud, target_boxes, per_point_labels=None): 39 | range_xyz = np.max(point_cloud[:, 0:3], axis=0) - np.min( 40 | point_cloud[:, 0:3], axis=0 41 | ) 42 | 43 | for _ in range(100): 44 | crop_range = self.min_crop + np.random.rand(3) * ( 45 | self.max_crop - self.min_crop 46 | ) 47 | if not check_aspect(crop_range, self.aspect): 48 | continue 49 | 50 | sample_center = point_cloud[np.random.choice(len(point_cloud)), 0:3] 51 | 52 | new_range = range_xyz * crop_range / 2.0 53 | 54 | max_xyz = sample_center + new_range 55 | min_xyz = sample_center - new_range 56 | 57 | upper_idx = ( 58 | np.sum((point_cloud[:, 0:3] <= max_xyz).astype(np.int32), 1) == 3 59 | ) 60 | lower_idx = ( 61 | np.sum((point_cloud[:, 0:3] >= min_xyz).astype(np.int32), 1) == 3 62 | ) 63 | 64 | new_pointidx = (upper_idx) & (lower_idx) 65 | 66 | if np.sum(new_pointidx) < self.min_points: 67 | continue 68 | 69 | new_point_cloud = point_cloud[new_pointidx, :] 70 | 71 | # filtering policy is the only modification from DepthContrast 72 | if self.box_filter_policy == "center": 73 | # remove boxes whose center does not lie within the new_point_cloud 74 | new_boxes = target_boxes 75 | if ( 76 | target_boxes.sum() > 0 77 | ): # ground truth contains no bounding boxes. Common in SUNRGBD. 78 | box_centers = target_boxes[:, 0:3] 79 | new_pc_min_max = np.min(new_point_cloud[:, 0:3], axis=0), np.max( 80 | new_point_cloud[:, 0:3], axis=0 81 | ) 82 | keep_boxes = np.logical_and( 83 | np.all(box_centers >= new_pc_min_max[0], axis=1), 84 | np.all(box_centers <= new_pc_min_max[1], axis=1), 85 | ) 86 | if keep_boxes.sum() == 0: 87 | # current data augmentation removes all boxes in the pointcloud. fail! 88 | continue 89 | new_boxes = target_boxes[keep_boxes] 90 | if per_point_labels is not None: 91 | new_per_point_labels = [x[new_pointidx] for x in per_point_labels] 92 | else: 93 | new_per_point_labels = None 94 | # if we are here, all conditions are met. return boxes 95 | return new_point_cloud, new_boxes, new_per_point_labels 96 | 97 | # fallback 98 | return point_cloud, target_boxes, per_point_labels 99 | --------------------------------------------------------------------------------