├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── assets └── global_calib.txt ├── configs ├── augmentation │ └── inference_preprocessing.yaml ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ ├── rich_progress_bar.yaml │ └── wandb.yaml ├── convert.yaml ├── datamodule │ └── kitti_datamodule.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── detector │ ├── yolov5.yaml │ └── yolov5_kitti.yaml ├── eval.yaml ├── evaluate.yaml ├── experiment │ └── sample.yaml ├── extras │ └── default.yaml ├── hparams_search │ ├── mnist_optuna.yaml │ └── optuna.yaml ├── hydra │ └── default.yaml ├── inference.yaml ├── local │ └── .gitkeep ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── mlflow.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ └── regressor.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── dgx.yaml │ ├── gpu.yaml │ ├── kaggle.yaml │ └── mps.yaml ├── convert.py ├── data └── datasplit.py ├── docs ├── assets │ ├── demo.gif │ ├── logo.png │ └── show.png ├── command.md ├── index.md └── javascripts │ └── mathjax.js ├── inference.py ├── kitti_object_eval ├── LICENSE ├── README.md ├── eval.py ├── evaluate.py ├── kitti_common.py ├── rotate_iou.py └── run.sh ├── logs └── .gitkeep ├── mkdocs.yml ├── notebooks └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── scripts ├── frames_to_video.py ├── generate_sets.py ├── get_weights.py ├── kitti_to_yolo.py ├── post_weights.py ├── schedule.sh ├── video_to_frame.py └── video_to_gif.py ├── setup.py ├── src ├── __init__.py ├── datamodules │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── kitti_dataset.py │ └── kitti_datamodule.py ├── eval.py ├── models │ ├── __init__.py │ ├── components │ │ ├── __init__.py │ │ └── base.py │ └── regressor.py ├── train.py └── utils │ ├── Calib.py │ ├── Math.py │ ├── Plotting.py │ ├── __init__.py │ ├── averages.py │ ├── class_averages-L4.txt │ ├── class_averages-kitti6.txt │ ├── class_averages.txt │ ├── eval.py │ ├── kitti_common.py │ ├── pylogger.py │ ├── rich_utils.py │ ├── rotate_iou.py │ └── utils.py ├── tests ├── __init__.py ├── conftest.py ├── helpers │ ├── __init__.py │ ├── package_available.py │ ├── run_if.py │ └── run_sh_command.py ├── test_configs.py ├── test_eval.py ├── test_mnist_datamodule.py ├── test_sweeps.py └── test_train.py ├── tmp └── .gitkeep └── weights └── get_regressor_weights.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtualenv 2 | /.venv/ 3 | /venv/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | /bin/ 14 | /build/ 15 | /develop-eggs/ 16 | /dist/ 17 | /eggs/ 18 | /lib/ 19 | /lib64/ 20 | /output/ 21 | /parts/ 22 | /sdist/ 23 | /var/ 24 | /*.egg-info/ 25 | /.installed.cfg 26 | /*.egg 27 | /.eggs 28 | 29 | # AUTHORS and ChangeLog will be generated while packaging 30 | /AUTHORS 31 | /ChangeLog 32 | 33 | # BCloud / BuildSubmitter 34 | /build_submitter.* 35 | /logger_client_log 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | .tox/ 43 | .coverage 44 | .cache 45 | .pytest_cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Sphinx documentation 53 | /docs/_build/ 54 | 55 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2021-2022 Megvii Inc. All rights reserved. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | format: ## Run pre-commit hooks 17 | pre-commit run -a 18 | 19 | sync: ## Merge changes from main branch to your current branch 20 | git pull 21 | git pull origin main 22 | 23 | test: ## Run not slow tests 24 | pytest -k "not slow" 25 | 26 | test-full: ## Run all tests 27 | pytest 28 | 29 | train: ## Train the model 30 | python src/train.py 31 | 32 | debug: ## Enter debugging mode with pdb 33 | # 34 | # tips: 35 | # - use "import pdb; pdb.set_trace()" to set breakpoint 36 | # - use "h" to print all commands 37 | # - use "n" to execute the next line 38 | # - use "c" to run until the breakpoint is hit 39 | # - use "l" to print src code around current line, "ll" for full function code 40 | # - docs: https://docs.python.org/3/library/pdb.html 41 | # 42 | python -m pdb src/train.py debug=default 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # YOLO3D: 3D Object Detection with YOLO 4 |
5 | 6 | ## Introduction 7 | 8 | YOLO3D is inspired by [Mousavian et al.](https://arxiv.org/abs/1612.00496) in their paper **3D Bounding Box Estimation Using Deep Learning and Geometry**. YOLO3D uses a different approach, we use 2d gt label result as the input of first stage detector, then use the 2d result as input to regressor model. 9 | 10 | ## Quickstart 11 | ```bash 12 | git clone git@github.com:ApolloAuto/apollo-model-yolo3d.git 13 | ``` 14 | 15 | ### creat env for YOLO3D 16 | ```shell 17 | cd apollo-model-yolo3d 18 | 19 | conda create -n apollo_yolo3d python=3.8 numpy 20 | conda activate apollo_yolo3d 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ### datasets 25 | here we use KITTI data to train. You can download KITTI dataset from [official website](http://www.cvlibs.net/datasets/kitti/). After that, extract dataset to `data/KITTI`. 26 | 27 | ```shell 28 | ln -s /your/KITTI/path data/KITTI 29 | ``` 30 | 31 | ```bash 32 | ├── data 33 | │   └── KITTI 34 | │   ├── calib 35 | │   ├── images_2 36 | │   └── labels_2 37 | ``` 38 | modify [datasplit](data/datasplit.py) file to split train and val data customerly. 39 | 40 | ```shell 41 | cd data 42 | python datasplit.py 43 | ``` 44 | 45 | ### train 46 | modify [train.yaml](configs/train.yaml) to train your model. 47 | 48 | ```shell 49 | python src/train.py experiment=sample 50 | ``` 51 | > log path: /logs \ 52 | > model path: /weights 53 | 54 | ### covert 55 | modify [convert.yaml](configs/convert.yaml) file to trans .ckpt to .pt model 56 | 57 | ```shell 58 | python convert.py 59 | ``` 60 | 61 | ### inference 62 | In order to show the real model infer ability, we crop image according to gt 2d box as yolo3d input, you can use following command to plot 3d results. 63 | 64 | modify [inference.yaml](configs/inference.yaml) file to change .pt model path. 65 | **export_onnx=True** can export onnx model. 66 | 67 | ```shell 68 | python inference.py \ 69 | source_dir=./data/KITTI \ 70 | detector.classes=6 \ 71 | regressor_weights=./weights/pytorch-kitti.pt \ 72 | export_onnx=False \ 73 | func=image 74 | ``` 75 | 76 | - source_dir: path os datasets, include /image_2 and /label_2 folder 77 | - detector.classes: kitti class 78 | - regressor_weights: your model 79 | - export_onnx: export onnx model for apollo 80 | 81 | > result path: /outputs 82 | 83 | ### evaluate 84 | generate label for 3d result: 85 | ```shell 86 | python inference.py \ 87 | source_dir=./data/KITTI \ 88 | detector.classes=6 \ 89 | regressor_weights=./weights/pytorch-kitti.pt \ 90 | export_onnx=False \ 91 | func=label 92 | ``` 93 | > result path: /data/KITTI/result 94 | 95 | ```bash 96 | ├── data 97 | │   └── KITTI 98 | │   ├── calib 99 | │   ├── images_2 100 | │   ├── labels_2 101 | │   └── result 102 | ``` 103 | 104 | modify label_path、result_path and label_split_file in [kitti_object_eval](kitti_object_eval) folder script run.sh, with the help of it we can calculate mAP: 105 | ```shell 106 | cd kitti_object_eval 107 | sh run.sh 108 | ``` 109 | 110 | ## Acknowledgement 111 | - [yolo3d-lighting](https://github.com/ruhyadi/yolo3d-lightning) 112 | - [skhadem/3D-BoundingBox](https://github.com/skhadem/3D-BoundingBox) 113 | - [Mousavian et al.](https://arxiv.org/abs/1612.00496) 114 | ``` 115 | @misc{mousavian20173d, 116 | title={3D Bounding Box Estimation Using Deep Learning and Geometry}, 117 | author={Arsalan Mousavian and Dragomir Anguelov and John Flynn and Jana Kosecka}, 118 | year={2017}, 119 | eprint={1612.00496}, 120 | archivePrefix={arXiv}, 121 | primaryClass={cs.CV} 122 | } 123 | ``` -------------------------------------------------------------------------------- /assets/global_calib.txt: -------------------------------------------------------------------------------- 1 | # KITTI 2 | P_rect_02: 7.188560e+02 0.000000e+00 6.071928e+02 4.538225e+01 0.000000e+00 7.188560e+02 1.852157e+02 -1.130887e-01 0.000000e+00 0.000000e+00 1.000000e+00 3.779761e-03 3 | 4 | calib_time: 09-Jan-2012 14:00:15 5 | corner_dist: 9.950000e-02 6 | S_00: 1.392000e+03 5.120000e+02 7 | K_00: 9.799200e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.741183e+02 2.486443e+02 0.000000e+00 0.000000e+00 1.000000e+00 8 | D_00: -3.745594e-01 2.049385e-01 1.110145e-03 1.379375e-03 -7.084798e-02 9 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 10 | T_00: -9.251859e-17 8.326673e-17 -7.401487e-17 11 | S_rect_00: 1.241000e+03 3.760000e+02 12 | R_rect_00: 9.999454e-01 7.259129e-03 -7.519551e-03 -7.292213e-03 9.999638e-01 -4.381729e-03 7.487471e-03 4.436324e-03 9.999621e-01 13 | P_rect_00: 7.188560e+02 0.000000e+00 6.071928e+02 0.000000e+00 0.000000e+00 7.188560e+02 1.852157e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 14 | S_01: 1.392000e+03 5.120000e+02 15 | K_01: 9.903522e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.855674e+02 2.607319e+02 0.000000e+00 0.000000e+00 1.000000e+00 16 | D_01: -3.712084e-01 1.978723e-01 -3.709831e-05 -3.440494e-04 -6.724045e-02 17 | R_01: 9.993440e-01 1.814887e-02 -3.134011e-02 -1.842595e-02 9.997935e-01 -8.575221e-03 3.117801e-02 9.147067e-03 9.994720e-01 18 | T_01: -5.370000e-01 5.964270e-03 -1.274584e-02 19 | S_rect_01: 1.241000e+03 3.760000e+02 20 | R_rect_01: 9.996568e-01 -1.110284e-02 2.372712e-02 1.099810e-02 9.999292e-01 4.539964e-03 -2.377585e-02 -4.277453e-03 9.997082e-01 21 | P_rect_01: 7.188560e+02 0.000000e+00 6.071928e+02 -3.861448e+02 0.000000e+00 7.188560e+02 1.852157e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 22 | S_02: 1.392000e+03 5.120000e+02 23 | K_02: 9.601149e+02 0.000000e+00 6.947923e+02 0.000000e+00 9.548911e+02 2.403547e+02 0.000000e+00 0.000000e+00 1.000000e+00 24 | D_02: -3.685917e-01 1.928022e-01 4.069233e-04 7.247536e-04 -6.276909e-02 25 | R_02: 9.999788e-01 -5.008404e-03 -4.151018e-03 4.990516e-03 9.999783e-01 -4.308488e-03 4.172506e-03 4.287682e-03 9.999821e-01 26 | T_02: 5.954406e-02 -7.675338e-04 3.582565e-03 27 | S_rect_02: 1.241000e+03 3.760000e+02 28 | R_rect_02: 9.999191e-01 1.228161e-02 -3.316013e-03 -1.228209e-02 9.999246e-01 -1.245511e-04 3.314233e-03 1.652686e-04 9.999945e-01 29 | S_03: 1.392000e+03 5.120000e+02 30 | K_03: 9.049931e+02 0.000000e+00 6.957698e+02 0.000000e+00 9.004945e+02 2.389820e+02 0.000000e+00 0.000000e+00 1.000000e+00 31 | D_03: -3.735725e-01 2.066816e-01 -6.133284e-04 -1.193269e-04 -7.600861e-02 32 | R_03: 9.995578e-01 1.656369e-02 -2.469315e-02 -1.663353e-02 9.998582e-01 -2.625576e-03 2.464616e-02 3.035149e-03 9.996916e-01 33 | T_03: -4.738786e-01 5.991982e-03 -3.215069e-03 34 | S_rect_03: 1.241000e+03 3.760000e+02 35 | R_rect_03: 9.998092e-01 -9.354781e-03 1.714961e-02 9.382303e-03 9.999548e-01 -1.525064e-03 -1.713457e-02 1.685675e-03 9.998518e-01 36 | P_rect_03: 7.188560e+02 0.000000e+00 6.071928e+02 -3.372877e+02 0.000000e+00 7.188560e+02 1.852157e+02 2.369057e+00 0.000000e+00 0.000000e+00 1.000000e+00 4.915215e-03 37 | -------------------------------------------------------------------------------- /configs/augmentation/inference_preprocessing.yaml: -------------------------------------------------------------------------------- 1 | to_tensor: 2 | _target_: torchvision.transforms.ToTensor 3 | normalize: 4 | _target_: torchvision.transforms.Normalize 5 | mean: [0.406, 0.456, 0.485] 6 | std: [0.225, 0.224, 0.229] -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - model_summary.yaml 5 | - rich_progress_bar.yaml 6 | - _self_ 7 | 8 | # model save config 9 | model_checkpoint: 10 | dirpath: "weights" 11 | filename: "epoch_{epoch:03d}" 12 | monitor: "val/loss" 13 | mode: "min" 14 | save_last: True 15 | save_top_k: 1 16 | auto_insert_metric_name: False 17 | 18 | early_stopping: 19 | monitor: "val/loss" 20 | patience: 100 21 | mode: "min" 22 | 23 | model_summary: 24 | max_depth: -1 25 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html 2 | 3 | # Monitor a metric and stop training when it stops improving. 4 | # Look at the above link for more detailed information. 5 | early_stopping: 6 | _target_: pytorch_lightning.callbacks.EarlyStopping 7 | monitor: ??? # quantity to be monitored, must be specified !!! 8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 9 | patience: 3 # number of checks with no improvement after which training will be stopped 10 | verbose: False # verbosity mode 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | strict: True # whether to crash the training if monitor is not found in the validation metrics 13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html 2 | 3 | # Save the model periodically by monitoring a quantity. 4 | # Look at the above link for more detailed information. 5 | model_checkpoint: 6 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 7 | dirpath: null # directory to save the model file 8 | filename: null # checkpoint filename 9 | monitor: null # name of the logged metric which determines when model is improving 10 | verbose: False # verbosity mode 11 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 12 | save_top_k: 1 # save k best models (determined by above metric) 13 | mode: "min" # "max" means higher metric value is better, can be also "min" 14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 15 | save_weights_only: False # if True, then only the model’s weights will be saved 16 | every_n_train_steps: null # number of training steps between checkpoints 17 | train_time_interval: null # checkpoints are monitored at the specified time interval 18 | every_n_epochs: null # number of epochs between checkpoints 19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 20 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html 2 | 3 | # Generates a summary of all layers in a LightningModule with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | model_summary: 6 | _target_: pytorch_lightning.callbacks.RichModelSummary 7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 8 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html 2 | 3 | # Create a progress bar with rich text formatting. 4 | # Look at the above link for more detailed information. 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | -------------------------------------------------------------------------------- /configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | upload_code_as_artifact: 10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 11 | code_dir: ${original_work_dir}/src 12 | 13 | upload_ckpts_as_artifact: 14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 15 | ckpt_dir: "checkpoints/" 16 | upload_best_only: True 17 | 18 | # log_f1_precision_recall_heatmap: 19 | # _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 20 | 21 | # log_confusion_matrix: 22 | # _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 23 | 24 | # log_image_predictions: 25 | # _target_: src.callbacks.wandb_callbacks.LogImagePredictions 26 | # num_samples: 8 -------------------------------------------------------------------------------- /configs/convert.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - model: regressor.yaml 7 | 8 | # enable color logging 9 | - override hydra/hydra_logging: colorlog 10 | - override hydra/job_logging: colorlog 11 | 12 | # pretty print config at the start of the run using Rich library 13 | print_config: True 14 | 15 | # disable python warnings if they annoy you 16 | ignore_warnings: True 17 | 18 | # root 19 | root: ${hydra:runtime.cwd} 20 | 21 | # TODO: cahnge to your checkpoint file 22 | checkpoint_dir: ${root}/weights/last.ckpt 23 | 24 | # dump dir 25 | dump_dir: ${root}/weights 26 | 27 | # input sample shape 28 | input_sample: 29 | __target__: torch.randn 30 | size: (1, 3, 224, 224) 31 | 32 | # convert to 33 | convert_to: "pytorch" # [pytorch, onnx, tensorrt] 34 | 35 | # TODO: model name without extension 36 | name: ${dump_dir}/pytorch-kitti 37 | 38 | # convert_to: "onnx" # [pytorch, onnx, tensorrt] 39 | # name: ${dump_dir}/onnx-3d-0817-5 40 | -------------------------------------------------------------------------------- /configs/datamodule/kitti_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.kitti_datamodule.KITTIDataModule 2 | 3 | dataset_path: ${paths.data_dir} # data_dir is specified in config.yaml 4 | train_sets: ${paths.data_dir}/train_80.txt 5 | val_sets: ${paths.data_dir}/val_80.txt 6 | test_sets: ${paths.data_dir}/test_80.txt 7 | batch_size: 64 8 | num_worker: 32 -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | datamodule: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/detector/yolov5.yaml: -------------------------------------------------------------------------------- 1 | _target_: inference.detector_yolov5 2 | 3 | model_path: ${root}/weights/detector_yolov5s.pt 4 | cfg_path: ${root}/yolov5/models/yolov5s.yaml 5 | classes: 5 6 | device: 'cpu' -------------------------------------------------------------------------------- /configs/detector/yolov5_kitti.yaml: -------------------------------------------------------------------------------- 1 | # KITTI to YOLO 2 | 3 | path: ../data/KITTI/ # dataset root dir 4 | train: train_yolo.txt # train images (relative to 'path') 3712 images 5 | val: val_yolo.txt # val images (relative to 'path') 3768 images 6 | 7 | # Classes 8 | nc: 5 # number of classes 9 | names: ['car', 'van', 'truck', 'pedestrian', 'cyclist'] -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - datamodule: mnist.yaml # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /configs/evaluate.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - detector: yolov5.yaml 7 | - model: regressor.yaml 8 | - augmentation: inference_preprocessing.yaml 9 | 10 | # debugging config (enable through command line, e.g. `python train.py debug=default) 11 | - debug: null 12 | 13 | # enable color logging 14 | - override hydra/hydra_logging: colorlog 15 | - override hydra/job_logging: colorlog 16 | 17 | # run name 18 | name: evaluate 19 | 20 | # directory 21 | root: ${hydra:runtime.cwd} 22 | 23 | # predictions/output directory 24 | # pred_dir: ${root}/${hydra:run.dir}/${name} 25 | 26 | # calib_file 27 | calib_file: ${root}/assets/global_calib.txt 28 | 29 | # regressor weights 30 | regressor_weights: ${root}/weights/regressor_resnet18.pt 31 | 32 | # validation images directory 33 | val_images_path: ${root}/data/KITTI/images_2 34 | 35 | # validation sets directory 36 | val_sets: ${root}/data/KITTI/ImageSets/val.txt 37 | 38 | # class to evaluated 39 | classes: 6 40 | 41 | # class_to_name = { 42 | # 0: 'Car', 43 | # 1: 'Cyclist', 44 | # 2: 'Truck', 45 | # 3: 'Van', 46 | # 4: 'Pedestrian', 47 | # 5: 'Tram', 48 | # } 49 | 50 | # gt label path 51 | gt_dir: ${root}/data/KITTI/label_2 52 | 53 | # dt label path 54 | pred_dir: ${root}/data/KITTI/result 55 | 56 | # device to inference 57 | device: 'cuda:0' -------------------------------------------------------------------------------- /configs/experiment/sample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /datamodule: kitti_datamodule.yaml 8 | - override /model: regressor.yaml 9 | - override /callbacks: default.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: dgx.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | seed: 42069 17 | 18 | # name of the run determines folder name in logs 19 | name: "new_network" 20 | 21 | datamodule: 22 | train_sets: ${paths.data_dir}/ImageSets/train.txt 23 | val_sets: ${paths.data_dir}/ImageSets/val.txt 24 | test_sets: ${paths.data_dir}/ImageSets/test.txt 25 | 26 | trainer: 27 | min_epochs: 1 28 | max_epochs: 200 29 | # limit_train_batches: 1.0 30 | # limit_val_batches: 1.0 31 | gpus: [0] 32 | strategy: ddp -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | datamodule.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hparams_search/optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/loss" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 2 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: 'minimize' 34 | 35 | # total number of runs that will be executed 36 | n_trials: 10 37 | 38 | # choose Optuna hyperparameter sampler 39 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 40 | sampler: 41 | _target_: optuna.samplers.TPESampler 42 | seed: 42069 43 | n_startup_trials: 10 # number of random sampling runs before optimization starts 44 | 45 | # define range of hyperparameters 46 | params: 47 | model.lr: interval(0.0001, 0.001) 48 | datamodule.batch_size: choice(32, 64, 128) 49 | model.optimizer: choice(adam, sgd) -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/inference.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - _self_ 6 | - detector: yolov5.yaml 7 | - model: regressor.yaml 8 | - augmentation: inference_preprocessing.yaml 9 | 10 | # debugging config (enable through command line, e.g. `python train.py debug=default) 11 | - debug: null 12 | 13 | # enable color logging 14 | - override hydra/hydra_logging: colorlog 15 | - override hydra/job_logging: colorlog 16 | 17 | # run name 18 | name: inference 19 | 20 | # directory 21 | root: ${hydra:runtime.cwd} 22 | output_dir: ${root}/${hydra:run.dir}/inference 23 | 24 | # calib_file 25 | calib_file: ${root}/assets/global_calib.txt 26 | 27 | # save 2D bounding box 28 | save_det2d: False 29 | 30 | # show and save result 31 | save_result: True 32 | 33 | # save result in txt 34 | # save_txt: True 35 | 36 | # regressor weights 37 | regressor_weights: ${root}/weights/regressor_resnet18.pt 38 | # regressor_weights: ${root}/weights/mobilenetv3-best.pt 39 | 40 | # inference type 41 | inference_type: pytorch # [pytorch, onnx, openvino, tensorrt] 42 | 43 | # source directory 44 | # source_dir: ${root}/tmp/kitti/ 45 | source_dir: ${root}/tmp/video_001 46 | 47 | # device to inference 48 | device: 'cpu' 49 | 50 | export_onnx: False 51 | 52 | func: "label" # image/label 53 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet.yaml 5 | - csv.yaml 6 | # - mlflow.yaml 7 | # - neptune.yaml 8 | - tensorboard.yaml 9 | - wandb.yaml 10 | -------------------------------------------------------------------------------- /configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "yolo3d-regressor" 11 | log_model: True # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/regressor.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.regressor.RegressorModel 2 | 3 | net: 4 | _target_: src.models.components.base.RegressorNet 5 | backbone: 6 | _target_: torchvision.models.resnet18 # change model on this 7 | pretrained: True 8 | bins: 2 9 | 10 | optimizer: adam 11 | 12 | lr: 0.0001 13 | momentum: 0.9 14 | w: 0.8 15 | alpha: 0.2 -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py` 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/KITTI 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - datamodule: kitti_datamodule.yaml 8 | - model: regressor.yaml 9 | - callbacks: default.yaml 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: dgx.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | # tags to help you identify your experiments 34 | # you can overwrite this in experiment configs 35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 36 | # appending lists from command line is currently not supported :( 37 | # https://github.com/facebookresearch/hydra/issues/1547 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: False 46 | 47 | # simply provide checkpoint path to resume training 48 | # ckpt_path: weights/last.ckpt 49 | ckpt_path: null 50 | 51 | # seed for random number generators in pytorch, numpy and python.random 52 | seed: null 53 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 25 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # set True to to ensure deterministic results 15 | # makes training slower but gives more reproducibility than just setting seeds 16 | deterministic: False 17 | -------------------------------------------------------------------------------- /configs/trainer/dgx.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0] 8 | num_nodes: 1 9 | sync_batchnorm: True -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/kaggle.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | gpus: 0 4 | 5 | min_epochs: 1 6 | max_epochs: 10 7 | 8 | # number of validation steps to execute at the beginning of the training 9 | # num_sanity_val_steps: 0 10 | 11 | # ckpt path 12 | resume_from_checkpoint: null 13 | 14 | # disable progress_bar 15 | enable_progress_bar: False -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | """ Conver checkpoint to model (.pt/.pth/.onnx) """ 2 | 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from pytorch_lightning import LightningModule 6 | from src import utils 7 | 8 | import dotenv 9 | import hydra 10 | from omegaconf import DictConfig 11 | import os 12 | 13 | # load environment variables from `.env` file if it exists 14 | # recursively searches for `.env` in all folders starting from work dir 15 | dotenv.load_dotenv(override=True) 16 | log = utils.get_pylogger(__name__) 17 | 18 | @hydra.main(config_path="configs/", config_name="convert.yaml") 19 | def convert(config: DictConfig): 20 | 21 | # assert model convertion 22 | assert config.get('convert_to') in ['pytorch', 'torchscript', 'onnx', 'tensorrt'], \ 23 | "Please Choose one of [pytorch, torchscript, onnx, tensorrt]" 24 | 25 | # Init lightning model 26 | log.info(f"Instantiating model <{config.model._target_}>") 27 | model: LightningModule = hydra.utils.instantiate(config.model) 28 | # regressor: LightningModule = hydra.utils.instantiate(config.model) 29 | # regressor.load_state_dict(torch.load(config.get("regressor_weights"), map_location="cpu")) 30 | # regressor.eval().to(config.get("device")) 31 | 32 | # Convert relative ckpt path to absolute path if necessary 33 | log.info(f"Load checkpoint <{config.get('checkpoint_dir')}>") 34 | ckpt_path = config.get("checkpoint_dir") 35 | if ckpt_path and not os.path.isabs(ckpt_path): 36 | ckpt_path = config.get(os.path.join(hydra.utils.get_original_cwd(), ckpt_path)) 37 | 38 | # load model checkpoint 39 | model = model.load_from_checkpoint(ckpt_path) 40 | model.cuda() 41 | 42 | # input sample 43 | input_sample = config.get('input_sample') 44 | 45 | # Convert 46 | if config.get('convert_to') == 'pytorch': 47 | log.info("Convert to Pytorch (.pt)") 48 | torch.save(model.state_dict(), f'{config.get("name")}.pt') 49 | log.info(f"Saved model {config.get('name')}.pt") 50 | if config.get('convert_to') == 'onnx': 51 | log.info("Convert to ONNX (.onnx)") 52 | model.cuda() 53 | input_sample = torch.rand((1, 3, 224, 224), device=torch.device('cuda')) 54 | model.to_onnx(f'{config.get("name")}.onnx', input_sample, export_params=True) 55 | log.info(f"Saved model {config.get('name')}.onnx") 56 | 57 | if __name__ == '__main__': 58 | 59 | convert() -------------------------------------------------------------------------------- /data/datasplit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Baidu apollo, Inc. 3 | # All Rights Reserved 4 | 5 | import os 6 | import random 7 | 8 | # TODO: change this to your own data path 9 | pnglabelfilepath = r'./KITTI/label_2' 10 | savePath = r"./KITTI/ImageSets/" 11 | 12 | target_png = os.listdir(pnglabelfilepath) 13 | total_png = [] 14 | for t in target_png: 15 | if t.endswith(".txt"): 16 | id = str(int(t.split('.')[0])).zfill(6) 17 | total_png.append(id + '.png') 18 | 19 | print("--- iter for image finished ---") 20 | 21 | # TODO: change this ratio to your own 22 | train_percent = 0.85 23 | val_percent = 0.1 24 | test_percent = 0.05 25 | 26 | num = len(total_png) 27 | # train = random.sample(num,0.9*num) 28 | list = list(range(num)) 29 | 30 | num_train = int(num * train_percent) 31 | num_val = int(num * val_percent) 32 | 33 | 34 | train = random.sample(list, num_train) 35 | num1 = len(train) 36 | for i in range(num1): 37 | list.remove(train[i]) 38 | 39 | val_test = [i for i in list if not i in train] 40 | val = random.sample(val_test, num_val) 41 | num2 = len(val) 42 | for i in range(num2): 43 | list.remove(val[i]) 44 | 45 | 46 | def mkdir(path): 47 | folder = os.path.exists(path) 48 | if not folder: 49 | os.makedirs(path) 50 | print("--- creating new folder... ---") 51 | print("--- finished ---") 52 | else: 53 | print("--- pass to create new folder ---") 54 | 55 | 56 | mkdir(savePath) 57 | 58 | ftrain = open(os.path.join(savePath, 'train.txt'), 'w') 59 | fval = open(os.path.join(savePath, 'val.txt'), 'w') 60 | ftest = open(os.path.join(savePath, 'test.txt'), 'w') 61 | 62 | for i in train: 63 | name = total_png[i][:-4]+ '\n' 64 | ftrain.write(name) 65 | 66 | 67 | for i in val: 68 | name = total_png[i][:-4] + '\n' 69 | fval.write(name) 70 | 71 | 72 | for i in list: 73 | name = total_png[i][:-4] + '\n' 74 | ftest.write(name) 75 | 76 | ftrain.close() 77 | -------------------------------------------------------------------------------- /docs/assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/demo.gif -------------------------------------------------------------------------------- /docs/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/logo.png -------------------------------------------------------------------------------- /docs/assets/show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/show.png -------------------------------------------------------------------------------- /docs/command.md: -------------------------------------------------------------------------------- 1 | # Quick Command 2 | 3 | ## Train Regressor Model 4 | 5 | - Train original 6 | ```bash 7 | python src/train.py 8 | ``` 9 | 10 | - With experiment 11 | ```bash 12 | python src/train.py \ 13 | experiment=sample 14 | ``` 15 | 16 | ## Train Detector Model 17 | ### Yolov5 18 | 19 | - Multi GPU Training 20 | ```bash 21 | cd yolov5 22 | python -m torch.distributed.launch \ 23 | --nproc_per_node 4 train.py \ 24 | --epochs 10 \ 25 | --batch 64 \ 26 | --data ../configs/detector/yolov5_kitti.yaml \ 27 | --weights yolov5s.pt \ 28 | --device 0,1,2,3 29 | ``` 30 | 31 | - Single GPU Training 32 | ```bash 33 | cd yolov5 34 | python train.py \ 35 | --data ../configs/detector/yolov5_kitti.yaml \ 36 | --weights yolov5s.pt \ 37 | --img 640 38 | ``` 39 | 40 | ## Hyperparameter Tuning with Hydra 41 | 42 | ```bash 43 | python src/train.py -m \ 44 | hparams_search=regressor_optuna \ 45 | experiment=sample_optuna 46 | ``` -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # YOLO3D: 3D Object Detection with YOLO 2 | 3 |
4 | 5 | Python 6 | PyTorch 7 | Lightning 8 | 9 | Config: hydra 10 | Code style: black 11 | Template
12 | 13 |
14 | 15 | ## ⚠️  Cautions 16 | > This repository currently under development 17 | 18 | ## 📼  Demo 19 |
20 | 21 | ![demo](./assets/demo.gif) 22 | 23 |
24 | 25 | ## 📌  Introduction 26 | 27 | Unofficial implementation of [Mousavian et al.](https://arxiv.org/abs/1612.00496) in their paper **3D Bounding Box Estimation Using Deep Learning and Geometry**. YOLO3D uses a different approach, as the detector uses **YOLOv5** which previously used Faster-RCNN, and Regressor uses **ResNet18/VGG11** which was previously VGG19. 28 | 29 | ## 🚀  Quickstart 30 | > We use hydra as the config manager; if you are unfamiliar with hydra, you can visit the official website or see the tutorial on this web. 31 | 32 | ### 🍿  Inference 33 | You can use pretrained weight from [Release](https://github.com/ruhyadi/yolo3d-lightning/releases), you can download it using script `get_weights.py`: 34 | ```bash 35 | # download pretrained model 36 | python script/get_weights.py \ 37 | --tag v0.1 \ 38 | --dir ./weights 39 | ``` 40 | Inference with `inference.py`: 41 | ```bash 42 | python inference.py \ 43 | source_dir="./data/demo/images" \ 44 | detector.model_path="./weights/detector_yolov5s.pt" \ 45 | regressor_weights="./weights/regressor_resnet18.pt" 46 | ``` 47 | 48 | ### ⚔️  Training 49 | There are two models that will be trained here: **detector** and **regressor**. For now, the detector model that can be used is only **YOLOv5**, while the regressor model can use all models supported by **Torchvision**. 50 | 51 | #### 🧭  Training YOLOv5 Detector 52 | The first step is to change the `label_2` format from KITTI to YOLO. You can use the following `src/kitti_to_yolo.py`. 53 | 54 | ```bash 55 | cd yolo3d-lightning/src 56 | python kitti_to_yolo.py \ 57 | --dataset_path ../data/KITTI/training/ 58 | --classes ["car", "van", "truck", "pedestrian", "cyclist"] 59 | --img_width 1224 60 | --img_height 370 61 | ``` 62 | 63 | The next step is to follow the [wiki provided by ultralytics](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data). **Note:** *readme will updated in future*. 64 | 65 | #### 🪀  Training Regessor 66 | Selanjutnya, kamu dapat melakukan training model regressor. Model regressor yang dapat dipakai bisa mengacu pada yang tersedia di `torchvision`, atau kamu bisa mengkustomnya sendiri. 67 | 68 | Langkah pertama adalah membuat train dan validation sets. Kamu dapat menggunakan `script/generate_sets.py`: 69 | 70 | ```bash 71 | cd yolo3d-lightning/script 72 | python generate_sets.py \ 73 | --images_path ../data/KITTI/training/images # or image_2 74 | --dump_dir ../data/KITTI/training 75 | --postfix _80 76 | --train_size 0.8 77 | ``` 78 | 79 | Pada langkah selanjutnya, kita hanya akan menggunakan model yang ada di `torchvision` saja. Langkah termudah adalah dengan mengubah configurasi di `configs.model.regressor.yaml`, seperti di bawah: 80 | 81 | ```yaml 82 | _target_: src.models.regressor.RegressorModel 83 | 84 | net: 85 | _target_: src.models.components.base.RegressorNet 86 | backbone: 87 | _target_: torchvision.models.resnet18 # edit this 88 | pretrained: True # maybe this too 89 | bins: 2 90 | 91 | lr: 0.001 92 | momentum: 0.9 93 | w: 0.4 94 | alpha: 0.6 95 | ``` 96 | 97 | Langkah selanjutnya adalah dengan membuat konfigurasi experiment pada `configs/experiment/your_exp.yaml`. Jika bingung, kamu dapat mengacu pada [`configs/experiment/demo.yaml`](./configs/experiment/demo.yaml). 98 | 99 | Setelah konfigurasi experiment dibuat. Kamu dapat dengan mudah menjalankan perintah `train.py`, seperti berikut: 100 | 101 | ```bash 102 | cd yolo3d-lightning 103 | python train.py \ 104 | experiment=demo 105 | ``` 106 | 107 | 108 | 109 | 110 | 111 | 112 | ## ❤️  Acknowledgement 113 | 114 | - [YOLOv5 by Ultralytics](https://github.com/ultralytics/yolov5) 115 | - [skhadem/3D-BoundingBox](https://github.com/skhadem/3D-BoundingBox) 116 | - [Mousavian et al.](https://arxiv.org/abs/1612.00496) 117 | ``` 118 | @misc{mousavian20173d, 119 | title={3D Bounding Box Estimation Using Deep Learning and Geometry}, 120 | author={Arsalan Mousavian and Dragomir Anguelov and John Flynn and Jana Kosecka}, 121 | year={2017}, 122 | eprint={1612.00496}, 123 | archivePrefix={arXiv}, 124 | primaryClass={cs.CV} 125 | } 126 | ``` -------------------------------------------------------------------------------- /docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | 16 | 17 | MathJax.typesetPromise() 18 | }) -------------------------------------------------------------------------------- /kitti_object_eval/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /kitti_object_eval/README.md: -------------------------------------------------------------------------------- 1 | # Note 2 | 3 | This code is from [traveller59/kitti-object-eval-python](https://github.com/traveller59/kitti-object-eval-python) 4 | 5 | # kitti-object-eval-python 6 | Fast kitti object detection eval in python(finish eval in less than 10 second), support 2d/bev/3d/aos. , support coco-style AP. If you use command line interface, numba need some time to compile jit functions. 7 | 8 | _WARNING_: The "coco" isn't official metrics. Only "AP(Average Precision)" is. 9 | ## Dependencies 10 | Only support python 3.6+, need `numpy`, `skimage`, `numba`, `fire`, `scipy`. If you have Anaconda, just install `cudatoolkit` in anaconda. Otherwise, please reference to this [page](https://github.com/numba/numba#custom-python-environments) to set up llvm and cuda for numba. 11 | * Install by conda: 12 | ``` 13 | conda install -c numba cudatoolkit=x.x (8.0, 9.0, 10.0, depend on your environment) 14 | ``` 15 | ## Usage 16 | * commandline interface: 17 | ``` 18 | python evaluate.py evaluate --label_path=/path/to/your_gt_label_folder --result_path=/path/to/your_result_folder --label_split_file=/path/to/val.txt --current_class=0 --coco=False 19 | ``` 20 | * python interface: 21 | ```Python 22 | import kitti_common as kitti 23 | from eval import get_official_eval_result, get_coco_eval_result 24 | def _read_imageset_file(path): 25 | with open(path, 'r') as f: 26 | lines = f.readlines() 27 | return [int(line) for line in lines] 28 | det_path = "/path/to/your_result_folder" 29 | dt_annos = kitti.get_label_annos(det_path) 30 | gt_path = "/path/to/your_gt_label_folder" 31 | gt_split_file = "/path/to/val.txt" # from https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz 32 | val_image_ids = _read_imageset_file(gt_split_file) 33 | gt_annos = kitti.get_label_annos(gt_path, val_image_ids) 34 | print(get_official_eval_result(gt_annos, dt_annos, 0)) # 6s in my computer 35 | print(get_coco_eval_result(gt_annos, dt_annos, 0)) # 18s in my computer 36 | ``` 37 | -------------------------------------------------------------------------------- /kitti_object_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | import fire 3 | import kitti_common as kitti 4 | from eval import get_official_eval_result, get_coco_eval_result 5 | 6 | 7 | def _read_imageset_file(path): 8 | with open(path, 'r') as f: 9 | lines = f.readlines() 10 | return [int(line) for line in lines] 11 | 12 | 13 | def evaluate(label_path, # gt 14 | result_path, # dt 15 | label_split_file, 16 | current_class=0, # 0: bbox, 1: bev, 2: 3d 17 | coco=False, 18 | score_thresh=-1): 19 | dt_annos = kitti.get_label_annos(result_path) 20 | # print("dt_annos[0] is ", dt_annos[0], " shape is ", len(dt_annos)) 21 | 22 | # if score_thresh > 0: 23 | # dt_annos = kitti.filter_annos_low_score(dt_annos, score_thresh) 24 | # val_image_ids = _read_imageset_file(label_split_file) 25 | 26 | gt_annos = kitti.get_label_annos(label_path) 27 | # print("gt_annos[0] is ", gt_annos[0], " shape is ", len(gt_annos)) 28 | 29 | if coco: 30 | print(get_coco_eval_result(gt_annos, dt_annos, current_class)) 31 | else: 32 | print("not coco") 33 | print(get_official_eval_result(gt_annos, dt_annos, current_class)) 34 | 35 | 36 | if __name__ == '__main__': 37 | fire.Fire() 38 | -------------------------------------------------------------------------------- /kitti_object_eval/rotate_iou.py: -------------------------------------------------------------------------------- 1 | ##################### 2 | # Based on https://github.com/hongzhenwang/RRPN-revise 3 | # Licensed under The MIT License 4 | # Author: yanyan, scrin@foxmail.com 5 | ##################### 6 | import math 7 | 8 | import numba 9 | import numpy as np 10 | from numba import cuda 11 | 12 | @numba.jit(nopython=True) 13 | def div_up(m, n): 14 | return m // n + (m % n > 0) 15 | 16 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 17 | def trangle_area(a, b, c): 18 | return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) * 19 | (b[0] - c[0])) / 2.0 20 | 21 | 22 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 23 | def area(int_pts, num_of_inter): 24 | area_val = 0.0 25 | for i in range(num_of_inter - 2): 26 | area_val += abs( 27 | trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4], 28 | int_pts[2 * i + 4:2 * i + 6])) 29 | return area_val 30 | 31 | 32 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 33 | def sort_vertex_in_convex_polygon(int_pts, num_of_inter): 34 | if num_of_inter > 0: 35 | center = cuda.local.array((2, ), dtype=numba.float32) 36 | center[:] = 0.0 37 | for i in range(num_of_inter): 38 | center[0] += int_pts[2 * i] 39 | center[1] += int_pts[2 * i + 1] 40 | center[0] /= num_of_inter 41 | center[1] /= num_of_inter 42 | v = cuda.local.array((2, ), dtype=numba.float32) 43 | vs = cuda.local.array((16, ), dtype=numba.float32) 44 | for i in range(num_of_inter): 45 | v[0] = int_pts[2 * i] - center[0] 46 | v[1] = int_pts[2 * i + 1] - center[1] 47 | d = math.sqrt(v[0] * v[0] + v[1] * v[1]) 48 | v[0] = v[0] / d 49 | v[1] = v[1] / d 50 | if v[1] < 0: 51 | v[0] = -2 - v[0] 52 | vs[i] = v[0] 53 | j = 0 54 | temp = 0 55 | for i in range(1, num_of_inter): 56 | if vs[i - 1] > vs[i]: 57 | temp = vs[i] 58 | tx = int_pts[2 * i] 59 | ty = int_pts[2 * i + 1] 60 | j = i 61 | while j > 0 and vs[j - 1] > temp: 62 | vs[j] = vs[j - 1] 63 | int_pts[j * 2] = int_pts[j * 2 - 2] 64 | int_pts[j * 2 + 1] = int_pts[j * 2 - 1] 65 | j -= 1 66 | 67 | vs[j] = temp 68 | int_pts[j * 2] = tx 69 | int_pts[j * 2 + 1] = ty 70 | 71 | 72 | @cuda.jit( 73 | '(float32[:], float32[:], int32, int32, float32[:])', 74 | device=True, 75 | inline=True) 76 | def line_segment_intersection(pts1, pts2, i, j, temp_pts): 77 | A = cuda.local.array((2, ), dtype=numba.float32) 78 | B = cuda.local.array((2, ), dtype=numba.float32) 79 | C = cuda.local.array((2, ), dtype=numba.float32) 80 | D = cuda.local.array((2, ), dtype=numba.float32) 81 | 82 | A[0] = pts1[2 * i] 83 | A[1] = pts1[2 * i + 1] 84 | 85 | B[0] = pts1[2 * ((i + 1) % 4)] 86 | B[1] = pts1[2 * ((i + 1) % 4) + 1] 87 | 88 | C[0] = pts2[2 * j] 89 | C[1] = pts2[2 * j + 1] 90 | 91 | D[0] = pts2[2 * ((j + 1) % 4)] 92 | D[1] = pts2[2 * ((j + 1) % 4) + 1] 93 | BA0 = B[0] - A[0] 94 | BA1 = B[1] - A[1] 95 | DA0 = D[0] - A[0] 96 | CA0 = C[0] - A[0] 97 | DA1 = D[1] - A[1] 98 | CA1 = C[1] - A[1] 99 | acd = DA1 * CA0 > CA1 * DA0 100 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0]) 101 | if acd != bcd: 102 | abc = CA1 * BA0 > BA1 * CA0 103 | abd = DA1 * BA0 > BA1 * DA0 104 | if abc != abd: 105 | DC0 = D[0] - C[0] 106 | DC1 = D[1] - C[1] 107 | ABBA = A[0] * B[1] - B[0] * A[1] 108 | CDDC = C[0] * D[1] - D[0] * C[1] 109 | DH = BA1 * DC0 - BA0 * DC1 110 | Dx = ABBA * DC0 - BA0 * CDDC 111 | Dy = ABBA * DC1 - BA1 * CDDC 112 | temp_pts[0] = Dx / DH 113 | temp_pts[1] = Dy / DH 114 | return True 115 | return False 116 | 117 | 118 | @cuda.jit( 119 | '(float32[:], float32[:], int32, int32, float32[:])', 120 | device=True, 121 | inline=True) 122 | def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts): 123 | a = cuda.local.array((2, ), dtype=numba.float32) 124 | b = cuda.local.array((2, ), dtype=numba.float32) 125 | c = cuda.local.array((2, ), dtype=numba.float32) 126 | d = cuda.local.array((2, ), dtype=numba.float32) 127 | 128 | a[0] = pts1[2 * i] 129 | a[1] = pts1[2 * i + 1] 130 | 131 | b[0] = pts1[2 * ((i + 1) % 4)] 132 | b[1] = pts1[2 * ((i + 1) % 4) + 1] 133 | 134 | c[0] = pts2[2 * j] 135 | c[1] = pts2[2 * j + 1] 136 | 137 | d[0] = pts2[2 * ((j + 1) % 4)] 138 | d[1] = pts2[2 * ((j + 1) % 4) + 1] 139 | 140 | area_abc = trangle_area(a, b, c) 141 | area_abd = trangle_area(a, b, d) 142 | 143 | if area_abc * area_abd >= 0: 144 | return False 145 | 146 | area_cda = trangle_area(c, d, a) 147 | area_cdb = area_cda + area_abc - area_abd 148 | 149 | if area_cda * area_cdb >= 0: 150 | return False 151 | t = area_cda / (area_abd - area_abc) 152 | 153 | dx = t * (b[0] - a[0]) 154 | dy = t * (b[1] - a[1]) 155 | temp_pts[0] = a[0] + dx 156 | temp_pts[1] = a[1] + dy 157 | return True 158 | 159 | 160 | @cuda.jit('(float32, float32, float32[:])', device=True, inline=True) 161 | def point_in_quadrilateral(pt_x, pt_y, corners): 162 | ab0 = corners[2] - corners[0] 163 | ab1 = corners[3] - corners[1] 164 | 165 | ad0 = corners[6] - corners[0] 166 | ad1 = corners[7] - corners[1] 167 | 168 | ap0 = pt_x - corners[0] 169 | ap1 = pt_y - corners[1] 170 | 171 | abab = ab0 * ab0 + ab1 * ab1 172 | abap = ab0 * ap0 + ab1 * ap1 173 | adad = ad0 * ad0 + ad1 * ad1 174 | adap = ad0 * ap0 + ad1 * ap1 175 | 176 | return abab >= abap and abap >= 0 and adad >= adap and adap >= 0 177 | 178 | 179 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 180 | def quadrilateral_intersection(pts1, pts2, int_pts): 181 | num_of_inter = 0 182 | for i in range(4): 183 | if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2): 184 | int_pts[num_of_inter * 2] = pts1[2 * i] 185 | int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1] 186 | num_of_inter += 1 187 | if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1): 188 | int_pts[num_of_inter * 2] = pts2[2 * i] 189 | int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1] 190 | num_of_inter += 1 191 | temp_pts = cuda.local.array((2, ), dtype=numba.float32) 192 | for i in range(4): 193 | for j in range(4): 194 | has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts) 195 | if has_pts: 196 | int_pts[num_of_inter * 2] = temp_pts[0] 197 | int_pts[num_of_inter * 2 + 1] = temp_pts[1] 198 | num_of_inter += 1 199 | 200 | return num_of_inter 201 | 202 | 203 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 204 | def rbbox_to_corners(corners, rbbox): 205 | # generate clockwise corners and rotate it clockwise 206 | angle = rbbox[4] 207 | a_cos = math.cos(angle) 208 | a_sin = math.sin(angle) 209 | center_x = rbbox[0] 210 | center_y = rbbox[1] 211 | x_d = rbbox[2] 212 | y_d = rbbox[3] 213 | corners_x = cuda.local.array((4, ), dtype=numba.float32) 214 | corners_y = cuda.local.array((4, ), dtype=numba.float32) 215 | corners_x[0] = -x_d / 2 216 | corners_x[1] = -x_d / 2 217 | corners_x[2] = x_d / 2 218 | corners_x[3] = x_d / 2 219 | corners_y[0] = -y_d / 2 220 | corners_y[1] = y_d / 2 221 | corners_y[2] = y_d / 2 222 | corners_y[3] = -y_d / 2 223 | for i in range(4): 224 | corners[2 * 225 | i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x 226 | corners[2 * i 227 | + 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y 228 | 229 | 230 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 231 | def inter(rbbox1, rbbox2): 232 | corners1 = cuda.local.array((8, ), dtype=numba.float32) 233 | corners2 = cuda.local.array((8, ), dtype=numba.float32) 234 | intersection_corners = cuda.local.array((16, ), dtype=numba.float32) 235 | 236 | rbbox_to_corners(corners1, rbbox1) 237 | rbbox_to_corners(corners2, rbbox2) 238 | 239 | num_intersection = quadrilateral_intersection(corners1, corners2, 240 | intersection_corners) 241 | sort_vertex_in_convex_polygon(intersection_corners, num_intersection) 242 | # print(intersection_corners.reshape([-1, 2])[:num_intersection]) 243 | 244 | return area(intersection_corners, num_intersection) 245 | 246 | 247 | @cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True) 248 | def devRotateIoUEval(rbox1, rbox2, criterion=-1): 249 | area1 = rbox1[2] * rbox1[3] 250 | area2 = rbox2[2] * rbox2[3] 251 | area_inter = inter(rbox1, rbox2) 252 | if criterion == -1: 253 | return area_inter / (area1 + area2 - area_inter) 254 | elif criterion == 0: 255 | return area_inter / area1 256 | elif criterion == 1: 257 | return area_inter / area2 258 | else: 259 | return area_inter 260 | 261 | @cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False) 262 | def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1): 263 | threadsPerBlock = 8 * 8 264 | row_start = cuda.blockIdx.x 265 | col_start = cuda.blockIdx.y 266 | tx = cuda.threadIdx.x 267 | row_size = min(N - row_start * threadsPerBlock, threadsPerBlock) 268 | col_size = min(K - col_start * threadsPerBlock, threadsPerBlock) 269 | block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 270 | block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 271 | 272 | dev_query_box_idx = threadsPerBlock * col_start + tx 273 | dev_box_idx = threadsPerBlock * row_start + tx 274 | if (tx < col_size): 275 | block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0] 276 | block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1] 277 | block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2] 278 | block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3] 279 | block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4] 280 | if (tx < row_size): 281 | block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0] 282 | block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1] 283 | block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2] 284 | block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3] 285 | block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4] 286 | cuda.syncthreads() 287 | if tx < row_size: 288 | for i in range(col_size): 289 | offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i 290 | dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5], 291 | block_boxes[tx * 5:tx * 5 + 5], criterion) 292 | 293 | 294 | def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0): 295 | """rotated box iou running in gpu. 500x faster than cpu version 296 | (take 5ms in one example with numba.cuda code). 297 | convert from [this project]( 298 | https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation). 299 | 300 | Args: 301 | boxes (float tensor: [N, 5]): rbboxes. format: centers, dims, 302 | angles(clockwise when positive) 303 | query_boxes (float tensor: [K, 5]): [description] 304 | device_id (int, optional): Defaults to 0. [description] 305 | 306 | Returns: 307 | [type]: [description] 308 | """ 309 | box_dtype = boxes.dtype 310 | boxes = boxes.astype(np.float32) 311 | query_boxes = query_boxes.astype(np.float32) 312 | N = boxes.shape[0] 313 | K = query_boxes.shape[0] 314 | iou = np.zeros((N, K), dtype=np.float32) 315 | if N == 0 or K == 0: 316 | return iou 317 | threadsPerBlock = 8 * 8 318 | cuda.select_device(device_id) 319 | blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock)) 320 | 321 | stream = cuda.stream() 322 | with stream.auto_synchronize(): 323 | boxes_dev = cuda.to_device(boxes.reshape([-1]), stream) 324 | query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream) 325 | iou_dev = cuda.to_device(iou.reshape([-1]), stream) 326 | rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream]( 327 | N, K, boxes_dev, query_boxes_dev, iou_dev, criterion) 328 | iou_dev.copy_to_host(iou.reshape([-1]), stream=stream) 329 | return iou.astype(boxes.dtype) -------------------------------------------------------------------------------- /kitti_object_eval/run.sh: -------------------------------------------------------------------------------- 1 | python evaluate.py evaluate \ 2 | --label_path=/home/your/path/data/KITTI/label_2 \ 3 | --result_path=/home/your/path/data/KITTI/result \ 4 | --label_split_file=/home/your/path/data/KITTI/ImageSets/val.txt \ 5 | --current_class=0,1,2 -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/logs/.gitkeep -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: YOLO3D 3 | site_url: https://ruhyadi.github.io/yolo3d-lightning 4 | site_author: Didi Ruhyadi 5 | site_description: >- 6 | YOLO3D: 3D Object Detection with YOLO 7 | 8 | # Repository 9 | repo_name: ruhyadi/yolo3d-lightning 10 | repo_url: https://github.com/ruhyadi/yolo3d-lightning 11 | edit_uri: "" 12 | 13 | # Copyright 14 | copyright: Copyright © 2020 - 2022 Didi Ruhyadi 15 | 16 | # Configuration 17 | theme: 18 | name: material 19 | language: en 20 | 21 | # Don't include MkDocs' JavaScript 22 | include_search_page: false 23 | search_index_only: true 24 | 25 | features: 26 | - content.code.annotate 27 | # - content.tabs.link 28 | # - header.autohide 29 | # - navigation.expand 30 | - navigation.indexes 31 | # - navigation.instant 32 | - navigation.sections 33 | - navigation.tabs 34 | # - navigation.tabs.sticky 35 | - navigation.top 36 | - navigation.tracking 37 | - search.highlight 38 | - search.share 39 | - search.suggest 40 | # - toc.integrate 41 | palette: 42 | - scheme: default 43 | primary: white 44 | accent: indigo 45 | toggle: 46 | icon: material/weather-night 47 | name: Vampire Mode 48 | - scheme: slate 49 | primary: indigo 50 | accent: blue 51 | toggle: 52 | icon: material/weather-sunny 53 | name: Beware of Your Eyes 54 | font: 55 | text: Noto Serif 56 | code: Noto Mono 57 | favicon: assets/logo.png 58 | logo: assets/logo.png 59 | icon: 60 | repo: fontawesome/brands/github 61 | 62 | # Plugins 63 | plugins: 64 | 65 | # Customization 66 | extra: 67 | social: 68 | - icon: fontawesome/brands/github 69 | link: https://github.com/ruhyadi 70 | - icon: fontawesome/brands/docker 71 | link: https://hub.docker.com/r/ruhyadi 72 | - icon: fontawesome/brands/twitter 73 | link: https://twitter.com/ 74 | - icon: fontawesome/brands/linkedin 75 | link: https://linkedin.com/in/didiruhyadi 76 | - icon: fontawesome/brands/instagram 77 | link: https://instagram.com/didiir_ 78 | 79 | extra_javascript: 80 | - javascripts/mathjax.js 81 | - https://polyfill.io/v3/polyfill.min.js?features=es6 82 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 83 | 84 | # Extensions 85 | markdown_extensions: 86 | - admonition 87 | - abbr 88 | - pymdownx.snippets 89 | - attr_list 90 | - def_list 91 | - footnotes 92 | - meta 93 | - md_in_html 94 | - toc: 95 | permalink: true 96 | - pymdownx.arithmatex: 97 | generic: true 98 | - pymdownx.betterem: 99 | smart_enable: all 100 | - pymdownx.caret 101 | - pymdownx.details 102 | - pymdownx.emoji: 103 | emoji_index: !!python/name:materialx.emoji.twemoji 104 | emoji_generator: !!python/name:materialx.emoji.to_svg 105 | - pymdownx.highlight: 106 | anchor_linenums: true 107 | - pymdownx.inlinehilite 108 | - pymdownx.keys 109 | - pymdownx.magiclink: 110 | repo_url_shorthand: true 111 | user: squidfunk 112 | repo: mkdocs-material 113 | - pymdownx.mark 114 | - pymdownx.smartsymbols 115 | - pymdownx.superfences: 116 | custom_fences: 117 | - name: mermaid 118 | class: mermaid 119 | format: !!python/name:pymdownx.superfences.fence_code_format 120 | - pymdownx.tabbed: 121 | alternate_style: true 122 | - pymdownx.tasklist: 123 | custom_checkbox: true 124 | - pymdownx.tilde 125 | 126 | # Page tree 127 | nav: 128 | - Home: 129 | - Home: index.md -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/notebooks/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--color=yes", 4 | "--durations=0", 5 | "--strict-markers", 6 | "--doctest-modules", 7 | ] 8 | filterwarnings = [ 9 | "ignore::DeprecationWarning", 10 | "ignore::UserWarning", 11 | ] 12 | log_cli = "True" 13 | markers = [ 14 | "slow: slow tests", 15 | ] 16 | minversion = "6.0" 17 | testpaths = "tests/" 18 | 19 | [tool.coverage.report] 20 | exclude_lines = [ 21 | "pragma: nocover", 22 | "raise NotImplementedError", 23 | "raise NotImplementedError()", 24 | "if __name__ == .__main__.:", 25 | ] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=1.8.0 3 | torchvision>=0.9.1 4 | pytorch-lightning==1.6.5 5 | torchmetrics==0.9.2 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.2.0 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | 18 | # --------- others --------- # 19 | pyrootutils # standardizing the project root setup 20 | pre-commit # hooks for applying linters on commit 21 | rich # beautiful text formatting in terminal 22 | pytest # tests 23 | sh # for running bash commands in some tests 24 | -------------------------------------------------------------------------------- /scripts/frames_to_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate frames to vid 3 | Usage: 4 | python scripts/frames_to_video.py \ 5 | --imgs_path /path/to/imgs \ 6 | --vid_path /path/to/vid \ 7 | --fps 24 \ 8 | --frame_size 1242 375 \ 9 | --resize 10 | 11 | python scripts/frames_to_video.py \ 12 | --imgs_path outputs/2023-05-13/22-51-34/inference \ 13 | --vid_path tmp/output_videos/001.mp4 \ 14 | --fps 3 \ 15 | --frame_size 1550 387 \ 16 | --resize 17 | """ 18 | 19 | import argparse 20 | import cv2 21 | from glob import glob 22 | import os 23 | from tqdm import tqdm 24 | 25 | def generate(imgs_path, vid_path, fps=30, frame_size=(1242, 375), resize=True): 26 | """Generate frames to vid""" 27 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 28 | vid_writer = cv2.VideoWriter(vid_path, fourcc, fps, frame_size) 29 | imgs_glob = sorted(glob(os.path.join(imgs_path, "*.png"))) 30 | if resize: 31 | for img_path in tqdm(imgs_glob): 32 | img = cv2.imread(img_path) 33 | img = cv2.resize(img, frame_size) 34 | vid_writer.write(img) 35 | else: 36 | for img_path in imgs_glob: 37 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 38 | vid_writer.write(img) 39 | vid_writer.release() 40 | print('[INFO] Video saved to {}'.format(vid_path)) 41 | 42 | if __name__ == "__main__": 43 | # create argparser 44 | parser = argparse.ArgumentParser(description="Generate frames to vid") 45 | parser.add_argument("--imgs_path", type=str, default="outputs/2022-10-23/21-03-50/inference", help="path to imgs") 46 | parser.add_argument("--vid_path", type=str, default="outputs/videos/004.mp4", help="path to vid") 47 | parser.add_argument("--fps", type=int, default=24, help="fps") 48 | parser.add_argument("--frame_size", type=int, nargs=2, default=(int(1242), int(375)), help="frame size") 49 | parser.add_argument("--resize", action="store_true", help="resize") 50 | args = parser.parse_args() 51 | 52 | # generate vid 53 | generate(args.imgs_path, args.vid_path, args.fps, args.frame_size) -------------------------------------------------------------------------------- /scripts/generate_sets.py: -------------------------------------------------------------------------------- 1 | """Create training and validation sets""" 2 | 3 | from glob import glob 4 | import os 5 | import argparse 6 | 7 | 8 | def generate_sets( 9 | images_path: str, 10 | dump_dir: str, 11 | postfix: str = "", 12 | train_size: float = 0.8, 13 | is_yolo: bool = False, 14 | ): 15 | images = glob(os.path.join(images_path, "*.png")) 16 | ids = [id_.split("/")[-1].split(".")[0] for id_ in images] 17 | 18 | train_sets = sorted(ids[: int(len(ids) * train_size)]) 19 | val_sets = sorted(ids[int(len(ids) * train_size) :]) 20 | 21 | for name, sets in zip(["train", "val"], [train_sets, val_sets]): 22 | name = os.path.join(dump_dir, f"{name}{postfix}.txt") 23 | with open(name, "w") as f: 24 | for id in sets: 25 | if is_yolo: 26 | f.write(f"./images/{id}.png\n") 27 | else: 28 | f.write(f"{id}\n") 29 | 30 | print(f"[INFO] Training set: {len(train_sets)}") 31 | print(f"[INFO] Validation set: {len(val_sets)}") 32 | print(f"[INFO] Total: {len(train_sets) + len(val_sets)}") 33 | print(f"[INFO] Success Generate Sets") 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser(description="Create training and validation sets") 38 | parser.add_argument("--images_path", type=str, default="./data/KITTI/images") 39 | parser.add_argument("--dump_dir", type=str, default="./data/KITTI") 40 | parser.add_argument("--postfix", type=str, default="_95") 41 | parser.add_argument("--train_size", type=float, default=0.95) 42 | parser.add_argument("--is_yolo", action="store_true") 43 | args = parser.parse_args() 44 | 45 | generate_sets( 46 | images_path=args.images_path, 47 | dump_dir=args.dump_dir, 48 | postfix=args.postfix, 49 | train_size=args.train_size, 50 | is_yolo=False, 51 | ) 52 | -------------------------------------------------------------------------------- /scripts/get_weights.py: -------------------------------------------------------------------------------- 1 | """Download pretrained weights from github release""" 2 | 3 | from pprint import pprint 4 | import requests 5 | import os 6 | import shutil 7 | import argparse 8 | from zipfile import ZipFile 9 | 10 | def get_assets(tag): 11 | """Get release assets by tag name""" 12 | url = 'https://api.github.com/repos/ruhyadi/yolo3d-lightning/releases/tags/' + tag 13 | response = requests.get(url) 14 | return response.json()['assets'] 15 | 16 | def download_assets(assets, dir): 17 | """Download assets to dir""" 18 | for asset in assets: 19 | url = asset['browser_download_url'] 20 | filename = asset['name'] 21 | print('[INFO] Downloading {}'.format(filename)) 22 | response = requests.get(url, stream=True) 23 | with open(os.path.join(dir, filename), 'wb') as f: 24 | shutil.copyfileobj(response.raw, f) 25 | del response 26 | 27 | with ZipFile(os.path.join(dir, filename), 'r') as zip_file: 28 | zip_file.extractall(dir) 29 | os.remove(os.path.join(dir, filename)) 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description='Download pretrained weights') 33 | parser.add_argument('--tag', type=str, default='v0.1', help='tag name') 34 | parser.add_argument('--dir', type=str, default='./', help='directory to save weights') 35 | args = parser.parse_args() 36 | 37 | assets = get_assets(args.tag) 38 | download_assets(assets, args.dir) 39 | -------------------------------------------------------------------------------- /scripts/kitti_to_yolo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert KITTI format to YOLO format. 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | from typing import Tuple 12 | 13 | 14 | class KITTI2YOLO: 15 | def __init__( 16 | self, 17 | dataset_path: str = "../data/KITTI", 18 | classes: Tuple = ["car", "van", "truck", "pedestrian", "cyclist"], 19 | img_width: int = 1224, 20 | img_height: int = 370, 21 | ): 22 | 23 | self.dataset_path = dataset_path 24 | self.img_width = img_width 25 | self.img_height = img_height 26 | self.classes = classes 27 | self.ids = {self.classes[i]: i for i in range(len(self.classes))} 28 | 29 | # create new directory 30 | self.label_path = os.path.join(self.dataset_path, "labels") 31 | if not os.path.isdir(self.label_path): 32 | os.makedirs(self.label_path) 33 | else: 34 | print("[INFO] Directory already exist...") 35 | 36 | def convert(self): 37 | files = glob(os.path.join(self.dataset_path, "label_2", "*.txt")) 38 | for file in tqdm(files): 39 | with open(file, "r") as f: 40 | filename = os.path.join(self.label_path, file.split("/")[-1]) 41 | dump_txt = open(filename, "w") 42 | for line in f: 43 | parse_line = self.parse_line(line) 44 | if parse_line["name"].lower() not in self.classes: 45 | continue 46 | 47 | xmin, ymin, xmax, ymax = parse_line["bbox_camera"] 48 | xcenter = ((xmax - xmin) / 2 + xmin) / self.img_width 49 | if xcenter > 1.0: 50 | xcenter = 1.0 51 | ycenter = ((ymax - ymin) / 2 + ymin) / self.img_height 52 | if ycenter > 1.0: 53 | ycenter = 1.0 54 | width = (xmax - xmin) / self.img_width 55 | if width > 1.0: 56 | width = 1.0 57 | height = (ymax - ymin) / self.img_height 58 | if height > 1.0: 59 | height = 1.0 60 | 61 | bbox_yolo = f"{self.ids[parse_line['name'].lower()]} {xcenter:.3f} {ycenter:.3f} {width:.3f} {height:.3f}" 62 | dump_txt.write(bbox_yolo + "\n") 63 | 64 | dump_txt.close() 65 | 66 | def parse_line(self, line): 67 | parts = line.split(" ") 68 | output = { 69 | "name": parts[0].strip(), 70 | "xyz_camera": (float(parts[11]), float(parts[12]), float(parts[13])), 71 | "wlh": (float(parts[9]), float(parts[10]), float(parts[8])), 72 | "yaw_camera": float(parts[14]), 73 | "bbox_camera": ( 74 | float(parts[4]), 75 | float(parts[5]), 76 | float(parts[6]), 77 | float(parts[7]), 78 | ), 79 | "truncation": float(parts[1]), 80 | "occlusion": float(parts[2]), 81 | "alpha": float(parts[3]), 82 | } 83 | 84 | # Add score if specified 85 | if len(parts) > 15: 86 | output["score"] = float(parts[15]) 87 | else: 88 | output["score"] = np.nan 89 | 90 | return output 91 | 92 | 93 | if __name__ == "__main__": 94 | 95 | # argparser 96 | parser = argparse.ArgumentParser(description="KITTI to YOLO Convertion") 97 | parser.add_argument("--dataset_path", type=str, default="../data/KITTI") 98 | parser.add_argument( 99 | "--classes", 100 | type=Tuple, 101 | default=["car", "van", "truck", "pedestrian", "cyclist"], 102 | ) 103 | parser.add_argument("--img_width", type=int, default=1224) 104 | parser.add_argument("--img_height", type=int, default=370) 105 | args = parser.parse_args() 106 | 107 | kitit2yolo = KITTI2YOLO( 108 | dataset_path=args.dataset_path, 109 | classes=args.classes, 110 | img_width=args.img_width, 111 | img_height=args.img_height, 112 | ) 113 | kitit2yolo.convert() 114 | -------------------------------------------------------------------------------- /scripts/post_weights.py: -------------------------------------------------------------------------------- 1 | """Upload weights to github release""" 2 | 3 | from pprint import pprint 4 | import requests 5 | import os 6 | import dotenv 7 | import argparse 8 | from zipfile import ZipFile 9 | 10 | dotenv.load_dotenv() 11 | 12 | 13 | def create_release(tag, name, description, target="main"): 14 | """Create release""" 15 | token = os.environ.get("GITHUB_TOKEN") 16 | headers = { 17 | "Accept": "application/vnd.github.v3+json", 18 | "Authorization": f"token {token}", 19 | "Content-Type": "application/zip" 20 | } 21 | url = "https://api.github.com/repos/ruhyadi/yolo3d-lightning/releases" 22 | payload = { 23 | "tag_name": tag, 24 | "target_commitish": target, 25 | "name": name, 26 | "body": description, 27 | "draft": True, 28 | "prerelease": False, 29 | "generate_release_notes": True, 30 | } 31 | print("[INFO] Creating release {}".format(tag)) 32 | response = requests.post(url, json=payload, headers=headers) 33 | print("[INFO] Release created id: {}".format(response.json()["id"])) 34 | 35 | return response.json() 36 | 37 | 38 | def post_assets(assets, release_id): 39 | """Post assets to release""" 40 | token = os.environ.get("GITHUB_TOKEN") 41 | headers = { 42 | "Accept": "application/vnd.github.v3+json", 43 | "Authorization": f"token {token}", 44 | "Content-Type": "application/zip" 45 | } 46 | for asset in assets: 47 | asset_path = os.path.join(os.getcwd(), asset) 48 | with ZipFile(f"{asset_path}.zip", "w") as zip_file: 49 | zip_file.write(asset) 50 | asset_path = f"{asset_path}.zip" 51 | filename = asset_path.split("/")[-1] 52 | url = ( 53 | "https://uploads.github.com/repos/ruhyadi/yolo3d-lightning/releases/" 54 | + str(release_id) 55 | + f"/assets?name={filename}" 56 | ) 57 | print("[INFO] Uploading {}".format(filename)) 58 | response = requests.post(url, files={"name": open(asset_path, "rb")}, headers=headers) 59 | pprint(response.json()) 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(description="Upload weights to github release") 64 | parser.add_argument("--tag", type=str, default="v0.6", help="tag name") 65 | parser.add_argument("--name", type=str, default="Release v0.6", help="release name") 66 | parser.add_argument("--description", type=str, default="v0.6", help="release description") 67 | parser.add_argument("--assets", type=tuple, default=["weights/mobilenetv3-best.pt", "weights/mobilenetv3-last.pt", "logs/train/runs/2022-09-28_10-36-08/checkpoints/epoch_007.ckpt", "logs/train/runs/2022-09-28_10-36-08/checkpoints/last.ckpt"], help="directory to save weights",) 68 | args = parser.parse_args() 69 | 70 | release_id = create_release(args.tag, args.name, args.description)["id"] 71 | post_assets(args.assets, release_id) 72 | -------------------------------------------------------------------------------- /scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /scripts/video_to_frame.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert video to frame 3 | Usage: 4 | python video_to_frame.py \ 5 | --video_path /path/to/video \ 6 | --output_path /path/to/output/folder \ 7 | --fps 24 8 | 9 | python scripts/video_to_frame.py \ 10 | --video_path tmp/video/20230513_100429.mp4 \ 11 | --output_path tmp/video_001 \ 12 | --fps 20 13 | """ 14 | 15 | import argparse 16 | 17 | import os 18 | import cv2 19 | 20 | 21 | def video_to_frame(video_path: str, output_path: str, fps: int = 5): 22 | """ 23 | Convert video to frame 24 | 25 | Args: 26 | video_path: path to video 27 | output_path: path to output folder 28 | fps: how many frames per second to save 29 | """ 30 | if not os.path.exists(output_path): 31 | os.makedirs(output_path) 32 | 33 | cap = cv2.VideoCapture(video_path) 34 | frame_count = 0 35 | while cap.isOpened(): 36 | ret, frame = cap.read() 37 | if not ret: 38 | break 39 | if frame_count % fps == 0: 40 | cv2.imwrite(os.path.join(output_path, f"{frame_count:06d}.jpg"), frame) 41 | frame_count += 1 42 | 43 | cap.release() 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--video_path", type=str, required=True) 48 | parser.add_argument("--output_path", type=str, required=True) 49 | parser.add_argument("--fps", type=int, default=30) 50 | args = parser.parse_args() 51 | 52 | video_to_frame(args.video_path, args.output_path, args.fps) -------------------------------------------------------------------------------- /scripts/video_to_gif.py: -------------------------------------------------------------------------------- 1 | """Convert video to gif with moviepy""" 2 | 3 | import argparse 4 | import moviepy.editor as mpy 5 | 6 | def generate(video_path, gif_path, fps): 7 | """Generate gif from video""" 8 | clip = mpy.VideoFileClip(video_path) 9 | clip.write_gif(gif_path, fps=fps) 10 | clip.close() 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser(description="Convert video to gif") 14 | parser.add_argument("--video_path", type=str, default="outputs/videos/004.mp4", help="Path to video") 15 | parser.add_argument("--gif_path", type=str, default="outputs/gif/002.gif", help="Path to gif") 16 | parser.add_argument("--fps", type=int, default=5, help="GIF fps") 17 | args = parser.parse_args() 18 | 19 | # generate gif 20 | generate(args.video_path, args.gif_path, args.fps) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="src", 7 | version="0.0.1", 8 | description="Describe Your Cool Project", 9 | author="", 10 | author_email="", 11 | url="https://github.com/user/project", # REPLACE WITH YOUR OWN GITHUB PROJECT LINK 12 | install_requires=["pytorch-lightning", "hydra-core"], 13 | packages=find_packages(), 14 | ) 15 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/__init__.py -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/datamodules/__init__.py -------------------------------------------------------------------------------- /src/datamodules/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/datamodules/components/__init__.py -------------------------------------------------------------------------------- /src/datamodules/kitti_datamodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset lightning class 3 | """ 4 | 5 | from pytorch_lightning import LightningDataModule 6 | from torch.utils.data import DataLoader 7 | from torchvision.transforms import transforms 8 | 9 | from src.datamodules.components.kitti_dataset import KITTIDataset, KITTIDataset2, KITTIDataset3 10 | 11 | class KITTIDataModule(LightningDataModule): 12 | def __init__( 13 | self, 14 | dataset_path: str = './data/KITTI', 15 | train_sets: str = './data/KITTI/train.txt', 16 | val_sets: str = './data/KITTI/val.txt', 17 | test_sets: str = './data/KITTI/test.txt', 18 | batch_size: int = 32, 19 | num_worker: int = 4, 20 | ): 21 | super().__init__() 22 | 23 | # save hyperparameters 24 | self.save_hyperparameters(logger=False) 25 | 26 | # transforms 27 | # TODO: using albumentations 28 | self.dataset_transforms = transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.1307,), (0.3081,)) 31 | ]) 32 | 33 | def setup(self, stage=None): 34 | """ Split dataset to training and validation """ 35 | self.KITTI_train = KITTIDataset(self.hparams.dataset_path, self.hparams.train_sets) 36 | self.KITTI_val = KITTIDataset(self.hparams.dataset_path, self.hparams.val_sets) 37 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets) 38 | # TODO: add test datasets dan test sets 39 | 40 | def train_dataloader(self): 41 | return DataLoader( 42 | dataset=self.KITTI_train, 43 | batch_size=self.hparams.batch_size, 44 | num_workers=self.hparams.num_worker, 45 | shuffle=True 46 | ) 47 | 48 | def val_dataloader(self): 49 | return DataLoader( 50 | dataset=self.KITTI_val, 51 | batch_size=self.hparams.batch_size, 52 | num_workers=self.hparams.num_worker, 53 | shuffle=False 54 | ) 55 | 56 | # def test_dataloader(self): 57 | # return DataLoader( 58 | # dataset=self.KITTI_test, 59 | # batch_size=self.hparams.batch_size, 60 | # num_workers=self.hparams.num_worker, 61 | # shuffle=False 62 | # ) 63 | 64 | class KITTIDataModule2(LightningDataModule): 65 | def __init__( 66 | self, 67 | dataset_path: str = './data/KITTI', 68 | train_sets: str = './data/KITTI/train.txt', 69 | val_sets: str = './data/KITTI/val.txt', 70 | test_sets: str = './data/KITTI/test.txt', 71 | batch_size: int = 32, 72 | num_worker: int = 4, 73 | ): 74 | super().__init__() 75 | 76 | # save hyperparameters 77 | self.save_hyperparameters(logger=False) 78 | 79 | def setup(self, stage=None): 80 | """ Split dataset to training and validation """ 81 | self.KITTI_train = KITTIDataset2(self.hparams.dataset_path, self.hparams.train_sets) 82 | self.KITTI_val = KITTIDataset2(self.hparams.dataset_path, self.hparams.val_sets) 83 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets) 84 | # TODO: add test datasets dan test sets 85 | 86 | def train_dataloader(self): 87 | return DataLoader( 88 | dataset=self.KITTI_train, 89 | batch_size=self.hparams.batch_size, 90 | num_workers=self.hparams.num_worker, 91 | shuffle=True 92 | ) 93 | 94 | def val_dataloader(self): 95 | return DataLoader( 96 | dataset=self.KITTI_val, 97 | batch_size=self.hparams.batch_size, 98 | num_workers=self.hparams.num_worker, 99 | shuffle=False 100 | ) 101 | 102 | class KITTIDataModule3(LightningDataModule): 103 | def __init__( 104 | self, 105 | dataset_path: str = './data/KITTI', 106 | train_sets: str = './data/KITTI/train.txt', 107 | val_sets: str = './data/KITTI/val.txt', 108 | test_sets: str = './data/KITTI/test.txt', 109 | batch_size: int = 32, 110 | num_worker: int = 4, 111 | ): 112 | super().__init__() 113 | 114 | # save hyperparameters 115 | self.save_hyperparameters(logger=False) 116 | 117 | # transforms 118 | # TODO: using albumentations 119 | self.dataset_transforms = transforms.Compose([ 120 | transforms.ToTensor(), 121 | transforms.Normalize((0.1307,), (0.3081,)) 122 | ]) 123 | 124 | def setup(self, stage=None): 125 | """ Split dataset to training and validation """ 126 | self.KITTI_train = KITTIDataset3(self.hparams.dataset_path, self.hparams.train_sets) 127 | self.KITTI_val = KITTIDataset3(self.hparams.dataset_path, self.hparams.val_sets) 128 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets) 129 | # TODO: add test datasets dan test sets 130 | 131 | def train_dataloader(self): 132 | return DataLoader( 133 | dataset=self.KITTI_train, 134 | batch_size=self.hparams.batch_size, 135 | num_workers=self.hparams.num_worker, 136 | shuffle=True 137 | ) 138 | 139 | def val_dataloader(self): 140 | return DataLoader( 141 | dataset=self.KITTI_val, 142 | batch_size=self.hparams.batch_size, 143 | num_workers=self.hparams.num_worker, 144 | shuffle=False 145 | ) 146 | 147 | 148 | if __name__ == '__main__': 149 | 150 | from time import time 151 | 152 | start1 = time() 153 | datamodule1 = KITTIDataModule( 154 | dataset_path='./data/KITTI', 155 | train_sets='./data/KITTI/train_95.txt', 156 | val_sets='./data/KITTI/val_95.txt', 157 | test_sets='./data/KITTI/test_95.txt', 158 | batch_size=5, 159 | ) 160 | datamodule1.setup() 161 | trainloader = datamodule1.val_dataloader() 162 | 163 | for img, label in trainloader: 164 | print(label["Orientation"]) 165 | break 166 | 167 | results1 = (time() - start1) * 1000 168 | 169 | start2 = time() 170 | datamodule2 = KITTIDataModule3( 171 | dataset_path='./data/KITTI', 172 | train_sets='./data/KITTI/train_95.txt', 173 | val_sets='./data/KITTI/val_95.txt', 174 | test_sets='./data/KITTI/test_95.txt', 175 | batch_size=5, 176 | ) 177 | datamodule2.setup() 178 | trainloader = datamodule2.val_dataloader() 179 | 180 | for img, label in trainloader: 181 | print(label["orientation"]) 182 | break 183 | 184 | results2 = (time() - start2) * 1000 185 | 186 | print(f'Time taken for datamodule1: {results1} ms') 187 | print(f'Time taken for datamodule2: {results2} ms') 188 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | 3 | root = pyrootutils.setup_root( 4 | search_from=__file__, 5 | indicator=[".git", "pyproject.toml"], 6 | pythonpath=True, 7 | dotenv=True, 8 | ) 9 | 10 | # ------------------------------------------------------------------------------------ # 11 | # `pyrootutils.setup_root(...)` is recommended at the top of each start file 12 | # to make the environment more robust and consistent 13 | # 14 | # the line above searches for ".git" or "pyproject.toml" in present and parent dirs 15 | # to determine the project root dir 16 | # 17 | # adds root dir to the PYTHONPATH (if `pythonpath=True`) 18 | # so this file can be run from any place without installing project as a package 19 | # 20 | # sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" 21 | # this makes all paths relative to the project root 22 | # 23 | # additionally loads environment variables from ".env" file (if `dotenv=True`) 24 | # 25 | # you can get away without using `pyrootutils.setup_root(...)` if you: 26 | # 1. move this file to the project root dir or install project as a package 27 | # 2. modify paths in "configs/paths/default.yaml" to not use PROJECT_ROOT 28 | # 3. always run this file from the project root dir 29 | # 30 | # https://github.com/ashleve/pyrootutils 31 | # ------------------------------------------------------------------------------------ # 32 | 33 | from typing import List, Tuple 34 | 35 | import hydra 36 | from omegaconf import DictConfig 37 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer 38 | from pytorch_lightning.loggers import LightningLoggerBase 39 | 40 | from src import utils 41 | 42 | log = utils.get_pylogger(__name__) 43 | 44 | 45 | @utils.task_wrapper 46 | def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: 47 | """Evaluates given checkpoint on a datamodule testset. 48 | 49 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities 50 | before and after the call. 51 | 52 | Args: 53 | cfg (DictConfig): Configuration composed by Hydra. 54 | 55 | Returns: 56 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 57 | """ 58 | 59 | assert cfg.ckpt_path 60 | 61 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 62 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 63 | 64 | log.info(f"Instantiating model <{cfg.model._target_}>") 65 | model: LightningModule = hydra.utils.instantiate(cfg.model) 66 | 67 | log.info("Instantiating loggers...") 68 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) 69 | 70 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 71 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 72 | 73 | object_dict = { 74 | "cfg": cfg, 75 | "datamodule": datamodule, 76 | "model": model, 77 | "logger": logger, 78 | "trainer": trainer, 79 | } 80 | 81 | if logger: 82 | log.info("Logging hyperparameters!") 83 | utils.log_hyperparameters(object_dict) 84 | 85 | log.info("Starting testing!") 86 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) 87 | 88 | # for predictions use trainer.predict(...) 89 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) 90 | 91 | metric_dict = trainer.callback_metrics 92 | 93 | return metric_dict, object_dict 94 | 95 | 96 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="eval.yaml") 97 | def main(cfg: DictConfig) -> None: 98 | evaluate(cfg) 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/models/components/__init__.py -------------------------------------------------------------------------------- /src/models/components/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Regressor Model 3 | """ 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torchvision import models 8 | 9 | class RegressorNet(nn.Module): 10 | def __init__( 11 | self, 12 | backbone: nn.Module, 13 | bins: int, 14 | ): 15 | super().__init__() 16 | # init model 17 | self.in_features = self._get_in_features(backbone) 18 | self.model = nn.Sequential(*(list(backbone.children())[:-2])) 19 | self.bins = bins 20 | 21 | # # orientation head, for orientation estimation 22 | # self.orientation = nn.Sequential( 23 | # nn.Linear(self.in_features, 256), 24 | # nn.ReLU(True), 25 | # nn.Dropout(), 26 | # nn.Linear(256, 256), 27 | # nn.ReLU(True), 28 | # nn.Dropout(), 29 | # nn.Linear(256, self.bins*2) # 4 bins 30 | # ) 31 | 32 | # # confident head, for orientation estimation 33 | # self.confidence = nn.Sequential( 34 | # nn.Linear(self.in_features, 256), 35 | # nn.ReLU(True), 36 | # nn.Dropout(), 37 | # nn.Linear(256, 256), 38 | # nn.ReLU(True), 39 | # nn.Dropout(), 40 | # nn.Linear(256, self.bins), 41 | # nn.Sigmoid() 42 | # ) 43 | self.orientation = nn.Sequential( 44 | nn.Linear(self.in_features, 1024), 45 | nn.ReLU(True), 46 | nn.Dropout(), 47 | nn.Linear(1024, 1024), 48 | nn.ReLU(True), 49 | nn.Dropout(), 50 | nn.Linear(1024, self.bins*2) # 4 bins 51 | ) 52 | 53 | # confident head, for orientation estimation 54 | self.confidence = nn.Sequential( 55 | nn.Linear(self.in_features, 512), 56 | nn.ReLU(True), 57 | nn.Dropout(), 58 | nn.Linear(512, 512), 59 | nn.ReLU(True), 60 | nn.Dropout(), 61 | nn.Linear(512, self.bins), 62 | nn.Sigmoid() 63 | ) 64 | # dimension head 65 | self.dimension = nn.Sequential( 66 | nn.Linear(self.in_features, 512), 67 | nn.ReLU(True), 68 | nn.Dropout(), 69 | nn.Linear(512, 512), 70 | nn.ReLU(True), 71 | nn.Dropout(), 72 | nn.Linear(512, 3) # x, y, z 73 | ) 74 | 75 | def forward(self, x): 76 | x = self.model(x) 77 | x = x.view(-1, self.in_features) 78 | 79 | orientation = self.orientation(x) 80 | orientation = orientation.view(-1, self.bins, 2) 81 | # orientation = F.normalize(orientation, dim=2) 82 | # TODO: export model use this 83 | orientation = orientation.div(orientation.norm(dim=2, keepdim=True)) 84 | 85 | confidence = self.confidence(x) 86 | 87 | dimension = self.dimension(x) 88 | 89 | return orientation, confidence, dimension 90 | 91 | def _get_in_features(self, net: nn.Module): 92 | 93 | # TODO: add more models 94 | in_features = { 95 | 'resnet': (lambda: net.fc.in_features * 7 * 7), # 512 * 7 * 7 = 25088 96 | 'vgg': (lambda: net.classifier[0].in_features), # 512 * 7 * 7 = 25088 97 | # 'mobilenetv3_large': (lambda: (net.classifier[0].in_features) * 7 * 7), # 960 * 7 * 7 = 47040 98 | 'mobilenetv3': (lambda: (net.classifier[0].in_features) * 7 * 7), # 576 * 7 * 7 = 28416 99 | } 100 | 101 | return in_features[(net.__class__.__name__).lower()]() 102 | 103 | 104 | class RegressorNet2(nn.Module): 105 | def __init__( 106 | self, 107 | backbone: nn.Module, 108 | bins: int, 109 | ): 110 | super().__init__() 111 | 112 | # init model 113 | self.in_features = self._get_in_features(backbone) 114 | self.model = nn.Sequential(*(list(backbone.children())[:-2])) 115 | self.bins = bins 116 | 117 | # orientation head, for orientation estimation 118 | # TODO: inprove 256 to 1024 119 | self.orientation = nn.Sequential( 120 | nn.Linear(self.in_features, 256), 121 | nn.LeakyReLU(0.1), 122 | nn.Dropout(), 123 | nn.Linear(256, self.bins*2), # 4 bins 124 | nn.LeakyReLU(0.1) 125 | ) 126 | 127 | # confident head, for orientation estimation 128 | self.confidence = nn.Sequential( 129 | nn.Linear(self.in_features, 256), 130 | nn.LeakyReLU(0.1), 131 | nn.Dropout(), 132 | nn.Linear(256, self.bins), 133 | nn.LeakyReLU(0.1) 134 | ) 135 | 136 | # dimension head 137 | self.dimension = nn.Sequential( 138 | nn.Linear(self.in_features, 512), 139 | nn.LeakyReLU(0.1), 140 | nn.Dropout(), 141 | nn.Linear(512, 3), # x, y, z 142 | nn.LeakyReLU(0.1) 143 | ) 144 | 145 | def forward(self, x): 146 | x = self.model(x) 147 | x = x.view(-1, self.in_features) 148 | 149 | orientation = self.orientation(x) 150 | orientation = orientation.view(-1, self.bins, 2) 151 | # TODO: export model use this 152 | orientation = orientation.div(orientation.norm(dim=2, keepdim=True)) 153 | 154 | confidence = self.confidence(x) 155 | 156 | dimension = self.dimension(x) 157 | 158 | return orientation, confidence, dimension 159 | 160 | def _get_in_features(self, net: nn.Module): 161 | 162 | # TODO: add more models 163 | in_features = { 164 | 'resnet': (lambda: net.fc.in_features * 7 * 7), 165 | 'vgg': (lambda: net.classifier[0].in_features) 166 | } 167 | 168 | return in_features[(net.__class__.__name__).lower()]() 169 | 170 | 171 | def OrientationLoss(orient_batch, orientGT_batch, confGT_batch): 172 | """ 173 | Orientation loss function 174 | """ 175 | batch_size = orient_batch.size()[0] 176 | indexes = torch.max(confGT_batch, dim=1)[1] 177 | 178 | # extract important bin 179 | orientGT_batch = orientGT_batch[torch.arange(batch_size), indexes] 180 | orient_batch = orient_batch[torch.arange(batch_size), indexes] 181 | 182 | theta_diff = torch.atan2(orientGT_batch[:,1], orientGT_batch[:,0]) 183 | estimated_theta_diff = torch.atan2(orient_batch[:,1], orient_batch[:,0]) 184 | 185 | return 2 - 2 * torch.cos(theta_diff - estimated_theta_diff).mean() 186 | # return -torch.cos(theta_diff - estimated_theta_diff).mean() 187 | 188 | 189 | def orientation_loss2(y_pred, y_true): 190 | """ 191 | Orientation loss function 192 | input: y_true -- (batch_size, bin, 2) ground truth orientation value in cos and sin form. 193 | y_pred -- (batch_size, bin, 2) estimated orientation value from the ConvNet 194 | output: loss -- loss values for orientation 195 | """ 196 | 197 | # sin^2 + cons^2 198 | anchors = torch.sum(y_true ** 2, dim=2) 199 | # check which bin valid 200 | anchors = torch.gt(anchors, 0.5) 201 | # add valid bin 202 | anchors = torch.sum(anchors.type(torch.float32), dim=1) 203 | 204 | # cos(true)cos(estimate) + sin(true)sin(estimate) 205 | loss = (y_true[:, : ,0] * y_pred[:, :, 0] + y_true[:, :, 1] * y_pred[:, :, 1]) 206 | # the mean value in each bin 207 | loss = torch.sum(loss, dim=1) / anchors 208 | # sum the value at each bin 209 | loss = torch.mean(loss) 210 | loss = 2 - 2 * loss 211 | 212 | return loss 213 | 214 | def get_model(backbone: str): 215 | """ 216 | Get truncated model and in_features 217 | """ 218 | 219 | # list of support model name 220 | # TODO: add more models 221 | list_model = ['resnet18', 'vgg11'] 222 | # model_name = str(backbone.__class__.__name__).lower() 223 | assert backbone in list_model, f"Model not support, please choose {list_model}" 224 | 225 | # TODO: change if else with attributes 226 | in_features = None 227 | model = None 228 | if backbone == 'resnet18': 229 | backbone = models.resnet18(pretrained=True) 230 | in_features = backbone.fc.in_features * 7 * 7 231 | model = nn.Sequential(*(list(backbone.children())[:-2])) 232 | elif backbone == 'vgg11': 233 | backbone = models.vgg11(pretrained=True) 234 | in_features = backbone.classifier[0].in_features 235 | model = backbone.features 236 | 237 | return [model, in_features] 238 | 239 | 240 | if __name__ == '__main__': 241 | 242 | # from torchvision.models import resnet18 243 | # from torchsummary import summary 244 | 245 | # backbone = resnet18(pretrained=False) 246 | # model = RegressorNet(backbone, 2) 247 | 248 | # input_size = (3, 224, 224) 249 | # summary(model, input_size, device='cpu') 250 | 251 | # test orientation loss 252 | y_true = torch.tensor([[[0.0, 0.0], [0.9362, 0.3515]]]) 253 | y_pred = torch.tensor([[[0.0, 0.0], [0.9362, 0.3515]]]) 254 | 255 | print(y_true, "\n", y_pred) 256 | print(orientation_loss2(y_pred, y_true)) 257 | -------------------------------------------------------------------------------- /src/models/regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Regressor Model 3 | """ 4 | import torch 5 | from torch import nn 6 | from pytorch_lightning import LightningModule 7 | 8 | from src.models.components.base import OrientationLoss, orientation_loss2 9 | 10 | 11 | class RegressorModel(LightningModule): 12 | def __init__( 13 | self, 14 | net: nn.Module, 15 | optimizer: str = "adam", 16 | lr: float = 0.0001, 17 | momentum: float = 0.9, 18 | w: float = 0.4, 19 | alpha: float = 0.6, 20 | ): 21 | super().__init__() 22 | 23 | # save hyperparamters 24 | self.save_hyperparameters(logger=False) 25 | 26 | # init model 27 | self.net = net 28 | 29 | # loss functions 30 | self.conf_loss_func = nn.CrossEntropyLoss() 31 | self.dim_loss_func = nn.MSELoss() 32 | self.orient_loss_func = OrientationLoss 33 | 34 | # TODO: export model use this 35 | def forward(self, x): 36 | output = self.net(x) 37 | orient = output[0] 38 | conf = output[1] 39 | dim = output[2] 40 | return [orient, conf, dim] 41 | 42 | # def forward(self, x): 43 | # return self.net(x) 44 | 45 | def on_train_start(self): 46 | # by default lightning executes validation step sanity checks before training starts, 47 | # so we need to make sure val_acc_best doesn't store accuracy from these checks 48 | # self.val_acc_best.reset() 49 | pass 50 | 51 | def step(self, batch): 52 | x, y = batch 53 | 54 | # convert to float 55 | x = x.float() 56 | truth_orient = y["Orientation"].float() 57 | truth_conf = y["Confidence"].float() # front or back 58 | truth_dim = y["Dimensions"].float() 59 | 60 | # predict y_hat 61 | preds = self(x) 62 | [orient, conf, dim] = preds 63 | 64 | # compute loss 65 | orient_loss = self.orient_loss_func(orient, truth_orient, truth_conf) 66 | dim_loss = self.dim_loss_func(dim, truth_dim) 67 | # truth_conf = torch.max(truth_conf, dim=1)[1] 68 | conf_loss = self.conf_loss_func(conf, truth_conf) 69 | 70 | loss_theta = conf_loss + 1.5 * self.hparams.w * orient_loss 71 | loss = self.hparams.alpha * dim_loss + loss_theta 72 | 73 | return [loss, loss_theta, orient_loss, dim_loss, conf_loss], preds, y 74 | 75 | def training_step(self, batch, batch_idx): 76 | loss, preds, targets = self.step(batch) 77 | 78 | # logging 79 | self.log_dict( 80 | { 81 | "train/loss": loss[0], 82 | "train/theta_loss": loss[1], 83 | "train/orient_loss": loss[2], 84 | "train/dim_loss": loss[3], 85 | "train/conf_loss": loss[4], 86 | }, 87 | on_step=False, 88 | on_epoch=True, 89 | prog_bar=False, 90 | ) 91 | return {"loss": loss[0], "preds": preds, "targets": targets} 92 | 93 | def training_epoch_end(self, outputs): 94 | # `outputs` is a list of dicts returned from `training_step()` 95 | pass 96 | 97 | def validation_step(self, batch, batch_idx): 98 | loss, preds, targets = self.step(batch) 99 | 100 | # logging 101 | self.log_dict( 102 | { 103 | "val/loss": loss[0], 104 | "val/theta_loss": loss[1], 105 | "val/orient_loss": loss[2], 106 | "val/dim_loss": loss[3], 107 | "val/conf_loss": loss[4], 108 | }, 109 | on_step=False, 110 | on_epoch=True, 111 | prog_bar=False, 112 | ) 113 | return {"loss": loss[0], "preds": preds, "targets": targets} 114 | 115 | def validation_epoch_end(self, outputs): 116 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean() 117 | 118 | # log to tensorboard 119 | self.log("val/avg_loss", avg_val_loss) 120 | return {"loss": avg_val_loss} 121 | 122 | def on_epoch_end(self): 123 | # reset metrics at the end of every epoch 124 | pass 125 | 126 | def configure_optimizers(self): 127 | if self.hparams.optimizer.lower() == "adam": 128 | optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr) 129 | elif self.hparams.optimizer.lower() == "sgd": 130 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, 131 | momentum=self.hparams.momentum 132 | ) 133 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5) 134 | return optimizer 135 | 136 | class RegressorModel2(LightningModule): 137 | def __init__( 138 | self, 139 | net: nn.Module, 140 | lr: float = 0.0001, 141 | momentum: float = 0.9, 142 | w: float = 0.4, 143 | alpha: float = 0.6, 144 | ): 145 | super().__init__() 146 | 147 | # save hyperparamters 148 | self.save_hyperparameters(logger=False) 149 | 150 | # init model 151 | self.net = net 152 | 153 | # loss functions 154 | self.conf_loss_func = nn.CrossEntropyLoss() 155 | self.dim_loss_func = nn.MSELoss() 156 | self.orient_loss_func = orientation_loss2 157 | 158 | def forward(self, x): 159 | return self.net(x) 160 | 161 | def on_train_start(self): 162 | # by default lightning executes validation step sanity checks before training starts, 163 | # so we need to make sure val_acc_best doesn't store accuracy from these checks 164 | # self.val_acc_best.reset() 165 | pass 166 | 167 | def step(self, batch): 168 | x, y = batch 169 | 170 | # convert to float 171 | x = x.float() 172 | gt_orient = y["orientation"].float() 173 | gt_conf = y["confidence"].float() 174 | gt_dims = y["dimensions"].float() 175 | 176 | # predict y_true 177 | predictions = self(x) 178 | [pred_orient, pred_conf, pred_dims] = predictions 179 | 180 | # compute loss 181 | loss_orient = self.orient_loss_func(pred_orient, gt_orient) 182 | loss_dims = self.dim_loss_func(pred_dims, gt_dims) 183 | gt_conf = torch.max(gt_conf, dim=1)[1] 184 | loss_conf = self.conf_loss_func(pred_conf, gt_conf) 185 | # weighting loss => see paper 186 | loss_theta = loss_conf + (self.hparams.w * loss_orient) 187 | loss = (self.hparams.alpha * loss_dims) + loss_theta 188 | 189 | return [loss, loss_theta, loss_orient, loss_conf, loss_dims], predictions, y 190 | 191 | def training_step(self, batch, batch_idx): 192 | loss, preds, targets = self.step(batch) 193 | 194 | # logging 195 | self.log_dict( 196 | { 197 | "train/loss": loss[0], 198 | "train/theta_loss": loss[1], 199 | "train/orient_loss": loss[2], 200 | "train/conf_loss": loss[3], 201 | "train/dim_loss": loss[4], 202 | }, 203 | on_step=False, 204 | on_epoch=True, 205 | prog_bar=False, 206 | ) 207 | return {"loss": loss[0], "preds": preds, "targets": targets} 208 | 209 | def training_epoch_end(self, outputs): 210 | # `outputs` is a list of dicts returned from `training_step()` 211 | pass 212 | 213 | def validation_step(self, batch, batch_idx): 214 | loss, preds, targets = self.step(batch) 215 | 216 | # logging 217 | self.log_dict( 218 | { 219 | "val/loss": loss[0], 220 | "val/theta_loss": loss[1], 221 | "val/orient_loss": loss[2], 222 | "val/conf_loss": loss[3], 223 | "val/dim_loss": loss[4], 224 | }, 225 | on_step=False, 226 | on_epoch=True, 227 | prog_bar=False, 228 | ) 229 | return {"loss": loss[0], "preds": preds, "targets": targets} 230 | 231 | def validation_epoch_end(self, outputs): 232 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean() 233 | 234 | # log to tensorboard 235 | self.log("val/avg_loss", avg_val_loss) 236 | return {"loss": avg_val_loss} 237 | 238 | def on_epoch_end(self): 239 | # reset metrics at the end of every epoch 240 | pass 241 | 242 | def configure_optimizers(self): 243 | # optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr) 244 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, 245 | momentum=self.hparams.momentum 246 | ) 247 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2) 248 | 249 | return optimizer 250 | 251 | class RegressorModel3(LightningModule): 252 | def __init__( 253 | self, 254 | net: nn.Module, 255 | optimizer: str = "adam", 256 | lr: float = 0.0001, 257 | momentum: float = 0.9, 258 | w: float = 0.4, 259 | alpha: float = 0.6, 260 | ): 261 | super().__init__() 262 | 263 | # save hyperparamters 264 | self.save_hyperparameters(logger=False) 265 | 266 | # init model 267 | self.net = net 268 | 269 | # loss functions 270 | self.conf_loss_func = nn.CrossEntropyLoss() 271 | self.dim_loss_func = nn.MSELoss() 272 | self.orient_loss_func = OrientationLoss 273 | 274 | def forward(self, x): 275 | return self.net(x) 276 | 277 | def on_train_start(self): 278 | # by default lightning executes validation step sanity checks before training starts, 279 | # so we need to make sure val_acc_best doesn't store accuracy from these checks 280 | # self.val_acc_best.reset() 281 | pass 282 | 283 | def step(self, batch): 284 | x, y = batch 285 | 286 | # convert to float 287 | x = x.float() 288 | gt_orient = y["orientation"].float() 289 | gt_conf = y["confidence"].float() 290 | gt_dims = y["dimensions"].float() 291 | 292 | # predict y_true 293 | predictions = self(x) 294 | [pred_orient, pred_conf, pred_dims] = predictions 295 | 296 | # compute loss 297 | loss_orient = self.orient_loss_func(pred_orient, gt_orient, gt_conf) 298 | loss_dims = self.dim_loss_func(pred_dims, gt_dims) 299 | gt_conf = torch.max(gt_conf, dim=1)[1] 300 | loss_conf = self.conf_loss_func(pred_conf, gt_conf) 301 | # weighting loss => see paper 302 | loss_theta = loss_conf + (self.hparams.w * loss_orient) 303 | loss = (self.hparams.alpha * loss_dims) + loss_theta 304 | 305 | return [loss, loss_theta, loss_orient, loss_conf, loss_dims], predictions, y 306 | 307 | def training_step(self, batch, batch_idx): 308 | loss, preds, targets = self.step(batch) 309 | 310 | # logging 311 | self.log_dict( 312 | { 313 | "train/loss": loss[0], 314 | "train/theta_loss": loss[1], 315 | "train/orient_loss": loss[2], 316 | "train/conf_loss": loss[3], 317 | "train/dim_loss": loss[4], 318 | }, 319 | on_step=False, 320 | on_epoch=True, 321 | prog_bar=False, 322 | ) 323 | return {"loss": loss[0], "preds": preds, "targets": targets} 324 | 325 | def training_epoch_end(self, outputs): 326 | # `outputs` is a list of dicts returned from `training_step()` 327 | pass 328 | 329 | def validation_step(self, batch, batch_idx): 330 | loss, preds, targets = self.step(batch) 331 | 332 | # logging 333 | self.log_dict( 334 | { 335 | "val/loss": loss[0], 336 | "val/theta_loss": loss[1], 337 | "val/orient_loss": loss[2], 338 | "val/conf_loss": loss[3], 339 | "val/dim_loss": loss[4], 340 | }, 341 | on_step=False, 342 | on_epoch=True, 343 | prog_bar=False, 344 | ) 345 | return {"loss": loss[0], "preds": preds, "targets": targets} 346 | 347 | def validation_epoch_end(self, outputs): 348 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean() 349 | 350 | # log to tensorboard 351 | self.log("val/avg_loss", avg_val_loss) 352 | return {"loss": avg_val_loss} 353 | 354 | def on_epoch_end(self): 355 | # reset metrics at the end of every epoch 356 | pass 357 | 358 | def configure_optimizers(self): 359 | if self.hparams.optimizer.lower() == "adam": 360 | optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr) 361 | elif self.hparams.optimizer.lower() == "sgd": 362 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, 363 | momentum=self.hparams.momentum 364 | ) 365 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5) 366 | 367 | return optimizer 368 | 369 | if __name__ == "__main__": 370 | 371 | from src.models.components.base import RegressorNet 372 | from torchvision.models import resnet18 373 | 374 | model1 = RegressorModel( 375 | net=RegressorNet(backbone=resnet18(pretrained=False), bins=2), 376 | ) 377 | 378 | print(model1) 379 | 380 | model2 = RegressorModel3( 381 | net=RegressorNet(backbone=resnet18(pretrained=False), bins=2), 382 | ) 383 | 384 | print(model2) 385 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | 3 | root = pyrootutils.setup_root( 4 | search_from=__file__, 5 | indicator=[".git", "pyproject.toml"], 6 | pythonpath=True, 7 | dotenv=True, 8 | ) 9 | 10 | # ------------------------------------------------------------------------------------ # 11 | # `pyrootutils.setup_root(...)` is recommended at the top of each start file 12 | # to make the environment more robust and consistent 13 | # 14 | # the line above searches for ".git" or "pyproject.toml" in present and parent dirs 15 | # to determine the project root dir 16 | # 17 | # adds root dir to the PYTHONPATH (if `pythonpath=True`) 18 | # so this file can be run from any place without installing project as a package 19 | # 20 | # sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml" 21 | # this makes all paths relative to the project root 22 | # 23 | # additionally loads environment variables from ".env" file (if `dotenv=True`) 24 | # 25 | # you can get away without using `pyrootutils.setup_root(...)` if you: 26 | # 1. move this file to the project root dir or install project as a package 27 | # 2. modify paths in "configs/paths/default.yaml" to not use PROJECT_ROOT 28 | # 3. always run this file from the project root dir 29 | # 30 | # https://github.com/ashleve/pyrootutils 31 | # ------------------------------------------------------------------------------------ # 32 | 33 | from typing import List, Optional, Tuple 34 | 35 | import hydra 36 | import pytorch_lightning as pl 37 | from omegaconf import DictConfig 38 | from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer 39 | from pytorch_lightning.loggers import LightningLoggerBase 40 | 41 | from src import utils 42 | 43 | log = utils.get_pylogger(__name__) 44 | 45 | 46 | @utils.task_wrapper 47 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 48 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 49 | training. 50 | 51 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities 52 | before and after the call. 53 | 54 | Args: 55 | cfg (DictConfig): Configuration composed by Hydra. 56 | 57 | Returns: 58 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 59 | """ 60 | 61 | # set seed for random number generators in pytorch, numpy and python.random 62 | if cfg.get("seed"): 63 | pl.seed_everything(cfg.seed, workers=True) 64 | 65 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 66 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 67 | 68 | log.info(f"Instantiating model <{cfg.model._target_}>") 69 | model: LightningModule = hydra.utils.instantiate(cfg.model) 70 | 71 | log.info("Instantiating callbacks...") 72 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 73 | 74 | log.info("Instantiating loggers...") 75 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger")) 76 | 77 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 78 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 79 | 80 | object_dict = { 81 | "cfg": cfg, 82 | "datamodule": datamodule, 83 | "model": model, 84 | "callbacks": callbacks, 85 | "logger": logger, 86 | "trainer": trainer, 87 | } 88 | 89 | if logger: 90 | log.info("Logging hyperparameters!") 91 | utils.log_hyperparameters(object_dict) 92 | 93 | # train 94 | if cfg.get("train"): 95 | log.info("Starting training!") 96 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 97 | 98 | train_metrics = trainer.callback_metrics 99 | 100 | if cfg.get("test"): 101 | log.info("Starting testing!") 102 | ckpt_path = trainer.checkpoint_callback.best_model_path 103 | if ckpt_path == "": 104 | log.warning("Best ckpt not found! Using current weights for testing...") 105 | ckpt_path = None 106 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 107 | log.info(f"Best ckpt path: {ckpt_path}") 108 | 109 | test_metrics = trainer.callback_metrics 110 | 111 | # merge train and test metrics 112 | metric_dict = {**train_metrics, **test_metrics} 113 | 114 | return metric_dict, object_dict 115 | 116 | 117 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="train.yaml") 118 | def main(cfg: DictConfig) -> Optional[float]: 119 | 120 | # train the model 121 | metric_dict, _ = train(cfg) 122 | 123 | # safely retrieve metric value for hydra-based hyperparameter optimization 124 | metric_value = utils.get_metric_value( 125 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 126 | ) 127 | 128 | # return optimized metric 129 | return metric_value 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /src/utils/Calib.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for handling calibration file 3 | """ 4 | 5 | import numpy as np 6 | 7 | def get_P(calib_file): 8 | """ 9 | Get matrix P_rect_02 (camera 2 RGB) 10 | and transform to 3 x 4 matrix 11 | """ 12 | for line in open(calib_file, 'r'): 13 | if 'P_rect_02' in line: 14 | cam_P = line.strip().split(' ') 15 | cam_P = np.asarray([float(cam_P) for cam_P in cam_P[1:]]) 16 | matrix = np.zeros((3, 4)) 17 | matrix = cam_P.reshape((3, 4)) 18 | return matrix 19 | 20 | # TODO: understand this 21 | 22 | def get_calibration_cam_to_image(cab_f): 23 | for line in open(cab_f): 24 | if 'P2:' in line: 25 | cam_to_img = line.strip().split(' ') 26 | cam_to_img = np.asarray([float(number) for number in cam_to_img[1:]]) 27 | cam_to_img = np.reshape(cam_to_img, (3, 4)) 28 | return cam_to_img 29 | 30 | file_not_found(cab_f) 31 | 32 | def get_R0(cab_f): 33 | for line in open(cab_f): 34 | if 'R0_rect:' in line: 35 | R0 = line.strip().split(' ') 36 | R0 = np.asarray([float(number) for number in R0[1:]]) 37 | R0 = np.reshape(R0, (3, 3)) 38 | 39 | R0_rect = np.zeros([4,4]) 40 | R0_rect[3,3] = 1 41 | R0_rect[:3,:3] = R0 42 | 43 | return R0_rect 44 | 45 | def get_tr_to_velo(cab_f): 46 | for line in open(cab_f): 47 | if 'Tr_velo_to_cam:' in line: 48 | Tr = line.strip().split(' ') 49 | Tr = np.asarray([float(number) for number in Tr[1:]]) 50 | Tr = np.reshape(Tr, (3, 4)) 51 | 52 | Tr_to_velo = np.zeros([4,4]) 53 | Tr_to_velo[3,3] = 1 54 | Tr_to_velo[:3,:4] = Tr 55 | 56 | return Tr_to_velo 57 | 58 | def file_not_found(filename): 59 | print("\nError! Can't read calibration file, does %s exist?"%filename) 60 | exit() 61 | 62 | -------------------------------------------------------------------------------- /src/utils/Math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # using this math: https://en.wikipedia.org/wiki/Rotation_matrix 4 | def rotation_matrix(yaw, pitch=0, roll=0): 5 | tx = roll 6 | ty = yaw 7 | tz = pitch 8 | 9 | Rx = np.array([[1,0,0], [0, np.cos(tx), -np.sin(tx)], [0, np.sin(tx), np.cos(tx)]]) 10 | Ry = np.array([[np.cos(ty), 0, np.sin(ty)], [0, 1, 0], [-np.sin(ty), 0, np.cos(ty)]]) 11 | Rz = np.array([[np.cos(tz), -np.sin(tz), 0], [np.sin(tz), np.cos(tz), 0], [0,0,1]]) 12 | 13 | 14 | return Ry.reshape([3,3]) 15 | # return np.dot(np.dot(Rz,Ry), Rx) 16 | 17 | # option to rotate and shift (for label info) 18 | def create_corners(dimension, location=None, R=None): 19 | dx = dimension[2] / 2 20 | dy = dimension[0] / 2 21 | dz = dimension[1] / 2 22 | 23 | x_corners = [] 24 | y_corners = [] 25 | z_corners = [] 26 | 27 | for i in [1, -1]: 28 | for j in [1,-1]: 29 | for k in [1,-1]: 30 | x_corners.append(dx*i) 31 | y_corners.append(dy*j) 32 | z_corners.append(dz*k) 33 | 34 | corners = [x_corners, y_corners, z_corners] 35 | 36 | # rotate if R is passed in 37 | if R is not None: 38 | corners = np.dot(R, corners) 39 | 40 | # shift if location is passed in 41 | if location is not None: 42 | for i,loc in enumerate(location): 43 | corners[i,:] = corners[i,:] + loc 44 | 45 | final_corners = [] 46 | for i in range(8): 47 | final_corners.append([corners[0][i], corners[1][i], corners[2][i]]) 48 | 49 | 50 | return final_corners 51 | 52 | # this is based on the paper. Math! 53 | # calib is a 3x4 matrix, box_2d is [(xmin, ymin), (xmax, ymax)] 54 | # Math help: http://ywpkwon.github.io/pdf/bbox3d-study.pdf 55 | def calc_location(dimension, proj_matrix, box_2d, alpha, theta_ray): 56 | #global orientation 57 | orient = alpha + theta_ray 58 | R = rotation_matrix(orient) 59 | 60 | # format 2d corners 61 | try: 62 | xmin = box_2d[0][0] 63 | ymin = box_2d[0][1] 64 | xmax = box_2d[1][0] 65 | ymax = box_2d[1][1] 66 | except: 67 | xmin = box_2d[0] 68 | ymin = box_2d[1] 69 | xmax = box_2d[2] 70 | ymax = box_2d[3] 71 | 72 | # left top right bottom 73 | box_corners = [xmin, ymin, xmax, ymax] 74 | 75 | # get the point constraints 76 | constraints = [] 77 | 78 | left_constraints = [] 79 | right_constraints = [] 80 | top_constraints = [] 81 | bottom_constraints = [] 82 | 83 | # using a different coord system 84 | dx = dimension[2] / 2 85 | dy = dimension[0] / 2 86 | dz = dimension[1] / 2 87 | 88 | # below is very much based on trial and error 89 | 90 | # based on the relative angle, a different configuration occurs 91 | # negative is back of car, positive is front 92 | left_mult = 1 93 | right_mult = -1 94 | 95 | # about straight on but opposite way 96 | if alpha < np.deg2rad(92) and alpha > np.deg2rad(88): 97 | left_mult = 1 98 | right_mult = 1 99 | # about straight on and same way 100 | elif alpha < np.deg2rad(-88) and alpha > np.deg2rad(-92): 101 | left_mult = -1 102 | right_mult = -1 103 | # this works but doesnt make much sense 104 | elif alpha < np.deg2rad(90) and alpha > -np.deg2rad(90): 105 | left_mult = -1 106 | right_mult = 1 107 | 108 | # if the car is facing the oppositeway, switch left and right 109 | switch_mult = -1 110 | if alpha > 0: 111 | switch_mult = 1 112 | 113 | # left and right could either be the front of the car ot the back of the car 114 | # careful to use left and right based on image, no of actual car's left and right 115 | for i in (-1,1): 116 | left_constraints.append([left_mult * dx, i*dy, -switch_mult * dz]) 117 | for i in (-1,1): 118 | right_constraints.append([right_mult * dx, i*dy, switch_mult * dz]) 119 | 120 | # top and bottom are easy, just the top and bottom of car 121 | for i in (-1,1): 122 | for j in (-1,1): 123 | top_constraints.append([i*dx, -dy, j*dz]) 124 | for i in (-1,1): 125 | for j in (-1,1): 126 | bottom_constraints.append([i*dx, dy, j*dz]) 127 | 128 | # now, 64 combinations 129 | for left in left_constraints: 130 | for top in top_constraints: 131 | for right in right_constraints: 132 | for bottom in bottom_constraints: 133 | constraints.append([left, top, right, bottom]) 134 | 135 | # filter out the ones with repeats 136 | constraints = filter(lambda x: len(x) == len(set(tuple(i) for i in x)), constraints) 137 | 138 | # create pre M (the term with I and the R*X) 139 | pre_M = np.zeros([4,4]) 140 | # 1's down diagonal 141 | for i in range(0,4): 142 | pre_M[i][i] = 1 143 | 144 | best_loc = None 145 | best_error = [1e09] 146 | best_X = None 147 | 148 | # loop through each possible constraint, hold on to the best guess 149 | # constraint will be 64 sets of 4 corners 150 | count = 0 151 | for constraint in constraints: 152 | # each corner 153 | Xa = constraint[0] 154 | Xb = constraint[1] 155 | Xc = constraint[2] 156 | Xd = constraint[3] 157 | 158 | X_array = [Xa, Xb, Xc, Xd] 159 | 160 | # M: all 1's down diagonal, and upper 3x1 is Rotation_matrix * [x, y, z] 161 | Ma = np.copy(pre_M) 162 | Mb = np.copy(pre_M) 163 | Mc = np.copy(pre_M) 164 | Md = np.copy(pre_M) 165 | 166 | M_array = [Ma, Mb, Mc, Md] 167 | 168 | # create A, b 169 | A = np.zeros([4,3], dtype=np.float) 170 | b = np.zeros([4,1]) 171 | 172 | indicies = [0,1,0,1] 173 | for row, index in enumerate(indicies): 174 | X = X_array[row] 175 | M = M_array[row] 176 | 177 | # create M for corner Xx 178 | RX = np.dot(R, X) 179 | M[:3,3] = RX.reshape(3) 180 | 181 | M = np.dot(proj_matrix, M) 182 | 183 | A[row, :] = M[index,:3] - box_corners[row] * M[2,:3] 184 | b[row] = box_corners[row] * M[2,3] - M[index,3] 185 | 186 | # solve here with least squares, since over fit will get some error 187 | loc, error, rank, s = np.linalg.lstsq(A, b, rcond=None) 188 | 189 | # found a better estimation 190 | if error < best_error: 191 | count += 1 # for debugging 192 | best_loc = loc 193 | best_error = error 194 | best_X = X_array 195 | 196 | # return best_loc, [left_constraints, right_constraints] # for debugging 197 | best_loc = [best_loc[0][0], best_loc[1][0], best_loc[2][0]] 198 | return best_loc, best_X 199 | 200 | """ 201 | Code for generating new plot with bev and 3dbbox 202 | source: https://github.com/lzccccc/3d-bounding-box-estimation-for-autonomous-driving 203 | """ 204 | 205 | def get_new_alpha(alpha): 206 | """ 207 | change the range of orientation from [-pi, pi] to [0, 2pi] 208 | :param alpha: original orientation in KITTI 209 | :return: new alpha 210 | """ 211 | new_alpha = float(alpha) + np.pi / 2. 212 | if new_alpha < 0: 213 | new_alpha = new_alpha + 2. * np.pi 214 | # make sure angle lies in [0, 2pi] 215 | new_alpha = new_alpha - int(new_alpha / (2. * np.pi)) * (2. * np.pi) 216 | 217 | return new_alpha 218 | 219 | def recover_angle(bin_anchor, bin_confidence, bin_num): 220 | # select anchor from bins 221 | max_anc = np.argmax(bin_confidence) 222 | anchors = bin_anchor[max_anc] 223 | # compute the angle offset 224 | if anchors[1] > 0: 225 | angle_offset = np.arccos(anchors[0]) 226 | else: 227 | angle_offset = -np.arccos(anchors[0]) 228 | 229 | # add the angle offset to the center ray of each bin to obtain the local orientation 230 | wedge = 2 * np.pi / bin_num 231 | angle = angle_offset + max_anc * wedge 232 | 233 | # angle - 2pi, if exceed 2pi 234 | angle_l = angle % (2 * np.pi) 235 | 236 | # change to ray back to [-pi, pi] 237 | angle = angle_l + wedge / 2 - np.pi 238 | if angle > np.pi: 239 | angle -= 2 * np.pi 240 | angle = round(angle, 2) 241 | return angle 242 | 243 | 244 | def compute_orientaion(P2, obj): 245 | x = (obj.xmax + obj.xmin) / 2 246 | # compute camera orientation 247 | u_distance = x - P2[0, 2] 248 | focal_length = P2[0, 0] 249 | rot_ray = np.arctan(u_distance / focal_length) 250 | # global = alpha + ray 251 | rot_global = obj.alpha + rot_ray 252 | 253 | # local orientation, [0, 2 * pi] 254 | # rot_local = obj.alpha + np.pi / 2 255 | rot_local = get_new_alpha(obj.alpha) 256 | 257 | rot_global = round(rot_global, 2) 258 | return rot_global, rot_local 259 | 260 | 261 | def translation_constraints(P2, obj, rot_local): 262 | bbox = [obj.xmin, obj.ymin, obj.xmax, obj.ymax] 263 | # rotation matrix 264 | R = np.array([[ np.cos(obj.rot_global), 0, np.sin(obj.rot_global)], 265 | [ 0, 1, 0 ], 266 | [-np.sin(obj.rot_global), 0, np.cos(obj.rot_global)]]) 267 | A = np.zeros((4, 3)) 268 | b = np.zeros((4, 1)) 269 | I = np.identity(3) 270 | 271 | # object coordinate T, samply divide into xyz 272 | # bug1: h div 2 273 | xmin_candi, xmax_candi, ymin_candi, ymax_candi = obj.box3d_candidate(rot_local, soft_range=8) 274 | 275 | X = np.bmat([xmin_candi, xmax_candi, 276 | ymin_candi, ymax_candi]) 277 | # X: [x, y, z] in object coordinate 278 | X = X.reshape(4,3).T 279 | 280 | # construct equation (3, 4) 281 | # object four point in bev 282 | for i in range(4): 283 | # X[:,i] sames as Ti 284 | # matrice = [R T] * Xo 285 | matrice = np.bmat([[I, np.matmul(R, X[:,i])], [np.zeros((1,3)), np.ones((1,1))]]) 286 | # M = K * [R T] * Xo 287 | M = np.matmul(P2, matrice) 288 | 289 | if i % 2 == 0: 290 | A[i, :] = M[0, 0:3] - bbox[i] * M[2, 0:3] 291 | b[i, :] = M[2, 3] * bbox[i] - M[0, 3] 292 | 293 | else: 294 | A[i, :] = M[1, 0:3] - bbox[i] * M[2, 0:3] 295 | b[i, :] = M[2, 3] * bbox[i] - M[1, 3] 296 | # solve x, y, z, using method of least square 297 | Tran = np.matmul(np.linalg.pinv(A), b) 298 | 299 | tx, ty, tz = [float(np.around(tran, 2)) for tran in Tran] 300 | return tx, ty, tz 301 | 302 | 303 | class detectionInfo(object): 304 | def __init__(self, line): 305 | self.name = line[0] 306 | 307 | self.truncation = float(line[1]) 308 | self.occlusion = int(line[2]) 309 | 310 | # local orientation = alpha + pi/2 311 | self.alpha = float(line[3]) 312 | 313 | # in pixel coordinate 314 | self.xmin = float(line[4]) 315 | self.ymin = float(line[5]) 316 | self.xmax = float(line[6]) 317 | self.ymax = float(line[7]) 318 | 319 | # height, weigh, length in object coordinate, meter 320 | self.h = float(line[8]) 321 | self.w = float(line[9]) 322 | self.l = float(line[10]) 323 | 324 | # x, y, z in camera coordinate, meter 325 | self.tx = float(line[11]) 326 | self.ty = float(line[12]) 327 | self.tz = float(line[13]) 328 | 329 | # global orientation [-pi, pi] 330 | self.rot_global = float(line[14]) 331 | 332 | def member_to_list(self): 333 | output_line = [] 334 | for name, value in vars(self).items(): 335 | output_line.append(value) 336 | return output_line 337 | 338 | def box3d_candidate(self, rot_local, soft_range): 339 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0] 340 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0] 341 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0] 342 | 343 | x_corners = [i - self.l / 2 for i in x_corners] 344 | y_corners = [i - self.h / 2 for i in y_corners] 345 | z_corners = [i - self.w / 2 for i in z_corners] 346 | 347 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners])) 348 | point1 = corners_3d[0, :] 349 | point2 = corners_3d[1, :] 350 | point3 = corners_3d[2, :] 351 | point4 = corners_3d[3, :] 352 | point5 = corners_3d[6, :] 353 | point6 = corners_3d[7, :] 354 | point7 = corners_3d[4, :] 355 | point8 = corners_3d[5, :] 356 | 357 | # set up projection relation based on local orientation 358 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0 359 | 360 | if 0 < rot_local < np.pi / 2: 361 | xmin_candi = point8 362 | xmax_candi = point2 363 | ymin_candi = point2 364 | ymax_candi = point5 365 | 366 | if np.pi / 2 <= rot_local <= np.pi: 367 | xmin_candi = point6 368 | xmax_candi = point4 369 | ymin_candi = point4 370 | ymax_candi = point1 371 | 372 | if np.pi < rot_local <= 3 / 2 * np.pi: 373 | xmin_candi = point2 374 | xmax_candi = point8 375 | ymin_candi = point8 376 | ymax_candi = point1 377 | 378 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi: 379 | xmin_candi = point4 380 | xmax_candi = point6 381 | ymin_candi = point6 382 | ymax_candi = point5 383 | 384 | # soft constraint 385 | div = soft_range * np.pi / 180 386 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi: 387 | xmin_candi = point8 388 | xmax_candi = point6 389 | ymin_candi = point6 390 | ymax_candi = point5 391 | 392 | if np.pi - div < rot_local < np.pi + div: 393 | xmin_candi = point2 394 | xmax_candi = point4 395 | ymin_candi = point8 396 | ymax_candi = point1 397 | 398 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.pylogger import get_pylogger 2 | from src.utils.rich_utils import enforce_tags, print_config_tree 3 | from src.utils.utils import ( 4 | close_loggers, 5 | extras, 6 | get_metric_value, 7 | instantiate_callbacks, 8 | instantiate_loggers, 9 | log_hyperparameters, 10 | save_file, 11 | task_wrapper, 12 | ) 13 | -------------------------------------------------------------------------------- /src/utils/averages.py: -------------------------------------------------------------------------------- 1 | """Average dimension class""" 2 | 3 | from typing import List 4 | import numpy as np 5 | import os 6 | import json 7 | 8 | class DimensionAverages: 9 | """ 10 | Class to calculate the average dimensions of the objects in the dataset. 11 | """ 12 | def __init__( 13 | self, 14 | categories: List[str] = ['car', 'pedestrian', 'cyclist'], 15 | save_file: str = 'dimension_averages.txt' 16 | ): 17 | self.dimension_map = {} 18 | self.filename = os.path.abspath(os.path.dirname(__file__)) + '/' + save_file 19 | self.categories = categories 20 | 21 | if len(self.categories) == 0: 22 | self.load_items_from_file() 23 | 24 | for det in self.categories: 25 | cat_ = det.lower() 26 | if cat_ in self.dimension_map.keys(): 27 | continue 28 | self.dimension_map[cat_] = {} 29 | self.dimension_map[cat_]['count'] = 0 30 | self.dimension_map[cat_]['total'] = np.zeros(3, dtype=np.float32) 31 | 32 | def add_items(self, items_path): 33 | for path in items_path: 34 | with open(path, "r") as f: 35 | for line in f: 36 | line = line.split(" ") 37 | if line[0].lower() in self.categories: 38 | self.add_item( 39 | line[0], 40 | np.array([float(line[8]), float(line[9]), float(line[10])]) 41 | ) 42 | 43 | def add_item(self, cat, dim): 44 | cat = cat.lower() 45 | self.dimension_map[cat]['count'] += 1 46 | self.dimension_map[cat]['total'] += dim 47 | 48 | def get_item(self, cat): 49 | cat = cat.lower() 50 | return self.dimension_map[cat]['total'] / self.dimension_map[cat]['count'] 51 | 52 | def load_items_from_file(self): 53 | f = open(self.filename, 'r') 54 | dimension_map = json.load(f) 55 | 56 | for cat in dimension_map: 57 | dimension_map[cat]['total'] = np.asarray(dimension_map[cat]['total']) 58 | 59 | self.dimension_map = dimension_map 60 | 61 | def dump_to_file(self): 62 | f = open(self.filename, "w") 63 | f.write(json.dumps(self.dimension_map, cls=NumpyEncoder)) 64 | f.close() 65 | 66 | def recognized_class(self, cat): 67 | return cat.lower() in self.dimension_map 68 | 69 | class ClassAverages: 70 | def __init__(self, classes=[]): 71 | self.dimension_map = {} 72 | self.filename = os.path.abspath(os.path.dirname(__file__)) + '/class_averages.txt' 73 | 74 | if len(classes) == 0: # eval mode 75 | self.load_items_from_file() 76 | 77 | for detection_class in classes: 78 | class_ = detection_class.lower() 79 | if class_ in self.dimension_map.keys(): 80 | continue 81 | self.dimension_map[class_] = {} 82 | self.dimension_map[class_]['count'] = 0 83 | self.dimension_map[class_]['total'] = np.zeros(3, dtype=np.double) 84 | 85 | 86 | def add_item(self, class_, dimension): 87 | class_ = class_.lower() 88 | self.dimension_map[class_]['count'] += 1 89 | self.dimension_map[class_]['total'] += dimension 90 | # self.dimension_map[class_]['total'] /= self.dimension_map[class_]['count'] 91 | 92 | def get_item(self, class_): 93 | class_ = class_.lower() 94 | return self.dimension_map[class_]['total'] / self.dimension_map[class_]['count'] 95 | 96 | def dump_to_file(self): 97 | f = open(self.filename, "w") 98 | f.write(json.dumps(self.dimension_map, cls=NumpyEncoder)) 99 | f.close() 100 | 101 | def load_items_from_file(self): 102 | f = open(self.filename, 'r') 103 | dimension_map = json.load(f) 104 | 105 | for class_ in dimension_map: 106 | dimension_map[class_]['total'] = np.asarray(dimension_map[class_]['total']) 107 | 108 | self.dimension_map = dimension_map 109 | 110 | def recognized_class(self, class_): 111 | return class_.lower() in self.dimension_map 112 | 113 | class NumpyEncoder(json.JSONEncoder): 114 | def default(self, obj): 115 | if isinstance(obj, np.ndarray): 116 | return obj.tolist() 117 | return json.JSONEncoder.default(self,obj) -------------------------------------------------------------------------------- /src/utils/class_averages-L4.txt: -------------------------------------------------------------------------------- 1 | {"car": {"count": 15939, "total": [25041.166112000075, 27306.989660000134, 68537.86737500015]}, "pedestrian": {"count": 1793, "total": [3018.555634000002, 974.4651910000001, 847.5994529999997]}, "cyclist": {"count": 116, "total": [169.2123550000001, 54.55659699999999, 187.05904800000002]}, "truck": {"count": 741, "total": [2219.4890700000014, 1835.661420999999, 6906.846059999999]}, "van": {"count": 632, "total": [1366.3955200000005, 1232.9152299999998, 3155.905800000001]}, "trafficcone": {"count": 0, "total": [0.0, 0.0, 0.0]}, "unknown": {"count": 4294, "total": [4924.141288999996, 3907.031903999998, 6185.184788000007]}} -------------------------------------------------------------------------------- /src/utils/class_averages-kitti6.txt: -------------------------------------------------------------------------------- 1 | {"car": {"count": 14385, "total": [21898.43999999967, 23568.289999999495, 55754.239999999765]}, "cyclist": {"count": 893, "total": [1561.5099999999982, 552.850000000001, 1569.5100000000007]}, "truck": {"count": 606, "total": [1916.6199999999872, 1554.710000000011, 6567.400000000018]}, "van": {"count": 1617, "total": [3593.439999999989, 3061.370000000014, 8122.769999999951]}, "pedestrian": {"count": 2280, "total": [3998.8900000000003, 1576.6400000000049, 1974.090000000009]}, "tram": {"count": 287, "total": [1012.7000000000005, 771.13, 4739.249999999991]}} -------------------------------------------------------------------------------- /src/utils/class_averages.txt: -------------------------------------------------------------------------------- 1 | {"car": {"count": 14385, "total": [21898.43999999967, 23568.289999999495, 55754.239999999765]}, "cyclist": {"count": 893, "total": [1561.5099999999982, 552.850000000001, 1569.5100000000007]}, "truck": {"count": 606, "total": [1916.6199999999872, 1554.710000000011, 6567.400000000018]}, "van": {"count": 1617, "total": [3593.439999999989, 3061.370000000014, 8122.769999999951]}, "pedestrian": {"count": 2280, "total": [3998.8900000000003, 1576.6400000000049, 1974.090000000009]}, "tram": {"count": 287, "total": [1012.7000000000005, 771.13, 4739.249999999991]}} -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from rich.prompt import Prompt 11 | 12 | from src.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "datamodule", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints content of DictConfig using Rich library and its tree structure. 33 | 34 | Args: 35 | cfg (DictConfig): Configuration composed by Hydra. 36 | print_order (Sequence[str], optional): Determines in what order config components are printed. 37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 38 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 39 | """ 40 | 41 | style = "dim" 42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 43 | 44 | queue = [] 45 | 46 | # add fields from `print_order` to queue 47 | for field in print_order: 48 | queue.append(field) if field in cfg else log.warning( 49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 50 | ) 51 | 52 | # add all the other fields to queue (not specified in `print_order`) 53 | for field in cfg: 54 | if field not in queue: 55 | queue.append(field) 56 | 57 | # generate config tree from queue 58 | for field in queue: 59 | branch = tree.add(field, style=style, guide_style=style) 60 | 61 | config_group = cfg[field] 62 | if isinstance(config_group, DictConfig): 63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 64 | else: 65 | branch_content = str(config_group) 66 | 67 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 68 | 69 | # print config tree 70 | rich.print(tree) 71 | 72 | # save config tree to file 73 | if save_to_file: 74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 75 | rich.print(tree, file=file) 76 | 77 | 78 | @rank_zero_only 79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 80 | """Prompts user to input tags from command line if no tags are provided in config.""" 81 | 82 | if not cfg.get("tags"): 83 | if "id" in HydraConfig().cfg.hydra.job: 84 | raise ValueError("Specify tags before launching a multirun!") 85 | 86 | log.warning("No tags provided in config. Prompting user to input tags...") 87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 88 | tags = [t.strip() for t in tags.split(",") if t != ""] 89 | 90 | with open_dict(cfg): 91 | cfg.tags = tags 92 | 93 | log.info(f"Tags: {cfg.tags}") 94 | 95 | if save_to_file: 96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 97 | rich.print(cfg.tags, file=file) 98 | 99 | 100 | if __name__ == "__main__": 101 | from hydra import compose, initialize 102 | 103 | with initialize(version_base="1.2", config_path="../../configs"): 104 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[]) 105 | print_config_tree(cfg, resolve=False, save_to_file=False) 106 | -------------------------------------------------------------------------------- /src/utils/rotate_iou.py: -------------------------------------------------------------------------------- 1 | ##################### 2 | # Based on https://github.com/hongzhenwang/RRPN-revise 3 | # Licensed under The MIT License 4 | # Author: yanyan, scrin@foxmail.com 5 | ##################### 6 | import math 7 | 8 | import numba 9 | import numpy as np 10 | from numba import cuda 11 | 12 | @numba.jit(nopython=True) 13 | def div_up(m, n): 14 | return m // n + (m % n > 0) 15 | 16 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 17 | def trangle_area(a, b, c): 18 | return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) * 19 | (b[0] - c[0])) / 2.0 20 | 21 | 22 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 23 | def area(int_pts, num_of_inter): 24 | area_val = 0.0 25 | for i in range(num_of_inter - 2): 26 | area_val += abs( 27 | trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4], 28 | int_pts[2 * i + 4:2 * i + 6])) 29 | return area_val 30 | 31 | 32 | @cuda.jit('(float32[:], int32)', device=True, inline=True) 33 | def sort_vertex_in_convex_polygon(int_pts, num_of_inter): 34 | if num_of_inter > 0: 35 | center = cuda.local.array((2, ), dtype=numba.float32) 36 | center[:] = 0.0 37 | for i in range(num_of_inter): 38 | center[0] += int_pts[2 * i] 39 | center[1] += int_pts[2 * i + 1] 40 | center[0] /= num_of_inter 41 | center[1] /= num_of_inter 42 | v = cuda.local.array((2, ), dtype=numba.float32) 43 | vs = cuda.local.array((16, ), dtype=numba.float32) 44 | for i in range(num_of_inter): 45 | v[0] = int_pts[2 * i] - center[0] 46 | v[1] = int_pts[2 * i + 1] - center[1] 47 | d = math.sqrt(v[0] * v[0] + v[1] * v[1]) 48 | v[0] = v[0] / d 49 | v[1] = v[1] / d 50 | if v[1] < 0: 51 | v[0] = -2 - v[0] 52 | vs[i] = v[0] 53 | j = 0 54 | temp = 0 55 | for i in range(1, num_of_inter): 56 | if vs[i - 1] > vs[i]: 57 | temp = vs[i] 58 | tx = int_pts[2 * i] 59 | ty = int_pts[2 * i + 1] 60 | j = i 61 | while j > 0 and vs[j - 1] > temp: 62 | vs[j] = vs[j - 1] 63 | int_pts[j * 2] = int_pts[j * 2 - 2] 64 | int_pts[j * 2 + 1] = int_pts[j * 2 - 1] 65 | j -= 1 66 | 67 | vs[j] = temp 68 | int_pts[j * 2] = tx 69 | int_pts[j * 2 + 1] = ty 70 | 71 | 72 | @cuda.jit( 73 | '(float32[:], float32[:], int32, int32, float32[:])', 74 | device=True, 75 | inline=True) 76 | def line_segment_intersection(pts1, pts2, i, j, temp_pts): 77 | A = cuda.local.array((2, ), dtype=numba.float32) 78 | B = cuda.local.array((2, ), dtype=numba.float32) 79 | C = cuda.local.array((2, ), dtype=numba.float32) 80 | D = cuda.local.array((2, ), dtype=numba.float32) 81 | 82 | A[0] = pts1[2 * i] 83 | A[1] = pts1[2 * i + 1] 84 | 85 | B[0] = pts1[2 * ((i + 1) % 4)] 86 | B[1] = pts1[2 * ((i + 1) % 4) + 1] 87 | 88 | C[0] = pts2[2 * j] 89 | C[1] = pts2[2 * j + 1] 90 | 91 | D[0] = pts2[2 * ((j + 1) % 4)] 92 | D[1] = pts2[2 * ((j + 1) % 4) + 1] 93 | BA0 = B[0] - A[0] 94 | BA1 = B[1] - A[1] 95 | DA0 = D[0] - A[0] 96 | CA0 = C[0] - A[0] 97 | DA1 = D[1] - A[1] 98 | CA1 = C[1] - A[1] 99 | acd = DA1 * CA0 > CA1 * DA0 100 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0]) 101 | if acd != bcd: 102 | abc = CA1 * BA0 > BA1 * CA0 103 | abd = DA1 * BA0 > BA1 * DA0 104 | if abc != abd: 105 | DC0 = D[0] - C[0] 106 | DC1 = D[1] - C[1] 107 | ABBA = A[0] * B[1] - B[0] * A[1] 108 | CDDC = C[0] * D[1] - D[0] * C[1] 109 | DH = BA1 * DC0 - BA0 * DC1 110 | Dx = ABBA * DC0 - BA0 * CDDC 111 | Dy = ABBA * DC1 - BA1 * CDDC 112 | temp_pts[0] = Dx / DH 113 | temp_pts[1] = Dy / DH 114 | return True 115 | return False 116 | 117 | 118 | @cuda.jit( 119 | '(float32[:], float32[:], int32, int32, float32[:])', 120 | device=True, 121 | inline=True) 122 | def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts): 123 | a = cuda.local.array((2, ), dtype=numba.float32) 124 | b = cuda.local.array((2, ), dtype=numba.float32) 125 | c = cuda.local.array((2, ), dtype=numba.float32) 126 | d = cuda.local.array((2, ), dtype=numba.float32) 127 | 128 | a[0] = pts1[2 * i] 129 | a[1] = pts1[2 * i + 1] 130 | 131 | b[0] = pts1[2 * ((i + 1) % 4)] 132 | b[1] = pts1[2 * ((i + 1) % 4) + 1] 133 | 134 | c[0] = pts2[2 * j] 135 | c[1] = pts2[2 * j + 1] 136 | 137 | d[0] = pts2[2 * ((j + 1) % 4)] 138 | d[1] = pts2[2 * ((j + 1) % 4) + 1] 139 | 140 | area_abc = trangle_area(a, b, c) 141 | area_abd = trangle_area(a, b, d) 142 | 143 | if area_abc * area_abd >= 0: 144 | return False 145 | 146 | area_cda = trangle_area(c, d, a) 147 | area_cdb = area_cda + area_abc - area_abd 148 | 149 | if area_cda * area_cdb >= 0: 150 | return False 151 | t = area_cda / (area_abd - area_abc) 152 | 153 | dx = t * (b[0] - a[0]) 154 | dy = t * (b[1] - a[1]) 155 | temp_pts[0] = a[0] + dx 156 | temp_pts[1] = a[1] + dy 157 | return True 158 | 159 | 160 | @cuda.jit('(float32, float32, float32[:])', device=True, inline=True) 161 | def point_in_quadrilateral(pt_x, pt_y, corners): 162 | ab0 = corners[2] - corners[0] 163 | ab1 = corners[3] - corners[1] 164 | 165 | ad0 = corners[6] - corners[0] 166 | ad1 = corners[7] - corners[1] 167 | 168 | ap0 = pt_x - corners[0] 169 | ap1 = pt_y - corners[1] 170 | 171 | abab = ab0 * ab0 + ab1 * ab1 172 | abap = ab0 * ap0 + ab1 * ap1 173 | adad = ad0 * ad0 + ad1 * ad1 174 | adap = ad0 * ap0 + ad1 * ap1 175 | 176 | return abab >= abap and abap >= 0 and adad >= adap and adap >= 0 177 | 178 | 179 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True) 180 | def quadrilateral_intersection(pts1, pts2, int_pts): 181 | num_of_inter = 0 182 | for i in range(4): 183 | if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2): 184 | int_pts[num_of_inter * 2] = pts1[2 * i] 185 | int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1] 186 | num_of_inter += 1 187 | if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1): 188 | int_pts[num_of_inter * 2] = pts2[2 * i] 189 | int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1] 190 | num_of_inter += 1 191 | temp_pts = cuda.local.array((2, ), dtype=numba.float32) 192 | for i in range(4): 193 | for j in range(4): 194 | has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts) 195 | if has_pts: 196 | int_pts[num_of_inter * 2] = temp_pts[0] 197 | int_pts[num_of_inter * 2 + 1] = temp_pts[1] 198 | num_of_inter += 1 199 | 200 | return num_of_inter 201 | 202 | 203 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 204 | def rbbox_to_corners(corners, rbbox): 205 | # generate clockwise corners and rotate it clockwise 206 | angle = rbbox[4] 207 | a_cos = math.cos(angle) 208 | a_sin = math.sin(angle) 209 | center_x = rbbox[0] 210 | center_y = rbbox[1] 211 | x_d = rbbox[2] 212 | y_d = rbbox[3] 213 | corners_x = cuda.local.array((4, ), dtype=numba.float32) 214 | corners_y = cuda.local.array((4, ), dtype=numba.float32) 215 | corners_x[0] = -x_d / 2 216 | corners_x[1] = -x_d / 2 217 | corners_x[2] = x_d / 2 218 | corners_x[3] = x_d / 2 219 | corners_y[0] = -y_d / 2 220 | corners_y[1] = y_d / 2 221 | corners_y[2] = y_d / 2 222 | corners_y[3] = -y_d / 2 223 | for i in range(4): 224 | corners[2 * 225 | i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x 226 | corners[2 * i 227 | + 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y 228 | 229 | 230 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True) 231 | def inter(rbbox1, rbbox2): 232 | corners1 = cuda.local.array((8, ), dtype=numba.float32) 233 | corners2 = cuda.local.array((8, ), dtype=numba.float32) 234 | intersection_corners = cuda.local.array((16, ), dtype=numba.float32) 235 | 236 | rbbox_to_corners(corners1, rbbox1) 237 | rbbox_to_corners(corners2, rbbox2) 238 | 239 | num_intersection = quadrilateral_intersection(corners1, corners2, 240 | intersection_corners) 241 | sort_vertex_in_convex_polygon(intersection_corners, num_intersection) 242 | # print(intersection_corners.reshape([-1, 2])[:num_intersection]) 243 | 244 | return area(intersection_corners, num_intersection) 245 | 246 | 247 | @cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True) 248 | def devRotateIoUEval(rbox1, rbox2, criterion=-1): 249 | area1 = rbox1[2] * rbox1[3] 250 | area2 = rbox2[2] * rbox2[3] 251 | area_inter = inter(rbox1, rbox2) 252 | if criterion == -1: 253 | return area_inter / (area1 + area2 - area_inter) 254 | elif criterion == 0: 255 | return area_inter / area1 256 | elif criterion == 1: 257 | return area_inter / area2 258 | else: 259 | return area_inter 260 | 261 | @cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False) 262 | def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1): 263 | threadsPerBlock = 8 * 8 264 | row_start = cuda.blockIdx.x 265 | col_start = cuda.blockIdx.y 266 | tx = cuda.threadIdx.x 267 | row_size = min(N - row_start * threadsPerBlock, threadsPerBlock) 268 | col_size = min(K - col_start * threadsPerBlock, threadsPerBlock) 269 | block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 270 | block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32) 271 | 272 | dev_query_box_idx = threadsPerBlock * col_start + tx 273 | dev_box_idx = threadsPerBlock * row_start + tx 274 | if (tx < col_size): 275 | block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0] 276 | block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1] 277 | block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2] 278 | block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3] 279 | block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4] 280 | if (tx < row_size): 281 | block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0] 282 | block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1] 283 | block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2] 284 | block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3] 285 | block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4] 286 | cuda.syncthreads() 287 | if tx < row_size: 288 | for i in range(col_size): 289 | offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i 290 | dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5], 291 | block_boxes[tx * 5:tx * 5 + 5], criterion) 292 | 293 | 294 | def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0): 295 | """rotated box iou running in gpu. 500x faster than cpu version 296 | (take 5ms in one example with numba.cuda code). 297 | convert from [this project]( 298 | https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation). 299 | 300 | Args: 301 | boxes (float tensor: [N, 5]): rbboxes. format: centers, dims, 302 | angles(clockwise when positive) 303 | query_boxes (float tensor: [K, 5]): [description] 304 | device_id (int, optional): Defaults to 0. [description] 305 | 306 | Returns: 307 | [type]: [description] 308 | """ 309 | box_dtype = boxes.dtype 310 | boxes = boxes.astype(np.float32) 311 | query_boxes = query_boxes.astype(np.float32) 312 | N = boxes.shape[0] 313 | K = query_boxes.shape[0] 314 | iou = np.zeros((N, K), dtype=np.float32) 315 | if N == 0 or K == 0: 316 | return iou 317 | threadsPerBlock = 8 * 8 318 | cuda.select_device(device_id) 319 | blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock)) 320 | 321 | stream = cuda.stream() 322 | with stream.auto_synchronize(): 323 | boxes_dev = cuda.to_device(boxes.reshape([-1]), stream) 324 | query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream) 325 | iou_dev = cuda.to_device(iou.reshape([-1]), stream) 326 | rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream]( 327 | N, K, boxes_dev, query_boxes_dev, iou_dev, criterion) 328 | iou_dev.copy_to_host(iou.reshape([-1]), stream=stream) 329 | return iou.astype(boxes.dtype) 330 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | from importlib.util import find_spec 4 | from pathlib import Path 5 | from typing import Any, Callable, Dict, List 6 | import numpy as np 7 | 8 | import hydra 9 | from omegaconf import DictConfig 10 | from pytorch_lightning import Callback 11 | from pytorch_lightning.loggers import LightningLoggerBase 12 | from pytorch_lightning.utilities import rank_zero_only 13 | 14 | from src.utils import pylogger, rich_utils 15 | 16 | log = pylogger.get_pylogger(__name__) 17 | 18 | 19 | def task_wrapper(task_func: Callable) -> Callable: 20 | """Optional decorator that wraps the task function in extra utilities. 21 | 22 | Makes multirun more resistant to failure. 23 | 24 | Utilities: 25 | - Calling the `utils.extras()` before the task is started 26 | - Calling the `utils.close_loggers()` after the task is finished 27 | - Logging the exception if occurs 28 | - Logging the task total execution time 29 | - Logging the output dir 30 | """ 31 | 32 | def wrap(cfg: DictConfig): 33 | 34 | # apply extra utilities 35 | extras(cfg) 36 | 37 | # execute the task 38 | try: 39 | start_time = time.time() 40 | metric_dict, object_dict = task_func(cfg=cfg) 41 | except Exception as ex: 42 | log.exception("") # save exception to `.log` file 43 | raise ex 44 | finally: 45 | path = Path(cfg.paths.output_dir, "exec_time.log") 46 | content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" 47 | save_file(path, content) # save task execution time (even if exception occurs) 48 | close_loggers() # close loggers (even if exception occurs so multirun won't fail) 49 | 50 | log.info(f"Output dir: {cfg.paths.output_dir}") 51 | 52 | return metric_dict, object_dict 53 | 54 | return wrap 55 | 56 | 57 | def extras(cfg: DictConfig) -> None: 58 | """Applies optional utilities before the task is started. 59 | 60 | Utilities: 61 | - Ignoring python warnings 62 | - Setting tags from command line 63 | - Rich config printing 64 | """ 65 | 66 | # return if no `extras` config 67 | if not cfg.get("extras"): 68 | log.warning("Extras config not found! ") 69 | return 70 | 71 | # disable python warnings 72 | if cfg.extras.get("ignore_warnings"): 73 | log.info("Disabling python warnings! ") 74 | warnings.filterwarnings("ignore") 75 | 76 | # prompt user to input tags from command line if none are provided in the config 77 | if cfg.extras.get("enforce_tags"): 78 | log.info("Enforcing tags! ") 79 | rich_utils.enforce_tags(cfg, save_to_file=True) 80 | 81 | # pretty print config tree using Rich library 82 | if cfg.extras.get("print_config"): 83 | log.info("Printing config tree with Rich! ") 84 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 85 | 86 | 87 | @rank_zero_only 88 | def save_file(path: str, content: str) -> None: 89 | """Save file in rank zero mode (only on one process in multi-GPU setup).""" 90 | with open(path, "w+") as file: 91 | file.write(content) 92 | 93 | 94 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 95 | """Instantiates callbacks from config.""" 96 | callbacks: List[Callback] = [] 97 | 98 | if not callbacks_cfg: 99 | log.warning("Callbacks config is empty.") 100 | return callbacks 101 | 102 | if not isinstance(callbacks_cfg, DictConfig): 103 | raise TypeError("Callbacks config must be a DictConfig!") 104 | 105 | for _, cb_conf in callbacks_cfg.items(): 106 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 107 | log.info(f"Instantiating callback <{cb_conf._target_}>") 108 | callbacks.append(hydra.utils.instantiate(cb_conf)) 109 | 110 | return callbacks 111 | 112 | 113 | def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]: 114 | """Instantiates loggers from config.""" 115 | logger: List[LightningLoggerBase] = [] 116 | 117 | if not logger_cfg: 118 | log.warning("Logger config is empty.") 119 | return logger 120 | 121 | if not isinstance(logger_cfg, DictConfig): 122 | raise TypeError("Logger config must be a DictConfig!") 123 | 124 | for _, lg_conf in logger_cfg.items(): 125 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 126 | log.info(f"Instantiating logger <{lg_conf._target_}>") 127 | logger.append(hydra.utils.instantiate(lg_conf)) 128 | 129 | return logger 130 | 131 | 132 | @rank_zero_only 133 | def log_hyperparameters(object_dict: dict) -> None: 134 | """Controls which config parts are saved by lightning loggers. 135 | 136 | Additionally saves: 137 | - Number of model parameters 138 | """ 139 | 140 | hparams = {} 141 | 142 | cfg = object_dict["cfg"] 143 | model = object_dict["model"] 144 | trainer = object_dict["trainer"] 145 | 146 | if not trainer.logger: 147 | log.warning("Logger not found! Skipping hyperparameter logging...") 148 | return 149 | 150 | hparams["model"] = cfg["model"] 151 | 152 | # save number of model parameters 153 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 154 | hparams["model/params/trainable"] = sum( 155 | p.numel() for p in model.parameters() if p.requires_grad 156 | ) 157 | hparams["model/params/non_trainable"] = sum( 158 | p.numel() for p in model.parameters() if not p.requires_grad 159 | ) 160 | 161 | hparams["datamodule"] = cfg["datamodule"] 162 | hparams["trainer"] = cfg["trainer"] 163 | 164 | hparams["callbacks"] = cfg.get("callbacks") 165 | hparams["extras"] = cfg.get("extras") 166 | 167 | hparams["task_name"] = cfg.get("task_name") 168 | hparams["tags"] = cfg.get("tags") 169 | hparams["ckpt_path"] = cfg.get("ckpt_path") 170 | hparams["seed"] = cfg.get("seed") 171 | 172 | # send hparams to all loggers 173 | trainer.logger.log_hyperparams(hparams) 174 | 175 | 176 | def get_metric_value(metric_dict: dict, metric_name: str) -> float: 177 | """Safely retrieves value of the metric logged in LightningModule.""" 178 | 179 | if not metric_name: 180 | log.info("Metric name is None! Skipping metric value retrieval...") 181 | return None 182 | 183 | if metric_name not in metric_dict: 184 | raise Exception( 185 | f"Metric value not found! \n" 186 | "Make sure metric name logged in LightningModule is correct!\n" 187 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 188 | ) 189 | 190 | metric_value = metric_dict[metric_name].item() 191 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 192 | 193 | return metric_value 194 | 195 | 196 | def close_loggers() -> None: 197 | """Makes sure all loggers closed properly (prevents logging failure during multirun).""" 198 | 199 | log.info("Closing loggers...") 200 | 201 | if find_spec("wandb"): # if wandb is installed 202 | import wandb 203 | 204 | if wandb.run: 205 | log.info("Closing wandb!") 206 | wandb.finish() 207 | 208 | class detectionInfo(object): 209 | """ 210 | utils for YOLO3D 211 | detectionInfo is a class that contains information about the detection 212 | """ 213 | def __init__(self, line): 214 | self.name = line[0] 215 | 216 | self.truncation = float(line[1]) 217 | self.occlusion = int(line[2]) 218 | 219 | # local orientation = alpha + pi/2 220 | self.alpha = float(line[3]) 221 | 222 | # in pixel coordinate 223 | self.xmin = float(line[4]) 224 | self.ymin = float(line[5]) 225 | self.xmax = float(line[6]) 226 | self.ymax = float(line[7]) 227 | 228 | # height, weigh, length in object coordinate, meter 229 | self.h = float(line[8]) 230 | self.w = float(line[9]) 231 | self.l = float(line[10]) 232 | 233 | # x, y, z in camera coordinate, meter 234 | self.tx = float(line[11]) 235 | self.ty = float(line[12]) 236 | self.tz = float(line[13]) 237 | 238 | # global orientation [-pi, pi] 239 | self.rot_global = float(line[14]) 240 | 241 | def member_to_list(self): 242 | output_line = [] 243 | for name, value in vars(self).items(): 244 | output_line.append(value) 245 | return output_line 246 | 247 | def box3d_candidate(self, rot_local, soft_range): 248 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0] 249 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0] 250 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0] 251 | 252 | x_corners = [i - self.l / 2 for i in x_corners] 253 | y_corners = [i - self.h / 2 for i in y_corners] 254 | z_corners = [i - self.w / 2 for i in z_corners] 255 | 256 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners])) 257 | point1 = corners_3d[0, :] 258 | point2 = corners_3d[1, :] 259 | point3 = corners_3d[2, :] 260 | point4 = corners_3d[3, :] 261 | point5 = corners_3d[6, :] 262 | point6 = corners_3d[7, :] 263 | point7 = corners_3d[4, :] 264 | point8 = corners_3d[5, :] 265 | 266 | # set up projection relation based on local orientation 267 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0 268 | 269 | if 0 < rot_local < np.pi / 2: 270 | xmin_candi = point8 271 | xmax_candi = point2 272 | ymin_candi = point2 273 | ymax_candi = point5 274 | 275 | if np.pi / 2 <= rot_local <= np.pi: 276 | xmin_candi = point6 277 | xmax_candi = point4 278 | ymin_candi = point4 279 | ymax_candi = point1 280 | 281 | if np.pi < rot_local <= 3 / 2 * np.pi: 282 | xmin_candi = point2 283 | xmax_candi = point8 284 | ymin_candi = point8 285 | ymax_candi = point1 286 | 287 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi: 288 | xmin_candi = point4 289 | xmax_candi = point6 290 | ymin_candi = point6 291 | ymax_candi = point5 292 | 293 | # soft constraint 294 | div = soft_range * np.pi / 180 295 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi: 296 | xmin_candi = point8 297 | xmax_candi = point6 298 | ymin_candi = point6 299 | ymax_candi = point5 300 | 301 | if np.pi - div < rot_local < np.pi + div: 302 | xmin_candi = point2 303 | xmax_candi = point4 304 | ymin_candi = point8 305 | ymax_candi = point1 306 | 307 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi 308 | 309 | 310 | class KITTIObject(): 311 | """ 312 | utils for YOLO3D 313 | detectionInfo is a class that contains information about the detection 314 | """ 315 | def __init__(self, line = np.zeros(16)): 316 | self.name = line[0] 317 | 318 | self.truncation = float(line[1]) 319 | self.occlusion = int(line[2]) 320 | 321 | # local orientation = alpha + pi/2 322 | self.alpha = float(line[3]) 323 | 324 | # in pixel coordinate 325 | self.xmin = float(line[4]) 326 | self.ymin = float(line[5]) 327 | self.xmax = float(line[6]) 328 | self.ymax = float(line[7]) 329 | 330 | # height, weigh, length in object coordinate, meter 331 | self.h = float(line[8]) 332 | self.w = float(line[9]) 333 | self.l = float(line[10]) 334 | 335 | # x, y, z in camera coordinate, meter 336 | self.tx = float(line[11]) 337 | self.ty = float(line[12]) 338 | self.tz = float(line[13]) 339 | 340 | # global orientation [-pi, pi] 341 | self.rot_global = float(line[14]) 342 | 343 | # score 344 | self.score = float(line[15]) 345 | 346 | def member_to_list(self): 347 | output_line = [] 348 | for name, value in vars(self).items(): 349 | output_line.append(value) 350 | return output_line 351 | 352 | def box3d_candidate(self, rot_local, soft_range): 353 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0] 354 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0] 355 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0] 356 | 357 | x_corners = [i - self.l / 2 for i in x_corners] 358 | y_corners = [i - self.h / 2 for i in y_corners] 359 | z_corners = [i - self.w / 2 for i in z_corners] 360 | 361 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners])) 362 | point1 = corners_3d[0, :] 363 | point2 = corners_3d[1, :] 364 | point3 = corners_3d[2, :] 365 | point4 = corners_3d[3, :] 366 | point5 = corners_3d[6, :] 367 | point6 = corners_3d[7, :] 368 | point7 = corners_3d[4, :] 369 | point8 = corners_3d[5, :] 370 | 371 | # set up projection relation based on local orientation 372 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0 373 | 374 | if 0 < rot_local < np.pi / 2: 375 | xmin_candi = point8 376 | xmax_candi = point2 377 | ymin_candi = point2 378 | ymax_candi = point5 379 | 380 | if np.pi / 2 <= rot_local <= np.pi: 381 | xmin_candi = point6 382 | xmax_candi = point4 383 | ymin_candi = point4 384 | ymax_candi = point1 385 | 386 | if np.pi < rot_local <= 3 / 2 * np.pi: 387 | xmin_candi = point2 388 | xmax_candi = point8 389 | ymin_candi = point8 390 | ymax_candi = point1 391 | 392 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi: 393 | xmin_candi = point4 394 | xmax_candi = point6 395 | ymin_candi = point6 396 | ymax_candi = point5 397 | 398 | # soft constraint 399 | div = soft_range * np.pi / 180 400 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi: 401 | xmin_candi = point8 402 | xmax_candi = point6 403 | ymin_candi = point6 404 | ymax_candi = point5 405 | 406 | if np.pi - div < rot_local < np.pi + div: 407 | xmin_candi = point2 408 | xmax_candi = point4 409 | ymin_candi = point8 410 | ymax_candi = point1 411 | 412 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi 413 | 414 | if __name__ == "__main__": 415 | pass -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pyrootutils 2 | import pytest 3 | from hydra import compose, initialize 4 | from hydra.core.global_hydra import GlobalHydra 5 | from omegaconf import DictConfig, open_dict 6 | 7 | 8 | @pytest.fixture(scope="package") 9 | def cfg_train_global() -> DictConfig: 10 | with initialize(version_base="1.2", config_path="../configs"): 11 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) 12 | 13 | # set defaults for all tests 14 | with open_dict(cfg): 15 | cfg.paths.root_dir = str(pyrootutils.find_root()) 16 | cfg.trainer.max_epochs = 1 17 | cfg.trainer.limit_train_batches = 0.01 18 | cfg.trainer.limit_val_batches = 0.1 19 | cfg.trainer.limit_test_batches = 0.1 20 | cfg.trainer.accelerator = "cpu" 21 | cfg.trainer.devices = 1 22 | cfg.datamodule.num_workers = 0 23 | cfg.datamodule.pin_memory = False 24 | cfg.extras.print_config = False 25 | cfg.extras.enforce_tags = False 26 | cfg.logger = None 27 | 28 | return cfg 29 | 30 | 31 | @pytest.fixture(scope="package") 32 | def cfg_eval_global() -> DictConfig: 33 | with initialize(version_base="1.2", config_path="../configs"): 34 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) 35 | 36 | # set defaults for all tests 37 | with open_dict(cfg): 38 | cfg.paths.root_dir = str(pyrootutils.find_root()) 39 | cfg.trainer.max_epochs = 1 40 | cfg.trainer.limit_test_batches = 0.1 41 | cfg.trainer.accelerator = "cpu" 42 | cfg.trainer.devices = 1 43 | cfg.datamodule.num_workers = 0 44 | cfg.datamodule.pin_memory = False 45 | cfg.extras.print_config = False 46 | cfg.extras.enforce_tags = False 47 | cfg.logger = None 48 | 49 | return cfg 50 | 51 | 52 | # this is called by each test which uses `cfg_train` arg 53 | # each test generates its own temporary logging path 54 | @pytest.fixture(scope="function") 55 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig: 56 | cfg = cfg_train_global.copy() 57 | 58 | with open_dict(cfg): 59 | cfg.paths.output_dir = str(tmp_path) 60 | cfg.paths.log_dir = str(tmp_path) 61 | 62 | yield cfg 63 | 64 | GlobalHydra.instance().clear() 65 | 66 | 67 | # this is called by each test which uses `cfg_eval` arg 68 | # each test generates its own temporary logging path 69 | @pytest.fixture(scope="function") 70 | def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig: 71 | cfg = cfg_eval_global.copy() 72 | 73 | with open_dict(cfg): 74 | cfg.paths.output_dir = str(tmp_path) 75 | cfg.paths.log_dir = str(tmp_path) 76 | 77 | yield cfg 78 | 79 | GlobalHydra.instance().clear() 80 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tests/helpers/__init__.py -------------------------------------------------------------------------------- /tests/helpers/package_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import pkg_resources 4 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils 5 | 6 | 7 | def _package_available(package_name: str) -> bool: 8 | """Check if a package is available in your environment.""" 9 | try: 10 | return pkg_resources.require(package_name) is not None 11 | except pkg_resources.DistributionNotFound: 12 | return False 13 | 14 | 15 | _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() 16 | 17 | _IS_WINDOWS = platform.system() == "Windows" 18 | 19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") 20 | 21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") 22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") 23 | 24 | _WANDB_AVAILABLE = _package_available("wandb") 25 | _NEPTUNE_AVAILABLE = _package_available("neptune") 26 | _COMET_AVAILABLE = _package_available("comet_ml") 27 | _MLFLOW_AVAILABLE = _package_available("mlflow") 28 | -------------------------------------------------------------------------------- /tests/helpers/run_if.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | 3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 4 | """ 5 | 6 | import sys 7 | from typing import Optional 8 | 9 | import pytest 10 | import torch 11 | from packaging.version import Version 12 | from pkg_resources import get_distribution 13 | 14 | from tests.helpers.package_available import ( 15 | _COMET_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _MLFLOW_AVAILABLE, 20 | _NEPTUNE_AVAILABLE, 21 | _SH_AVAILABLE, 22 | _TPU_AVAILABLE, 23 | _WANDB_AVAILABLE, 24 | ) 25 | 26 | 27 | class RunIf: 28 | """RunIf wrapper for conditional skipping of tests. 29 | 30 | Fully compatible with `@pytest.mark`. 31 | 32 | Example: 33 | 34 | @RunIf(min_torch="1.8") 35 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 36 | def test_wrapper(arg1): 37 | assert arg1 > 0 38 | """ 39 | 40 | def __new__( 41 | self, 42 | min_gpus: int = 0, 43 | min_torch: Optional[str] = None, 44 | max_torch: Optional[str] = None, 45 | min_python: Optional[str] = None, 46 | skip_windows: bool = False, 47 | sh: bool = False, 48 | tpu: bool = False, 49 | fairscale: bool = False, 50 | deepspeed: bool = False, 51 | wandb: bool = False, 52 | neptune: bool = False, 53 | comet: bool = False, 54 | mlflow: bool = False, 55 | **kwargs, 56 | ): 57 | """ 58 | Args: 59 | min_gpus: min number of GPUs required to run test 60 | min_torch: minimum pytorch version to run test 61 | max_torch: maximum pytorch version to run test 62 | min_python: minimum python version required to run test 63 | skip_windows: skip test for Windows platform 64 | tpu: if TPU is available 65 | sh: if `sh` module is required to run the test 66 | fairscale: if `fairscale` module is required to run the test 67 | deepspeed: if `deepspeed` module is required to run the test 68 | wandb: if `wandb` module is required to run the test 69 | neptune: if `neptune` module is required to run the test 70 | comet: if `comet` module is required to run the test 71 | mlflow: if `mlflow` module is required to run the test 72 | kwargs: native pytest.mark.skipif keyword arguments 73 | """ 74 | conditions = [] 75 | reasons = [] 76 | 77 | if min_gpus: 78 | conditions.append(torch.cuda.device_count() < min_gpus) 79 | reasons.append(f"GPUs>={min_gpus}") 80 | 81 | if min_torch: 82 | torch_version = get_distribution("torch").version 83 | conditions.append(Version(torch_version) < Version(min_torch)) 84 | reasons.append(f"torch>={min_torch}") 85 | 86 | if max_torch: 87 | torch_version = get_distribution("torch").version 88 | conditions.append(Version(torch_version) >= Version(max_torch)) 89 | reasons.append(f"torch<{max_torch}") 90 | 91 | if min_python: 92 | py_version = ( 93 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 94 | ) 95 | conditions.append(Version(py_version) < Version(min_python)) 96 | reasons.append(f"python>={min_python}") 97 | 98 | if skip_windows: 99 | conditions.append(_IS_WINDOWS) 100 | reasons.append("does not run on Windows") 101 | 102 | if tpu: 103 | conditions.append(not _TPU_AVAILABLE) 104 | reasons.append("TPU") 105 | 106 | if sh: 107 | conditions.append(not _SH_AVAILABLE) 108 | reasons.append("sh") 109 | 110 | if fairscale: 111 | conditions.append(not _FAIRSCALE_AVAILABLE) 112 | reasons.append("fairscale") 113 | 114 | if deepspeed: 115 | conditions.append(not _DEEPSPEED_AVAILABLE) 116 | reasons.append("deepspeed") 117 | 118 | if wandb: 119 | conditions.append(not _WANDB_AVAILABLE) 120 | reasons.append("wandb") 121 | 122 | if neptune: 123 | conditions.append(not _NEPTUNE_AVAILABLE) 124 | reasons.append("neptune") 125 | 126 | if comet: 127 | conditions.append(not _COMET_AVAILABLE) 128 | reasons.append("comet") 129 | 130 | if mlflow: 131 | conditions.append(not _MLFLOW_AVAILABLE) 132 | reasons.append("mlflow") 133 | 134 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 135 | return pytest.mark.skipif( 136 | condition=any(conditions), 137 | reason=f"Requires: [{' + '.join(reasons)}]", 138 | **kwargs, 139 | ) 140 | -------------------------------------------------------------------------------- /tests/helpers/run_sh_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from tests.helpers.package_available import _SH_AVAILABLE 6 | 7 | if _SH_AVAILABLE: 8 | import sh 9 | 10 | 11 | def run_sh_command(command: List[str]): 12 | """Default method for executing shell commands with pytest and sh package.""" 13 | msg = None 14 | try: 15 | sh.python(command) 16 | except sh.ErrorReturnCode as e: 17 | msg = e.stderr.decode() 18 | if msg: 19 | pytest.fail(msg=msg) 20 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from hydra.core.hydra_config import HydraConfig 3 | from omegaconf import DictConfig 4 | 5 | 6 | def test_train_config(cfg_train: DictConfig): 7 | assert cfg_train 8 | assert cfg_train.datamodule 9 | assert cfg_train.model 10 | assert cfg_train.trainer 11 | 12 | HydraConfig().set_config(cfg_train) 13 | 14 | hydra.utils.instantiate(cfg_train.datamodule) 15 | hydra.utils.instantiate(cfg_train.model) 16 | hydra.utils.instantiate(cfg_train.trainer) 17 | 18 | 19 | def test_eval_config(cfg_eval: DictConfig): 20 | assert cfg_eval 21 | assert cfg_eval.datamodule 22 | assert cfg_eval.model 23 | assert cfg_eval.trainer 24 | 25 | HydraConfig().set_config(cfg_eval) 26 | 27 | hydra.utils.instantiate(cfg_eval.datamodule) 28 | hydra.utils.instantiate(cfg_eval.model) 29 | hydra.utils.instantiate(cfg_eval.trainer) 30 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.eval import evaluate 8 | from src.train import train 9 | 10 | 11 | @pytest.mark.slow 12 | def test_train_eval(tmp_path, cfg_train, cfg_eval): 13 | """Train for 1 epoch with `train.py` and evaluate with `eval.py`""" 14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir 15 | 16 | with open_dict(cfg_train): 17 | cfg_train.trainer.max_epochs = 1 18 | cfg_train.test = True 19 | 20 | HydraConfig().set_config(cfg_train) 21 | train_metric_dict, _ = train(cfg_train) 22 | 23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") 24 | 25 | with open_dict(cfg_eval): 26 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 27 | 28 | HydraConfig().set_config(cfg_eval) 29 | test_metric_dict, _ = evaluate(cfg_eval) 30 | 31 | assert test_metric_dict["test/acc"] > 0.0 32 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 33 | -------------------------------------------------------------------------------- /tests/test_mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | 6 | from src.datamodules.mnist_datamodule import MNISTDataModule 7 | 8 | 9 | @pytest.mark.parametrize("batch_size", [32, 128]) 10 | def test_mnist_datamodule(batch_size): 11 | data_dir = "data/" 12 | 13 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) 14 | dm.prepare_data() 15 | 16 | assert not dm.data_train and not dm.data_val and not dm.data_test 17 | assert Path(data_dir, "MNIST").exists() 18 | assert Path(data_dir, "MNIST", "raw").exists() 19 | 20 | dm.setup() 21 | assert dm.data_train and dm.data_val and dm.data_test 22 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() 23 | 24 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) 25 | assert num_datapoints == 70_000 26 | 27 | batch = next(iter(dm.train_dataloader())) 28 | x, y = batch 29 | assert len(x) == batch_size 30 | assert len(y) == batch_size 31 | assert x.dtype == torch.float32 32 | assert y.dtype == torch.int64 33 | -------------------------------------------------------------------------------- /tests/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_if import RunIf 4 | from tests.helpers.run_sh_command import run_sh_command 5 | 6 | startfile = "src/train.py" 7 | overrides = ["logger=[]"] 8 | 9 | 10 | @RunIf(sh=True) 11 | @pytest.mark.slow 12 | def test_experiments(tmp_path): 13 | """Test running all available experiment configs with fast_dev_run=True.""" 14 | command = [ 15 | startfile, 16 | "-m", 17 | "experiment=glob(*)", 18 | "hydra.sweep.dir=" + str(tmp_path), 19 | "++trainer.fast_dev_run=true", 20 | ] + overrides 21 | run_sh_command(command) 22 | 23 | 24 | @RunIf(sh=True) 25 | @pytest.mark.slow 26 | def test_hydra_sweep(tmp_path): 27 | """Test default hydra sweep.""" 28 | command = [ 29 | startfile, 30 | "-m", 31 | "hydra.sweep.dir=" + str(tmp_path), 32 | "model.optimizer.lr=0.005,0.01", 33 | "++trainer.fast_dev_run=true", 34 | ] + overrides 35 | 36 | run_sh_command(command) 37 | 38 | 39 | @RunIf(sh=True) 40 | @pytest.mark.slow 41 | def test_hydra_sweep_ddp_sim(tmp_path): 42 | """Test default hydra sweep with ddp sim.""" 43 | command = [ 44 | startfile, 45 | "-m", 46 | "hydra.sweep.dir=" + str(tmp_path), 47 | "trainer=ddp_sim", 48 | "trainer.max_epochs=3", 49 | "+trainer.limit_train_batches=0.01", 50 | "+trainer.limit_val_batches=0.1", 51 | "+trainer.limit_test_batches=0.1", 52 | "model.optimizer.lr=0.005,0.01,0.02", 53 | ] + overrides 54 | run_sh_command(command) 55 | 56 | 57 | @RunIf(sh=True) 58 | @pytest.mark.slow 59 | def test_optuna_sweep(tmp_path): 60 | """Test optuna sweep.""" 61 | command = [ 62 | startfile, 63 | "-m", 64 | "hparams_search=mnist_optuna", 65 | "hydra.sweep.dir=" + str(tmp_path), 66 | "hydra.sweeper.n_trials=10", 67 | "hydra.sweeper.sampler.n_startup_trials=5", 68 | "++trainer.fast_dev_run=true", 69 | ] + overrides 70 | run_sh_command(command) 71 | 72 | 73 | @RunIf(wandb=True, sh=True) 74 | @pytest.mark.slow 75 | def test_optuna_sweep_ddp_sim_wandb(tmp_path): 76 | """Test optuna sweep with wandb and ddp sim.""" 77 | command = [ 78 | startfile, 79 | "-m", 80 | "hparams_search=mnist_optuna", 81 | "hydra.sweep.dir=" + str(tmp_path), 82 | "hydra.sweeper.n_trials=5", 83 | "trainer=ddp_sim", 84 | "trainer.max_epochs=3", 85 | "+trainer.limit_train_batches=0.01", 86 | "+trainer.limit_val_batches=0.1", 87 | "+trainer.limit_test_batches=0.1", 88 | "logger=wandb", 89 | ] 90 | run_sh_command(command) 91 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from hydra.core.hydra_config import HydraConfig 5 | from omegaconf import open_dict 6 | 7 | from src.train import train 8 | from tests.helpers.run_if import RunIf 9 | 10 | 11 | def test_train_fast_dev_run(cfg_train): 12 | """Run for 1 train, val and test step.""" 13 | HydraConfig().set_config(cfg_train) 14 | with open_dict(cfg_train): 15 | cfg_train.trainer.fast_dev_run = True 16 | cfg_train.trainer.accelerator = "cpu" 17 | train(cfg_train) 18 | 19 | 20 | @RunIf(min_gpus=1) 21 | def test_train_fast_dev_run_gpu(cfg_train): 22 | """Run for 1 train, val and test step on GPU.""" 23 | HydraConfig().set_config(cfg_train) 24 | with open_dict(cfg_train): 25 | cfg_train.trainer.fast_dev_run = True 26 | cfg_train.trainer.accelerator = "gpu" 27 | train(cfg_train) 28 | 29 | 30 | @RunIf(min_gpus=1) 31 | @pytest.mark.slow 32 | def test_train_epoch_gpu_amp(cfg_train): 33 | """Train 1 epoch on GPU with mixed-precision.""" 34 | HydraConfig().set_config(cfg_train) 35 | with open_dict(cfg_train): 36 | cfg_train.trainer.max_epochs = 1 37 | cfg_train.trainer.accelerator = "cpu" 38 | cfg_train.trainer.precision = 16 39 | train(cfg_train) 40 | 41 | 42 | @pytest.mark.slow 43 | def test_train_epoch_double_val_loop(cfg_train): 44 | """Train 1 epoch with validation loop twice per epoch.""" 45 | HydraConfig().set_config(cfg_train) 46 | with open_dict(cfg_train): 47 | cfg_train.trainer.max_epochs = 1 48 | cfg_train.trainer.val_check_interval = 0.5 49 | train(cfg_train) 50 | 51 | 52 | @pytest.mark.slow 53 | def test_train_ddp_sim(cfg_train): 54 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes.""" 55 | HydraConfig().set_config(cfg_train) 56 | with open_dict(cfg_train): 57 | cfg_train.trainer.max_epochs = 2 58 | cfg_train.trainer.accelerator = "cpu" 59 | cfg_train.trainer.devices = 2 60 | cfg_train.trainer.strategy = "ddp_spawn" 61 | train(cfg_train) 62 | 63 | 64 | @pytest.mark.slow 65 | def test_train_resume(tmp_path, cfg_train): 66 | """Run 1 epoch, finish, and resume for another epoch.""" 67 | with open_dict(cfg_train): 68 | cfg_train.trainer.max_epochs = 1 69 | 70 | HydraConfig().set_config(cfg_train) 71 | metric_dict_1, _ = train(cfg_train) 72 | 73 | files = os.listdir(tmp_path / "checkpoints") 74 | assert "last.ckpt" in files 75 | assert "epoch_000.ckpt" in files 76 | 77 | with open_dict(cfg_train): 78 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") 79 | cfg_train.trainer.max_epochs = 2 80 | 81 | metric_dict_2, _ = train(cfg_train) 82 | 83 | files = os.listdir(tmp_path / "checkpoints") 84 | assert "epoch_001.ckpt" in files 85 | assert "epoch_002.ckpt" not in files 86 | 87 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] 88 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"] 89 | -------------------------------------------------------------------------------- /tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tmp/.gitkeep -------------------------------------------------------------------------------- /weights/get_regressor_weights.py: -------------------------------------------------------------------------------- 1 | """Get checkpoint from W&B""" 2 | 3 | import wandb 4 | 5 | run = wandb.init() 6 | artifact = run.use_artifact('3ddetection/yolo3d-regressor/experiment-ckpts:v11', type='checkpoints') 7 | artifact_dir = artifact.download() --------------------------------------------------------------------------------