├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── assets
└── global_calib.txt
├── configs
├── augmentation
│ └── inference_preprocessing.yaml
├── callbacks
│ ├── default.yaml
│ ├── early_stopping.yaml
│ ├── model_checkpoint.yaml
│ ├── model_summary.yaml
│ ├── none.yaml
│ ├── rich_progress_bar.yaml
│ └── wandb.yaml
├── convert.yaml
├── datamodule
│ └── kitti_datamodule.yaml
├── debug
│ ├── default.yaml
│ ├── fdr.yaml
│ ├── limit.yaml
│ ├── overfit.yaml
│ └── profiler.yaml
├── detector
│ ├── yolov5.yaml
│ └── yolov5_kitti.yaml
├── eval.yaml
├── evaluate.yaml
├── experiment
│ └── sample.yaml
├── extras
│ └── default.yaml
├── hparams_search
│ ├── mnist_optuna.yaml
│ └── optuna.yaml
├── hydra
│ └── default.yaml
├── inference.yaml
├── local
│ └── .gitkeep
├── logger
│ ├── comet.yaml
│ ├── csv.yaml
│ ├── many_loggers.yaml
│ ├── mlflow.yaml
│ ├── neptune.yaml
│ ├── tensorboard.yaml
│ └── wandb.yaml
├── model
│ └── regressor.yaml
├── paths
│ └── default.yaml
├── train.yaml
└── trainer
│ ├── cpu.yaml
│ ├── ddp.yaml
│ ├── ddp_sim.yaml
│ ├── default.yaml
│ ├── dgx.yaml
│ ├── gpu.yaml
│ ├── kaggle.yaml
│ └── mps.yaml
├── convert.py
├── data
└── datasplit.py
├── docs
├── assets
│ ├── demo.gif
│ ├── logo.png
│ └── show.png
├── command.md
├── index.md
└── javascripts
│ └── mathjax.js
├── inference.py
├── kitti_object_eval
├── LICENSE
├── README.md
├── eval.py
├── evaluate.py
├── kitti_common.py
├── rotate_iou.py
└── run.sh
├── logs
└── .gitkeep
├── mkdocs.yml
├── notebooks
└── .gitkeep
├── pyproject.toml
├── requirements.txt
├── scripts
├── frames_to_video.py
├── generate_sets.py
├── get_weights.py
├── kitti_to_yolo.py
├── post_weights.py
├── schedule.sh
├── video_to_frame.py
└── video_to_gif.py
├── setup.py
├── src
├── __init__.py
├── datamodules
│ ├── __init__.py
│ ├── components
│ │ ├── __init__.py
│ │ └── kitti_dataset.py
│ └── kitti_datamodule.py
├── eval.py
├── models
│ ├── __init__.py
│ ├── components
│ │ ├── __init__.py
│ │ └── base.py
│ └── regressor.py
├── train.py
└── utils
│ ├── Calib.py
│ ├── Math.py
│ ├── Plotting.py
│ ├── __init__.py
│ ├── averages.py
│ ├── class_averages-L4.txt
│ ├── class_averages-kitti6.txt
│ ├── class_averages.txt
│ ├── eval.py
│ ├── kitti_common.py
│ ├── pylogger.py
│ ├── rich_utils.py
│ ├── rotate_iou.py
│ └── utils.py
├── tests
├── __init__.py
├── conftest.py
├── helpers
│ ├── __init__.py
│ ├── package_available.py
│ ├── run_if.py
│ └── run_sh_command.py
├── test_configs.py
├── test_eval.py
├── test_mnist_datamodule.py
├── test_sweeps.py
└── test_train.py
├── tmp
└── .gitkeep
└── weights
└── get_regressor_weights.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Virtualenv
2 | /.venv/
3 | /venv/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | /bin/
14 | /build/
15 | /develop-eggs/
16 | /dist/
17 | /eggs/
18 | /lib/
19 | /lib64/
20 | /output/
21 | /parts/
22 | /sdist/
23 | /var/
24 | /*.egg-info/
25 | /.installed.cfg
26 | /*.egg
27 | /.eggs
28 |
29 | # AUTHORS and ChangeLog will be generated while packaging
30 | /AUTHORS
31 | /ChangeLog
32 |
33 | # BCloud / BuildSubmitter
34 | /build_submitter.*
35 | /logger_client_log
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | .tox/
43 | .coverage
44 | .cache
45 | .pytest_cache
46 | nosetests.xml
47 | coverage.xml
48 |
49 | # Translations
50 | *.mo
51 |
52 | # Sphinx documentation
53 | /docs/_build/
54 |
55 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright (c) 2021-2022 Megvii Inc. All rights reserved.
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | help: ## Show help
3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
4 |
5 | clean: ## Clean autogenerated files
6 | rm -rf dist
7 | find . -type f -name "*.DS_Store" -ls -delete
8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
9 | find . | grep -E ".pytest_cache" | xargs rm -rf
10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
11 | rm -f .coverage
12 |
13 | clean-logs: ## Clean logs
14 | rm -rf logs/**
15 |
16 | format: ## Run pre-commit hooks
17 | pre-commit run -a
18 |
19 | sync: ## Merge changes from main branch to your current branch
20 | git pull
21 | git pull origin main
22 |
23 | test: ## Run not slow tests
24 | pytest -k "not slow"
25 |
26 | test-full: ## Run all tests
27 | pytest
28 |
29 | train: ## Train the model
30 | python src/train.py
31 |
32 | debug: ## Enter debugging mode with pdb
33 | #
34 | # tips:
35 | # - use "import pdb; pdb.set_trace()" to set breakpoint
36 | # - use "h" to print all commands
37 | # - use "n" to execute the next line
38 | # - use "c" to run until the breakpoint is hit
39 | # - use "l" to print src code around current line, "ll" for full function code
40 | # - docs: https://docs.python.org/3/library/pdb.html
41 | #
42 | python -m pdb src/train.py debug=default
43 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # YOLO3D: 3D Object Detection with YOLO
4 |
5 |
6 | ## Introduction
7 |
8 | YOLO3D is inspired by [Mousavian et al.](https://arxiv.org/abs/1612.00496) in their paper **3D Bounding Box Estimation Using Deep Learning and Geometry**. YOLO3D uses a different approach, we use 2d gt label result as the input of first stage detector, then use the 2d result as input to regressor model.
9 |
10 | ## Quickstart
11 | ```bash
12 | git clone git@github.com:ApolloAuto/apollo-model-yolo3d.git
13 | ```
14 |
15 | ### creat env for YOLO3D
16 | ```shell
17 | cd apollo-model-yolo3d
18 |
19 | conda create -n apollo_yolo3d python=3.8 numpy
20 | conda activate apollo_yolo3d
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ### datasets
25 | here we use KITTI data to train. You can download KITTI dataset from [official website](http://www.cvlibs.net/datasets/kitti/). After that, extract dataset to `data/KITTI`.
26 |
27 | ```shell
28 | ln -s /your/KITTI/path data/KITTI
29 | ```
30 |
31 | ```bash
32 | ├── data
33 | │ └── KITTI
34 | │ ├── calib
35 | │ ├── images_2
36 | │ └── labels_2
37 | ```
38 | modify [datasplit](data/datasplit.py) file to split train and val data customerly.
39 |
40 | ```shell
41 | cd data
42 | python datasplit.py
43 | ```
44 |
45 | ### train
46 | modify [train.yaml](configs/train.yaml) to train your model.
47 |
48 | ```shell
49 | python src/train.py experiment=sample
50 | ```
51 | > log path: /logs \
52 | > model path: /weights
53 |
54 | ### covert
55 | modify [convert.yaml](configs/convert.yaml) file to trans .ckpt to .pt model
56 |
57 | ```shell
58 | python convert.py
59 | ```
60 |
61 | ### inference
62 | In order to show the real model infer ability, we crop image according to gt 2d box as yolo3d input, you can use following command to plot 3d results.
63 |
64 | modify [inference.yaml](configs/inference.yaml) file to change .pt model path.
65 | **export_onnx=True** can export onnx model.
66 |
67 | ```shell
68 | python inference.py \
69 | source_dir=./data/KITTI \
70 | detector.classes=6 \
71 | regressor_weights=./weights/pytorch-kitti.pt \
72 | export_onnx=False \
73 | func=image
74 | ```
75 |
76 | - source_dir: path os datasets, include /image_2 and /label_2 folder
77 | - detector.classes: kitti class
78 | - regressor_weights: your model
79 | - export_onnx: export onnx model for apollo
80 |
81 | > result path: /outputs
82 |
83 | ### evaluate
84 | generate label for 3d result:
85 | ```shell
86 | python inference.py \
87 | source_dir=./data/KITTI \
88 | detector.classes=6 \
89 | regressor_weights=./weights/pytorch-kitti.pt \
90 | export_onnx=False \
91 | func=label
92 | ```
93 | > result path: /data/KITTI/result
94 |
95 | ```bash
96 | ├── data
97 | │ └── KITTI
98 | │ ├── calib
99 | │ ├── images_2
100 | │ ├── labels_2
101 | │ └── result
102 | ```
103 |
104 | modify label_path、result_path and label_split_file in [kitti_object_eval](kitti_object_eval) folder script run.sh, with the help of it we can calculate mAP:
105 | ```shell
106 | cd kitti_object_eval
107 | sh run.sh
108 | ```
109 |
110 | ## Acknowledgement
111 | - [yolo3d-lighting](https://github.com/ruhyadi/yolo3d-lightning)
112 | - [skhadem/3D-BoundingBox](https://github.com/skhadem/3D-BoundingBox)
113 | - [Mousavian et al.](https://arxiv.org/abs/1612.00496)
114 | ```
115 | @misc{mousavian20173d,
116 | title={3D Bounding Box Estimation Using Deep Learning and Geometry},
117 | author={Arsalan Mousavian and Dragomir Anguelov and John Flynn and Jana Kosecka},
118 | year={2017},
119 | eprint={1612.00496},
120 | archivePrefix={arXiv},
121 | primaryClass={cs.CV}
122 | }
123 | ```
--------------------------------------------------------------------------------
/assets/global_calib.txt:
--------------------------------------------------------------------------------
1 | # KITTI
2 | P_rect_02: 7.188560e+02 0.000000e+00 6.071928e+02 4.538225e+01 0.000000e+00 7.188560e+02 1.852157e+02 -1.130887e-01 0.000000e+00 0.000000e+00 1.000000e+00 3.779761e-03
3 |
4 | calib_time: 09-Jan-2012 14:00:15
5 | corner_dist: 9.950000e-02
6 | S_00: 1.392000e+03 5.120000e+02
7 | K_00: 9.799200e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.741183e+02 2.486443e+02 0.000000e+00 0.000000e+00 1.000000e+00
8 | D_00: -3.745594e-01 2.049385e-01 1.110145e-03 1.379375e-03 -7.084798e-02
9 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00
10 | T_00: -9.251859e-17 8.326673e-17 -7.401487e-17
11 | S_rect_00: 1.241000e+03 3.760000e+02
12 | R_rect_00: 9.999454e-01 7.259129e-03 -7.519551e-03 -7.292213e-03 9.999638e-01 -4.381729e-03 7.487471e-03 4.436324e-03 9.999621e-01
13 | P_rect_00: 7.188560e+02 0.000000e+00 6.071928e+02 0.000000e+00 0.000000e+00 7.188560e+02 1.852157e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
14 | S_01: 1.392000e+03 5.120000e+02
15 | K_01: 9.903522e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.855674e+02 2.607319e+02 0.000000e+00 0.000000e+00 1.000000e+00
16 | D_01: -3.712084e-01 1.978723e-01 -3.709831e-05 -3.440494e-04 -6.724045e-02
17 | R_01: 9.993440e-01 1.814887e-02 -3.134011e-02 -1.842595e-02 9.997935e-01 -8.575221e-03 3.117801e-02 9.147067e-03 9.994720e-01
18 | T_01: -5.370000e-01 5.964270e-03 -1.274584e-02
19 | S_rect_01: 1.241000e+03 3.760000e+02
20 | R_rect_01: 9.996568e-01 -1.110284e-02 2.372712e-02 1.099810e-02 9.999292e-01 4.539964e-03 -2.377585e-02 -4.277453e-03 9.997082e-01
21 | P_rect_01: 7.188560e+02 0.000000e+00 6.071928e+02 -3.861448e+02 0.000000e+00 7.188560e+02 1.852157e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
22 | S_02: 1.392000e+03 5.120000e+02
23 | K_02: 9.601149e+02 0.000000e+00 6.947923e+02 0.000000e+00 9.548911e+02 2.403547e+02 0.000000e+00 0.000000e+00 1.000000e+00
24 | D_02: -3.685917e-01 1.928022e-01 4.069233e-04 7.247536e-04 -6.276909e-02
25 | R_02: 9.999788e-01 -5.008404e-03 -4.151018e-03 4.990516e-03 9.999783e-01 -4.308488e-03 4.172506e-03 4.287682e-03 9.999821e-01
26 | T_02: 5.954406e-02 -7.675338e-04 3.582565e-03
27 | S_rect_02: 1.241000e+03 3.760000e+02
28 | R_rect_02: 9.999191e-01 1.228161e-02 -3.316013e-03 -1.228209e-02 9.999246e-01 -1.245511e-04 3.314233e-03 1.652686e-04 9.999945e-01
29 | S_03: 1.392000e+03 5.120000e+02
30 | K_03: 9.049931e+02 0.000000e+00 6.957698e+02 0.000000e+00 9.004945e+02 2.389820e+02 0.000000e+00 0.000000e+00 1.000000e+00
31 | D_03: -3.735725e-01 2.066816e-01 -6.133284e-04 -1.193269e-04 -7.600861e-02
32 | R_03: 9.995578e-01 1.656369e-02 -2.469315e-02 -1.663353e-02 9.998582e-01 -2.625576e-03 2.464616e-02 3.035149e-03 9.996916e-01
33 | T_03: -4.738786e-01 5.991982e-03 -3.215069e-03
34 | S_rect_03: 1.241000e+03 3.760000e+02
35 | R_rect_03: 9.998092e-01 -9.354781e-03 1.714961e-02 9.382303e-03 9.999548e-01 -1.525064e-03 -1.713457e-02 1.685675e-03 9.998518e-01
36 | P_rect_03: 7.188560e+02 0.000000e+00 6.071928e+02 -3.372877e+02 0.000000e+00 7.188560e+02 1.852157e+02 2.369057e+00 0.000000e+00 0.000000e+00 1.000000e+00 4.915215e-03
37 |
--------------------------------------------------------------------------------
/configs/augmentation/inference_preprocessing.yaml:
--------------------------------------------------------------------------------
1 | to_tensor:
2 | _target_: torchvision.transforms.ToTensor
3 | normalize:
4 | _target_: torchvision.transforms.Normalize
5 | mean: [0.406, 0.456, 0.485]
6 | std: [0.225, 0.224, 0.229]
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model_checkpoint.yaml
3 | - early_stopping.yaml
4 | - model_summary.yaml
5 | - rich_progress_bar.yaml
6 | - _self_
7 |
8 | # model save config
9 | model_checkpoint:
10 | dirpath: "weights"
11 | filename: "epoch_{epoch:03d}"
12 | monitor: "val/loss"
13 | mode: "min"
14 | save_last: True
15 | save_top_k: 1
16 | auto_insert_metric_name: False
17 |
18 | early_stopping:
19 | monitor: "val/loss"
20 | patience: 100
21 | mode: "min"
22 |
23 | model_summary:
24 | max_depth: -1
25 |
--------------------------------------------------------------------------------
/configs/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.EarlyStopping.html
2 |
3 | # Monitor a metric and stop training when it stops improving.
4 | # Look at the above link for more detailed information.
5 | early_stopping:
6 | _target_: pytorch_lightning.callbacks.EarlyStopping
7 | monitor: ??? # quantity to be monitored, must be specified !!!
8 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
9 | patience: 3 # number of checks with no improvement after which training will be stopped
10 | verbose: False # verbosity mode
11 | mode: "min" # "max" means higher metric value is better, can be also "min"
12 | strict: True # whether to crash the training if monitor is not found in the validation metrics
13 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
14 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
15 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
16 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
17 | # log_rank_zero_only: False # this keyword argument isn't available in stable version
18 |
--------------------------------------------------------------------------------
/configs/callbacks/model_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ModelCheckpoint.html
2 |
3 | # Save the model periodically by monitoring a quantity.
4 | # Look at the above link for more detailed information.
5 | model_checkpoint:
6 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
7 | dirpath: null # directory to save the model file
8 | filename: null # checkpoint filename
9 | monitor: null # name of the logged metric which determines when model is improving
10 | verbose: False # verbosity mode
11 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
12 | save_top_k: 1 # save k best models (determined by above metric)
13 | mode: "min" # "max" means higher metric value is better, can be also "min"
14 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
15 | save_weights_only: False # if True, then only the model’s weights will be saved
16 | every_n_train_steps: null # number of training steps between checkpoints
17 | train_time_interval: null # checkpoints are monitored at the specified time interval
18 | every_n_epochs: null # number of epochs between checkpoints
19 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
20 |
--------------------------------------------------------------------------------
/configs/callbacks/model_summary.yaml:
--------------------------------------------------------------------------------
1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichModelSummary.html
2 |
3 | # Generates a summary of all layers in a LightningModule with rich text formatting.
4 | # Look at the above link for more detailed information.
5 | model_summary:
6 | _target_: pytorch_lightning.callbacks.RichModelSummary
7 | max_depth: 1 # the maximum depth of layer nesting that the summary will include
8 |
--------------------------------------------------------------------------------
/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/configs/callbacks/rich_progress_bar.yaml:
--------------------------------------------------------------------------------
1 | # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.RichProgressBar.html
2 |
3 | # Create a progress bar with rich text formatting.
4 | # Look at the above link for more detailed information.
5 | rich_progress_bar:
6 | _target_: pytorch_lightning.callbacks.RichProgressBar
7 |
--------------------------------------------------------------------------------
/configs/callbacks/wandb.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | watch_model:
5 | _target_: src.callbacks.wandb_callbacks.WatchModel
6 | log: "all"
7 | log_freq: 100
8 |
9 | upload_code_as_artifact:
10 | _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
11 | code_dir: ${original_work_dir}/src
12 |
13 | upload_ckpts_as_artifact:
14 | _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
15 | ckpt_dir: "checkpoints/"
16 | upload_best_only: True
17 |
18 | # log_f1_precision_recall_heatmap:
19 | # _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
20 |
21 | # log_confusion_matrix:
22 | # _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
23 |
24 | # log_image_predictions:
25 | # _target_: src.callbacks.wandb_callbacks.LogImagePredictions
26 | # num_samples: 8
--------------------------------------------------------------------------------
/configs/convert.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - model: regressor.yaml
7 |
8 | # enable color logging
9 | - override hydra/hydra_logging: colorlog
10 | - override hydra/job_logging: colorlog
11 |
12 | # pretty print config at the start of the run using Rich library
13 | print_config: True
14 |
15 | # disable python warnings if they annoy you
16 | ignore_warnings: True
17 |
18 | # root
19 | root: ${hydra:runtime.cwd}
20 |
21 | # TODO: cahnge to your checkpoint file
22 | checkpoint_dir: ${root}/weights/last.ckpt
23 |
24 | # dump dir
25 | dump_dir: ${root}/weights
26 |
27 | # input sample shape
28 | input_sample:
29 | __target__: torch.randn
30 | size: (1, 3, 224, 224)
31 |
32 | # convert to
33 | convert_to: "pytorch" # [pytorch, onnx, tensorrt]
34 |
35 | # TODO: model name without extension
36 | name: ${dump_dir}/pytorch-kitti
37 |
38 | # convert_to: "onnx" # [pytorch, onnx, tensorrt]
39 | # name: ${dump_dir}/onnx-3d-0817-5
40 |
--------------------------------------------------------------------------------
/configs/datamodule/kitti_datamodule.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datamodules.kitti_datamodule.KITTIDataModule
2 |
3 | dataset_path: ${paths.data_dir} # data_dir is specified in config.yaml
4 | train_sets: ${paths.data_dir}/train_80.txt
5 | val_sets: ${paths.data_dir}/val_80.txt
6 | test_sets: ${paths.data_dir}/test_80.txt
7 | batch_size: 64
8 | num_worker: 32
--------------------------------------------------------------------------------
/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | # disable callbacks and loggers during debugging
10 | callbacks: null
11 | logger: null
12 |
13 | extras:
14 | ignore_warnings: False
15 | enforce_tags: False
16 |
17 | # sets level of all command line loggers to 'DEBUG'
18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19 | hydra:
20 | job_logging:
21 | root:
22 | level: DEBUG
23 |
24 | # use this to also set hydra loggers to 'DEBUG'
25 | # verbose: True
26 |
27 | trainer:
28 | max_epochs: 1
29 | accelerator: cpu # debuggers don't like gpus
30 | devices: 1 # debuggers don't like multiprocessing
31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32 |
33 | datamodule:
34 | num_workers: 0 # debuggers don't like multiprocessing
35 | pin_memory: False # disable gpu memory pin
36 |
--------------------------------------------------------------------------------
/configs/debug/fdr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs 1 train, 1 validation and 1 test step
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | fast_dev_run: true
10 |
--------------------------------------------------------------------------------
/configs/debug/limit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # uses only 1% of the training data and 5% of validation/test data
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 3
10 | limit_train_batches: 0.01
11 | limit_val_batches: 0.05
12 | limit_test_batches: 0.05
13 |
--------------------------------------------------------------------------------
/configs/debug/overfit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # overfits to 3 batches
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 20
10 | overfit_batches: 3
11 |
12 | # model ckpt and early stopping need to be disabled during overfitting
13 | callbacks: null
14 |
--------------------------------------------------------------------------------
/configs/debug/profiler.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs with execution time profiling
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 1
10 | profiler: "simple"
11 | # profiler: "advanced"
12 | # profiler: "pytorch"
13 |
--------------------------------------------------------------------------------
/configs/detector/yolov5.yaml:
--------------------------------------------------------------------------------
1 | _target_: inference.detector_yolov5
2 |
3 | model_path: ${root}/weights/detector_yolov5s.pt
4 | cfg_path: ${root}/yolov5/models/yolov5s.yaml
5 | classes: 5
6 | device: 'cpu'
--------------------------------------------------------------------------------
/configs/detector/yolov5_kitti.yaml:
--------------------------------------------------------------------------------
1 | # KITTI to YOLO
2 |
3 | path: ../data/KITTI/ # dataset root dir
4 | train: train_yolo.txt # train images (relative to 'path') 3712 images
5 | val: val_yolo.txt # val images (relative to 'path') 3768 images
6 |
7 | # Classes
8 | nc: 5 # number of classes
9 | names: ['car', 'van', 'truck', 'pedestrian', 'cyclist']
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - datamodule: mnist.yaml # choose datamodule with `test_dataloader()` for evaluation
6 | - model: mnist.yaml
7 | - logger: null
8 | - trainer: default.yaml
9 | - paths: default.yaml
10 | - extras: default.yaml
11 | - hydra: default.yaml
12 |
13 | task_name: "eval"
14 |
15 | tags: ["dev"]
16 |
17 | # passing checkpoint path is necessary for evaluation
18 | ckpt_path: ???
19 |
--------------------------------------------------------------------------------
/configs/evaluate.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - detector: yolov5.yaml
7 | - model: regressor.yaml
8 | - augmentation: inference_preprocessing.yaml
9 |
10 | # debugging config (enable through command line, e.g. `python train.py debug=default)
11 | - debug: null
12 |
13 | # enable color logging
14 | - override hydra/hydra_logging: colorlog
15 | - override hydra/job_logging: colorlog
16 |
17 | # run name
18 | name: evaluate
19 |
20 | # directory
21 | root: ${hydra:runtime.cwd}
22 |
23 | # predictions/output directory
24 | # pred_dir: ${root}/${hydra:run.dir}/${name}
25 |
26 | # calib_file
27 | calib_file: ${root}/assets/global_calib.txt
28 |
29 | # regressor weights
30 | regressor_weights: ${root}/weights/regressor_resnet18.pt
31 |
32 | # validation images directory
33 | val_images_path: ${root}/data/KITTI/images_2
34 |
35 | # validation sets directory
36 | val_sets: ${root}/data/KITTI/ImageSets/val.txt
37 |
38 | # class to evaluated
39 | classes: 6
40 |
41 | # class_to_name = {
42 | # 0: 'Car',
43 | # 1: 'Cyclist',
44 | # 2: 'Truck',
45 | # 3: 'Van',
46 | # 4: 'Pedestrian',
47 | # 5: 'Tram',
48 | # }
49 |
50 | # gt label path
51 | gt_dir: ${root}/data/KITTI/label_2
52 |
53 | # dt label path
54 | pred_dir: ${root}/data/KITTI/result
55 |
56 | # device to inference
57 | device: 'cuda:0'
--------------------------------------------------------------------------------
/configs/experiment/sample.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /datamodule: kitti_datamodule.yaml
8 | - override /model: regressor.yaml
9 | - override /callbacks: default.yaml
10 | - override /logger: wandb.yaml
11 | - override /trainer: dgx.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | seed: 42069
17 |
18 | # name of the run determines folder name in logs
19 | name: "new_network"
20 |
21 | datamodule:
22 | train_sets: ${paths.data_dir}/ImageSets/train.txt
23 | val_sets: ${paths.data_dir}/ImageSets/val.txt
24 | test_sets: ${paths.data_dir}/ImageSets/test.txt
25 |
26 | trainer:
27 | min_epochs: 1
28 | max_epochs: 200
29 | # limit_train_batches: 1.0
30 | # limit_val_batches: 1.0
31 | gpus: [0]
32 | strategy: ddp
--------------------------------------------------------------------------------
/configs/extras/default.yaml:
--------------------------------------------------------------------------------
1 | # disable python warnings if they annoy you
2 | ignore_warnings: False
3 |
4 | # ask user for tags if none are provided in the config
5 | enforce_tags: True
6 |
7 | # pretty print config tree at the start of the run using Rich library
8 | print_config: True
9 |
--------------------------------------------------------------------------------
/configs/hparams_search/mnist_optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python train.py -m hparams_search=mnist_optuna experiment=example
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/acc_best"
12 |
13 | # here we define Optuna hyperparameter search
14 | # it optimizes for value returned from function with @hydra.main decorator
15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16 | hydra:
17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18 |
19 | sweeper:
20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21 |
22 | # storage URL to persist optimization results
23 | # for example, you can use SQLite if you set 'sqlite:///example.db'
24 | storage: null
25 |
26 | # name of the study to persist optimization results
27 | study_name: null
28 |
29 | # number of parallel workers
30 | n_jobs: 1
31 |
32 | # 'minimize' or 'maximize' the objective
33 | direction: maximize
34 |
35 | # total number of runs that will be executed
36 | n_trials: 20
37 |
38 | # choose Optuna hyperparameter sampler
39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41 | sampler:
42 | _target_: optuna.samplers.TPESampler
43 | seed: 1234
44 | n_startup_trials: 10 # number of random sampling runs before optimization starts
45 |
46 | # define hyperparameter search space
47 | params:
48 | model.optimizer.lr: interval(0.0001, 0.1)
49 | datamodule.batch_size: choice(32, 64, 128, 256)
50 | model.net.lin1_size: choice(64, 128, 256)
51 | model.net.lin2_size: choice(64, 128, 256)
52 | model.net.lin3_size: choice(32, 64, 128, 256)
53 |
--------------------------------------------------------------------------------
/configs/hparams_search/optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python train.py -m hparams_search=mnist_optuna experiment=example
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/loss"
12 |
13 | # here we define Optuna hyperparameter search
14 | # it optimizes for value returned from function with @hydra.main decorator
15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16 | hydra:
17 | mode: "MULTIRUN"
18 |
19 | sweeper:
20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21 |
22 | # storage URL to persist optimization results
23 | # for example, you can use SQLite if you set 'sqlite:///example.db'
24 | storage: null
25 |
26 | # name of the study to persist optimization results
27 | study_name: null
28 |
29 | # number of parallel workers
30 | n_jobs: 2
31 |
32 | # 'minimize' or 'maximize' the objective
33 | direction: 'minimize'
34 |
35 | # total number of runs that will be executed
36 | n_trials: 10
37 |
38 | # choose Optuna hyperparameter sampler
39 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
40 | sampler:
41 | _target_: optuna.samplers.TPESampler
42 | seed: 42069
43 | n_startup_trials: 10 # number of random sampling runs before optimization starts
44 |
45 | # define range of hyperparameters
46 | params:
47 | model.lr: interval(0.0001, 0.001)
48 | datamodule.batch_size: choice(32, 64, 128)
49 | model.optimizer: choice(adam, sgd)
--------------------------------------------------------------------------------
/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11 | sweep:
12 | dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13 | subdir: ${hydra.job.num}
14 |
--------------------------------------------------------------------------------
/configs/inference.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - _self_
6 | - detector: yolov5.yaml
7 | - model: regressor.yaml
8 | - augmentation: inference_preprocessing.yaml
9 |
10 | # debugging config (enable through command line, e.g. `python train.py debug=default)
11 | - debug: null
12 |
13 | # enable color logging
14 | - override hydra/hydra_logging: colorlog
15 | - override hydra/job_logging: colorlog
16 |
17 | # run name
18 | name: inference
19 |
20 | # directory
21 | root: ${hydra:runtime.cwd}
22 | output_dir: ${root}/${hydra:run.dir}/inference
23 |
24 | # calib_file
25 | calib_file: ${root}/assets/global_calib.txt
26 |
27 | # save 2D bounding box
28 | save_det2d: False
29 |
30 | # show and save result
31 | save_result: True
32 |
33 | # save result in txt
34 | # save_txt: True
35 |
36 | # regressor weights
37 | regressor_weights: ${root}/weights/regressor_resnet18.pt
38 | # regressor_weights: ${root}/weights/mobilenetv3-best.pt
39 |
40 | # inference type
41 | inference_type: pytorch # [pytorch, onnx, openvino, tensorrt]
42 |
43 | # source directory
44 | # source_dir: ${root}/tmp/kitti/
45 | source_dir: ${root}/tmp/video_001
46 |
47 | # device to inference
48 | device: 'cpu'
49 |
50 | export_onnx: False
51 |
52 | func: "label" # image/label
53 |
--------------------------------------------------------------------------------
/configs/local/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/configs/local/.gitkeep
--------------------------------------------------------------------------------
/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: pytorch_lightning.loggers.comet.CometLogger
5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
6 | save_dir: "${paths.output_dir}"
7 | project_name: "lightning-hydra-template"
8 | rest_api_key: null
9 | # experiment_name: ""
10 | experiment_key: null # set to resume experiment
11 | offline: False
12 | prefix: ""
13 |
--------------------------------------------------------------------------------
/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # csv logger built in lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "${paths.output_dir}"
6 | name: "csv/"
7 | prefix: ""
8 |
--------------------------------------------------------------------------------
/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | # - comet.yaml
5 | - csv.yaml
6 | # - mlflow.yaml
7 | # - neptune.yaml
8 | - tensorboard.yaml
9 | - wandb.yaml
10 |
--------------------------------------------------------------------------------
/configs/logger/mlflow.yaml:
--------------------------------------------------------------------------------
1 | # https://mlflow.org
2 |
3 | mlflow:
4 | _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
5 | # experiment_name: ""
6 | # run_name: ""
7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI
8 | tags: null
9 | # save_dir: "./mlruns"
10 | prefix: ""
11 | artifact_location: null
12 | # run_id: ""
13 |
--------------------------------------------------------------------------------
/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
6 | project: username/lightning-hydra-template
7 | # name: ""
8 | log_model_checkpoints: True
9 | prefix: ""
10 |
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # https://www.tensorflow.org/tensorboard/
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "${paths.output_dir}/tensorboard/"
6 | name: null
7 | log_graph: False
8 | default_hp_metric: True
9 | prefix: ""
10 | # version: ""
11 |
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
5 | # name: "" # name of the run (normally generated by wandb)
6 | save_dir: "${paths.output_dir}"
7 | offline: False
8 | id: null # pass correct id to resume experiment!
9 | anonymous: null # enable anonymous logging
10 | project: "yolo3d-regressor"
11 | log_model: True # upload lightning ckpts
12 | prefix: "" # a string to put at the beginning of metric keys
13 | # entity: "" # set to name of your wandb team
14 | group: ""
15 | tags: []
16 | job_type: ""
17 |
--------------------------------------------------------------------------------
/configs/model/regressor.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.regressor.RegressorModel
2 |
3 | net:
4 | _target_: src.models.components.base.RegressorNet
5 | backbone:
6 | _target_: torchvision.models.resnet18 # change model on this
7 | pretrained: True
8 | bins: 2
9 |
10 | optimizer: adam
11 |
12 | lr: 0.0001
13 | momentum: 0.9
14 | w: 0.8
15 | alpha: 0.2
--------------------------------------------------------------------------------
/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | # path to root directory
2 | # this requires PROJECT_ROOT environment variable to exist
3 | # PROJECT_ROOT is inferred and set by pyrootutils package in `train.py` and `eval.py`
4 | root_dir: ${oc.env:PROJECT_ROOT}
5 |
6 | # path to data directory
7 | data_dir: ${paths.root_dir}/data/KITTI
8 |
9 | # path to logging directory
10 | log_dir: ${paths.root_dir}/logs/
11 |
12 | # path to output directory, created dynamically by hydra
13 | # path generation pattern is specified in `configs/hydra/default.yaml`
14 | # use it to store all files generated during the run, like ckpts and metrics
15 | output_dir: ${hydra:runtime.output_dir}
16 |
17 | # path to working directory
18 | work_dir: ${hydra:runtime.cwd}
19 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default configuration
4 | # order of defaults determines the order in which configs override each other
5 | defaults:
6 | - _self_
7 | - datamodule: kitti_datamodule.yaml
8 | - model: regressor.yaml
9 | - callbacks: default.yaml
10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11 | - trainer: dgx.yaml
12 | - paths: default.yaml
13 | - extras: default.yaml
14 | - hydra: default.yaml
15 |
16 | # experiment configs allow for version control of specific hyperparameters
17 | # e.g. best hyperparameters for given model and datamodule
18 | - experiment: null
19 |
20 | # config for hyperparameter optimization
21 | - hparams_search: null
22 |
23 | # optional local config for machine/user specific settings
24 | # it's optional since it doesn't need to exist and is excluded from version control
25 | - optional local: default.yaml
26 |
27 | # debugging config (enable through command line, e.g. `python train.py debug=default)
28 | - debug: null
29 |
30 | # task name, determines output directory path
31 | task_name: "train"
32 |
33 | # tags to help you identify your experiments
34 | # you can overwrite this in experiment configs
35 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
36 | # appending lists from command line is currently not supported :(
37 | # https://github.com/facebookresearch/hydra/issues/1547
38 | tags: ["dev"]
39 |
40 | # set False to skip model training
41 | train: True
42 |
43 | # evaluate on test set, using best model weights achieved during training
44 | # lightning chooses best weights based on the metric specified in checkpoint callback
45 | test: False
46 |
47 | # simply provide checkpoint path to resume training
48 | # ckpt_path: weights/last.ckpt
49 | ckpt_path: null
50 |
51 | # seed for random number generators in pytorch, numpy and python.random
52 | seed: null
53 |
--------------------------------------------------------------------------------
/configs/trainer/cpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: cpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # use "ddp_spawn" instead of "ddp",
5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra
6 | # https://github.com/facebookresearch/hydra/issues/2070
7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
8 | strategy: ddp_spawn
9 |
10 | accelerator: gpu
11 | devices: 4
12 | num_nodes: 1
13 | sync_batchnorm: True
14 |
--------------------------------------------------------------------------------
/configs/trainer/ddp_sim.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # simulate DDP on CPU, useful for debugging
5 | accelerator: cpu
6 | devices: 2
7 | strategy: ddp_spawn
8 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | default_root_dir: ${paths.output_dir}
4 |
5 | min_epochs: 1 # prevents early stopping
6 | max_epochs: 25
7 |
8 | accelerator: cpu
9 | devices: 1
10 |
11 | # mixed precision for extra speed-up
12 | # precision: 16
13 |
14 | # set True to to ensure deterministic results
15 | # makes training slower but gives more reproducibility than just setting seeds
16 | deterministic: False
17 |
--------------------------------------------------------------------------------
/configs/trainer/dgx.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # strategy: ddp
5 |
6 | accelerator: gpu
7 | devices: [0]
8 | num_nodes: 1
9 | sync_batchnorm: True
--------------------------------------------------------------------------------
/configs/trainer/gpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: gpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/kaggle.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | gpus: 0
4 |
5 | min_epochs: 1
6 | max_epochs: 10
7 |
8 | # number of validation steps to execute at the beginning of the training
9 | # num_sanity_val_steps: 0
10 |
11 | # ckpt path
12 | resume_from_checkpoint: null
13 |
14 | # disable progress_bar
15 | enable_progress_bar: False
--------------------------------------------------------------------------------
/configs/trainer/mps.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: mps
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | """ Conver checkpoint to model (.pt/.pth/.onnx) """
2 |
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | from pytorch_lightning import LightningModule
6 | from src import utils
7 |
8 | import dotenv
9 | import hydra
10 | from omegaconf import DictConfig
11 | import os
12 |
13 | # load environment variables from `.env` file if it exists
14 | # recursively searches for `.env` in all folders starting from work dir
15 | dotenv.load_dotenv(override=True)
16 | log = utils.get_pylogger(__name__)
17 |
18 | @hydra.main(config_path="configs/", config_name="convert.yaml")
19 | def convert(config: DictConfig):
20 |
21 | # assert model convertion
22 | assert config.get('convert_to') in ['pytorch', 'torchscript', 'onnx', 'tensorrt'], \
23 | "Please Choose one of [pytorch, torchscript, onnx, tensorrt]"
24 |
25 | # Init lightning model
26 | log.info(f"Instantiating model <{config.model._target_}>")
27 | model: LightningModule = hydra.utils.instantiate(config.model)
28 | # regressor: LightningModule = hydra.utils.instantiate(config.model)
29 | # regressor.load_state_dict(torch.load(config.get("regressor_weights"), map_location="cpu"))
30 | # regressor.eval().to(config.get("device"))
31 |
32 | # Convert relative ckpt path to absolute path if necessary
33 | log.info(f"Load checkpoint <{config.get('checkpoint_dir')}>")
34 | ckpt_path = config.get("checkpoint_dir")
35 | if ckpt_path and not os.path.isabs(ckpt_path):
36 | ckpt_path = config.get(os.path.join(hydra.utils.get_original_cwd(), ckpt_path))
37 |
38 | # load model checkpoint
39 | model = model.load_from_checkpoint(ckpt_path)
40 | model.cuda()
41 |
42 | # input sample
43 | input_sample = config.get('input_sample')
44 |
45 | # Convert
46 | if config.get('convert_to') == 'pytorch':
47 | log.info("Convert to Pytorch (.pt)")
48 | torch.save(model.state_dict(), f'{config.get("name")}.pt')
49 | log.info(f"Saved model {config.get('name')}.pt")
50 | if config.get('convert_to') == 'onnx':
51 | log.info("Convert to ONNX (.onnx)")
52 | model.cuda()
53 | input_sample = torch.rand((1, 3, 224, 224), device=torch.device('cuda'))
54 | model.to_onnx(f'{config.get("name")}.onnx', input_sample, export_params=True)
55 | log.info(f"Saved model {config.get('name')}.onnx")
56 |
57 | if __name__ == '__main__':
58 |
59 | convert()
--------------------------------------------------------------------------------
/data/datasplit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Baidu apollo, Inc.
3 | # All Rights Reserved
4 |
5 | import os
6 | import random
7 |
8 | # TODO: change this to your own data path
9 | pnglabelfilepath = r'./KITTI/label_2'
10 | savePath = r"./KITTI/ImageSets/"
11 |
12 | target_png = os.listdir(pnglabelfilepath)
13 | total_png = []
14 | for t in target_png:
15 | if t.endswith(".txt"):
16 | id = str(int(t.split('.')[0])).zfill(6)
17 | total_png.append(id + '.png')
18 |
19 | print("--- iter for image finished ---")
20 |
21 | # TODO: change this ratio to your own
22 | train_percent = 0.85
23 | val_percent = 0.1
24 | test_percent = 0.05
25 |
26 | num = len(total_png)
27 | # train = random.sample(num,0.9*num)
28 | list = list(range(num))
29 |
30 | num_train = int(num * train_percent)
31 | num_val = int(num * val_percent)
32 |
33 |
34 | train = random.sample(list, num_train)
35 | num1 = len(train)
36 | for i in range(num1):
37 | list.remove(train[i])
38 |
39 | val_test = [i for i in list if not i in train]
40 | val = random.sample(val_test, num_val)
41 | num2 = len(val)
42 | for i in range(num2):
43 | list.remove(val[i])
44 |
45 |
46 | def mkdir(path):
47 | folder = os.path.exists(path)
48 | if not folder:
49 | os.makedirs(path)
50 | print("--- creating new folder... ---")
51 | print("--- finished ---")
52 | else:
53 | print("--- pass to create new folder ---")
54 |
55 |
56 | mkdir(savePath)
57 |
58 | ftrain = open(os.path.join(savePath, 'train.txt'), 'w')
59 | fval = open(os.path.join(savePath, 'val.txt'), 'w')
60 | ftest = open(os.path.join(savePath, 'test.txt'), 'w')
61 |
62 | for i in train:
63 | name = total_png[i][:-4]+ '\n'
64 | ftrain.write(name)
65 |
66 |
67 | for i in val:
68 | name = total_png[i][:-4] + '\n'
69 | fval.write(name)
70 |
71 |
72 | for i in list:
73 | name = total_png[i][:-4] + '\n'
74 | ftest.write(name)
75 |
76 | ftrain.close()
77 |
--------------------------------------------------------------------------------
/docs/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/demo.gif
--------------------------------------------------------------------------------
/docs/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/logo.png
--------------------------------------------------------------------------------
/docs/assets/show.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/docs/assets/show.png
--------------------------------------------------------------------------------
/docs/command.md:
--------------------------------------------------------------------------------
1 | # Quick Command
2 |
3 | ## Train Regressor Model
4 |
5 | - Train original
6 | ```bash
7 | python src/train.py
8 | ```
9 |
10 | - With experiment
11 | ```bash
12 | python src/train.py \
13 | experiment=sample
14 | ```
15 |
16 | ## Train Detector Model
17 | ### Yolov5
18 |
19 | - Multi GPU Training
20 | ```bash
21 | cd yolov5
22 | python -m torch.distributed.launch \
23 | --nproc_per_node 4 train.py \
24 | --epochs 10 \
25 | --batch 64 \
26 | --data ../configs/detector/yolov5_kitti.yaml \
27 | --weights yolov5s.pt \
28 | --device 0,1,2,3
29 | ```
30 |
31 | - Single GPU Training
32 | ```bash
33 | cd yolov5
34 | python train.py \
35 | --data ../configs/detector/yolov5_kitti.yaml \
36 | --weights yolov5s.pt \
37 | --img 640
38 | ```
39 |
40 | ## Hyperparameter Tuning with Hydra
41 |
42 | ```bash
43 | python src/train.py -m \
44 | hparams_search=regressor_optuna \
45 | experiment=sample_optuna
46 | ```
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # YOLO3D: 3D Object Detection with YOLO
2 |
3 |
4 |
5 |

6 |

7 |

8 |
9 |

10 |

11 |

12 |
13 |
14 |
15 | ## ⚠️ Cautions
16 | > This repository currently under development
17 |
18 | ## 📼 Demo
19 |
20 |
21 | 
22 |
23 |
24 |
25 | ## 📌 Introduction
26 |
27 | Unofficial implementation of [Mousavian et al.](https://arxiv.org/abs/1612.00496) in their paper **3D Bounding Box Estimation Using Deep Learning and Geometry**. YOLO3D uses a different approach, as the detector uses **YOLOv5** which previously used Faster-RCNN, and Regressor uses **ResNet18/VGG11** which was previously VGG19.
28 |
29 | ## 🚀 Quickstart
30 | > We use hydra as the config manager; if you are unfamiliar with hydra, you can visit the official website or see the tutorial on this web.
31 |
32 | ### 🍿 Inference
33 | You can use pretrained weight from [Release](https://github.com/ruhyadi/yolo3d-lightning/releases), you can download it using script `get_weights.py`:
34 | ```bash
35 | # download pretrained model
36 | python script/get_weights.py \
37 | --tag v0.1 \
38 | --dir ./weights
39 | ```
40 | Inference with `inference.py`:
41 | ```bash
42 | python inference.py \
43 | source_dir="./data/demo/images" \
44 | detector.model_path="./weights/detector_yolov5s.pt" \
45 | regressor_weights="./weights/regressor_resnet18.pt"
46 | ```
47 |
48 | ### ⚔️ Training
49 | There are two models that will be trained here: **detector** and **regressor**. For now, the detector model that can be used is only **YOLOv5**, while the regressor model can use all models supported by **Torchvision**.
50 |
51 | #### 🧭 Training YOLOv5 Detector
52 | The first step is to change the `label_2` format from KITTI to YOLO. You can use the following `src/kitti_to_yolo.py`.
53 |
54 | ```bash
55 | cd yolo3d-lightning/src
56 | python kitti_to_yolo.py \
57 | --dataset_path ../data/KITTI/training/
58 | --classes ["car", "van", "truck", "pedestrian", "cyclist"]
59 | --img_width 1224
60 | --img_height 370
61 | ```
62 |
63 | The next step is to follow the [wiki provided by ultralytics](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data). **Note:** *readme will updated in future*.
64 |
65 | #### 🪀 Training Regessor
66 | Selanjutnya, kamu dapat melakukan training model regressor. Model regressor yang dapat dipakai bisa mengacu pada yang tersedia di `torchvision`, atau kamu bisa mengkustomnya sendiri.
67 |
68 | Langkah pertama adalah membuat train dan validation sets. Kamu dapat menggunakan `script/generate_sets.py`:
69 |
70 | ```bash
71 | cd yolo3d-lightning/script
72 | python generate_sets.py \
73 | --images_path ../data/KITTI/training/images # or image_2
74 | --dump_dir ../data/KITTI/training
75 | --postfix _80
76 | --train_size 0.8
77 | ```
78 |
79 | Pada langkah selanjutnya, kita hanya akan menggunakan model yang ada di `torchvision` saja. Langkah termudah adalah dengan mengubah configurasi di `configs.model.regressor.yaml`, seperti di bawah:
80 |
81 | ```yaml
82 | _target_: src.models.regressor.RegressorModel
83 |
84 | net:
85 | _target_: src.models.components.base.RegressorNet
86 | backbone:
87 | _target_: torchvision.models.resnet18 # edit this
88 | pretrained: True # maybe this too
89 | bins: 2
90 |
91 | lr: 0.001
92 | momentum: 0.9
93 | w: 0.4
94 | alpha: 0.6
95 | ```
96 |
97 | Langkah selanjutnya adalah dengan membuat konfigurasi experiment pada `configs/experiment/your_exp.yaml`. Jika bingung, kamu dapat mengacu pada [`configs/experiment/demo.yaml`](./configs/experiment/demo.yaml).
98 |
99 | Setelah konfigurasi experiment dibuat. Kamu dapat dengan mudah menjalankan perintah `train.py`, seperti berikut:
100 |
101 | ```bash
102 | cd yolo3d-lightning
103 | python train.py \
104 | experiment=demo
105 | ```
106 |
107 |
108 |
109 |
110 |
111 |
112 | ## ❤️ Acknowledgement
113 |
114 | - [YOLOv5 by Ultralytics](https://github.com/ultralytics/yolov5)
115 | - [skhadem/3D-BoundingBox](https://github.com/skhadem/3D-BoundingBox)
116 | - [Mousavian et al.](https://arxiv.org/abs/1612.00496)
117 | ```
118 | @misc{mousavian20173d,
119 | title={3D Bounding Box Estimation Using Deep Learning and Geometry},
120 | author={Arsalan Mousavian and Dragomir Anguelov and John Flynn and Jana Kosecka},
121 | year={2017},
122 | eprint={1612.00496},
123 | archivePrefix={arXiv},
124 | primaryClass={cs.CV}
125 | }
126 | ```
--------------------------------------------------------------------------------
/docs/javascripts/mathjax.js:
--------------------------------------------------------------------------------
1 | window.MathJax = {
2 | tex: {
3 | inlineMath: [["\\(", "\\)"]],
4 | displayMath: [["\\[", "\\]"]],
5 | processEscapes: true,
6 | processEnvironments: true
7 | },
8 | options: {
9 | ignoreHtmlClass: ".*|",
10 | processHtmlClass: "arithmatex"
11 | }
12 | };
13 |
14 | document$.subscribe(() => {
15 |
16 |
17 | MathJax.typesetPromise()
18 | })
--------------------------------------------------------------------------------
/kitti_object_eval/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/kitti_object_eval/README.md:
--------------------------------------------------------------------------------
1 | # Note
2 |
3 | This code is from [traveller59/kitti-object-eval-python](https://github.com/traveller59/kitti-object-eval-python)
4 |
5 | # kitti-object-eval-python
6 | Fast kitti object detection eval in python(finish eval in less than 10 second), support 2d/bev/3d/aos. , support coco-style AP. If you use command line interface, numba need some time to compile jit functions.
7 |
8 | _WARNING_: The "coco" isn't official metrics. Only "AP(Average Precision)" is.
9 | ## Dependencies
10 | Only support python 3.6+, need `numpy`, `skimage`, `numba`, `fire`, `scipy`. If you have Anaconda, just install `cudatoolkit` in anaconda. Otherwise, please reference to this [page](https://github.com/numba/numba#custom-python-environments) to set up llvm and cuda for numba.
11 | * Install by conda:
12 | ```
13 | conda install -c numba cudatoolkit=x.x (8.0, 9.0, 10.0, depend on your environment)
14 | ```
15 | ## Usage
16 | * commandline interface:
17 | ```
18 | python evaluate.py evaluate --label_path=/path/to/your_gt_label_folder --result_path=/path/to/your_result_folder --label_split_file=/path/to/val.txt --current_class=0 --coco=False
19 | ```
20 | * python interface:
21 | ```Python
22 | import kitti_common as kitti
23 | from eval import get_official_eval_result, get_coco_eval_result
24 | def _read_imageset_file(path):
25 | with open(path, 'r') as f:
26 | lines = f.readlines()
27 | return [int(line) for line in lines]
28 | det_path = "/path/to/your_result_folder"
29 | dt_annos = kitti.get_label_annos(det_path)
30 | gt_path = "/path/to/your_gt_label_folder"
31 | gt_split_file = "/path/to/val.txt" # from https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz
32 | val_image_ids = _read_imageset_file(gt_split_file)
33 | gt_annos = kitti.get_label_annos(gt_path, val_image_ids)
34 | print(get_official_eval_result(gt_annos, dt_annos, 0)) # 6s in my computer
35 | print(get_coco_eval_result(gt_annos, dt_annos, 0)) # 18s in my computer
36 | ```
37 |
--------------------------------------------------------------------------------
/kitti_object_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | import time
2 | import fire
3 | import kitti_common as kitti
4 | from eval import get_official_eval_result, get_coco_eval_result
5 |
6 |
7 | def _read_imageset_file(path):
8 | with open(path, 'r') as f:
9 | lines = f.readlines()
10 | return [int(line) for line in lines]
11 |
12 |
13 | def evaluate(label_path, # gt
14 | result_path, # dt
15 | label_split_file,
16 | current_class=0, # 0: bbox, 1: bev, 2: 3d
17 | coco=False,
18 | score_thresh=-1):
19 | dt_annos = kitti.get_label_annos(result_path)
20 | # print("dt_annos[0] is ", dt_annos[0], " shape is ", len(dt_annos))
21 |
22 | # if score_thresh > 0:
23 | # dt_annos = kitti.filter_annos_low_score(dt_annos, score_thresh)
24 | # val_image_ids = _read_imageset_file(label_split_file)
25 |
26 | gt_annos = kitti.get_label_annos(label_path)
27 | # print("gt_annos[0] is ", gt_annos[0], " shape is ", len(gt_annos))
28 |
29 | if coco:
30 | print(get_coco_eval_result(gt_annos, dt_annos, current_class))
31 | else:
32 | print("not coco")
33 | print(get_official_eval_result(gt_annos, dt_annos, current_class))
34 |
35 |
36 | if __name__ == '__main__':
37 | fire.Fire()
38 |
--------------------------------------------------------------------------------
/kitti_object_eval/rotate_iou.py:
--------------------------------------------------------------------------------
1 | #####################
2 | # Based on https://github.com/hongzhenwang/RRPN-revise
3 | # Licensed under The MIT License
4 | # Author: yanyan, scrin@foxmail.com
5 | #####################
6 | import math
7 |
8 | import numba
9 | import numpy as np
10 | from numba import cuda
11 |
12 | @numba.jit(nopython=True)
13 | def div_up(m, n):
14 | return m // n + (m % n > 0)
15 |
16 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
17 | def trangle_area(a, b, c):
18 | return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) *
19 | (b[0] - c[0])) / 2.0
20 |
21 |
22 | @cuda.jit('(float32[:], int32)', device=True, inline=True)
23 | def area(int_pts, num_of_inter):
24 | area_val = 0.0
25 | for i in range(num_of_inter - 2):
26 | area_val += abs(
27 | trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4],
28 | int_pts[2 * i + 4:2 * i + 6]))
29 | return area_val
30 |
31 |
32 | @cuda.jit('(float32[:], int32)', device=True, inline=True)
33 | def sort_vertex_in_convex_polygon(int_pts, num_of_inter):
34 | if num_of_inter > 0:
35 | center = cuda.local.array((2, ), dtype=numba.float32)
36 | center[:] = 0.0
37 | for i in range(num_of_inter):
38 | center[0] += int_pts[2 * i]
39 | center[1] += int_pts[2 * i + 1]
40 | center[0] /= num_of_inter
41 | center[1] /= num_of_inter
42 | v = cuda.local.array((2, ), dtype=numba.float32)
43 | vs = cuda.local.array((16, ), dtype=numba.float32)
44 | for i in range(num_of_inter):
45 | v[0] = int_pts[2 * i] - center[0]
46 | v[1] = int_pts[2 * i + 1] - center[1]
47 | d = math.sqrt(v[0] * v[0] + v[1] * v[1])
48 | v[0] = v[0] / d
49 | v[1] = v[1] / d
50 | if v[1] < 0:
51 | v[0] = -2 - v[0]
52 | vs[i] = v[0]
53 | j = 0
54 | temp = 0
55 | for i in range(1, num_of_inter):
56 | if vs[i - 1] > vs[i]:
57 | temp = vs[i]
58 | tx = int_pts[2 * i]
59 | ty = int_pts[2 * i + 1]
60 | j = i
61 | while j > 0 and vs[j - 1] > temp:
62 | vs[j] = vs[j - 1]
63 | int_pts[j * 2] = int_pts[j * 2 - 2]
64 | int_pts[j * 2 + 1] = int_pts[j * 2 - 1]
65 | j -= 1
66 |
67 | vs[j] = temp
68 | int_pts[j * 2] = tx
69 | int_pts[j * 2 + 1] = ty
70 |
71 |
72 | @cuda.jit(
73 | '(float32[:], float32[:], int32, int32, float32[:])',
74 | device=True,
75 | inline=True)
76 | def line_segment_intersection(pts1, pts2, i, j, temp_pts):
77 | A = cuda.local.array((2, ), dtype=numba.float32)
78 | B = cuda.local.array((2, ), dtype=numba.float32)
79 | C = cuda.local.array((2, ), dtype=numba.float32)
80 | D = cuda.local.array((2, ), dtype=numba.float32)
81 |
82 | A[0] = pts1[2 * i]
83 | A[1] = pts1[2 * i + 1]
84 |
85 | B[0] = pts1[2 * ((i + 1) % 4)]
86 | B[1] = pts1[2 * ((i + 1) % 4) + 1]
87 |
88 | C[0] = pts2[2 * j]
89 | C[1] = pts2[2 * j + 1]
90 |
91 | D[0] = pts2[2 * ((j + 1) % 4)]
92 | D[1] = pts2[2 * ((j + 1) % 4) + 1]
93 | BA0 = B[0] - A[0]
94 | BA1 = B[1] - A[1]
95 | DA0 = D[0] - A[0]
96 | CA0 = C[0] - A[0]
97 | DA1 = D[1] - A[1]
98 | CA1 = C[1] - A[1]
99 | acd = DA1 * CA0 > CA1 * DA0
100 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
101 | if acd != bcd:
102 | abc = CA1 * BA0 > BA1 * CA0
103 | abd = DA1 * BA0 > BA1 * DA0
104 | if abc != abd:
105 | DC0 = D[0] - C[0]
106 | DC1 = D[1] - C[1]
107 | ABBA = A[0] * B[1] - B[0] * A[1]
108 | CDDC = C[0] * D[1] - D[0] * C[1]
109 | DH = BA1 * DC0 - BA0 * DC1
110 | Dx = ABBA * DC0 - BA0 * CDDC
111 | Dy = ABBA * DC1 - BA1 * CDDC
112 | temp_pts[0] = Dx / DH
113 | temp_pts[1] = Dy / DH
114 | return True
115 | return False
116 |
117 |
118 | @cuda.jit(
119 | '(float32[:], float32[:], int32, int32, float32[:])',
120 | device=True,
121 | inline=True)
122 | def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts):
123 | a = cuda.local.array((2, ), dtype=numba.float32)
124 | b = cuda.local.array((2, ), dtype=numba.float32)
125 | c = cuda.local.array((2, ), dtype=numba.float32)
126 | d = cuda.local.array((2, ), dtype=numba.float32)
127 |
128 | a[0] = pts1[2 * i]
129 | a[1] = pts1[2 * i + 1]
130 |
131 | b[0] = pts1[2 * ((i + 1) % 4)]
132 | b[1] = pts1[2 * ((i + 1) % 4) + 1]
133 |
134 | c[0] = pts2[2 * j]
135 | c[1] = pts2[2 * j + 1]
136 |
137 | d[0] = pts2[2 * ((j + 1) % 4)]
138 | d[1] = pts2[2 * ((j + 1) % 4) + 1]
139 |
140 | area_abc = trangle_area(a, b, c)
141 | area_abd = trangle_area(a, b, d)
142 |
143 | if area_abc * area_abd >= 0:
144 | return False
145 |
146 | area_cda = trangle_area(c, d, a)
147 | area_cdb = area_cda + area_abc - area_abd
148 |
149 | if area_cda * area_cdb >= 0:
150 | return False
151 | t = area_cda / (area_abd - area_abc)
152 |
153 | dx = t * (b[0] - a[0])
154 | dy = t * (b[1] - a[1])
155 | temp_pts[0] = a[0] + dx
156 | temp_pts[1] = a[1] + dy
157 | return True
158 |
159 |
160 | @cuda.jit('(float32, float32, float32[:])', device=True, inline=True)
161 | def point_in_quadrilateral(pt_x, pt_y, corners):
162 | ab0 = corners[2] - corners[0]
163 | ab1 = corners[3] - corners[1]
164 |
165 | ad0 = corners[6] - corners[0]
166 | ad1 = corners[7] - corners[1]
167 |
168 | ap0 = pt_x - corners[0]
169 | ap1 = pt_y - corners[1]
170 |
171 | abab = ab0 * ab0 + ab1 * ab1
172 | abap = ab0 * ap0 + ab1 * ap1
173 | adad = ad0 * ad0 + ad1 * ad1
174 | adap = ad0 * ap0 + ad1 * ap1
175 |
176 | return abab >= abap and abap >= 0 and adad >= adap and adap >= 0
177 |
178 |
179 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
180 | def quadrilateral_intersection(pts1, pts2, int_pts):
181 | num_of_inter = 0
182 | for i in range(4):
183 | if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2):
184 | int_pts[num_of_inter * 2] = pts1[2 * i]
185 | int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1]
186 | num_of_inter += 1
187 | if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1):
188 | int_pts[num_of_inter * 2] = pts2[2 * i]
189 | int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1]
190 | num_of_inter += 1
191 | temp_pts = cuda.local.array((2, ), dtype=numba.float32)
192 | for i in range(4):
193 | for j in range(4):
194 | has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts)
195 | if has_pts:
196 | int_pts[num_of_inter * 2] = temp_pts[0]
197 | int_pts[num_of_inter * 2 + 1] = temp_pts[1]
198 | num_of_inter += 1
199 |
200 | return num_of_inter
201 |
202 |
203 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True)
204 | def rbbox_to_corners(corners, rbbox):
205 | # generate clockwise corners and rotate it clockwise
206 | angle = rbbox[4]
207 | a_cos = math.cos(angle)
208 | a_sin = math.sin(angle)
209 | center_x = rbbox[0]
210 | center_y = rbbox[1]
211 | x_d = rbbox[2]
212 | y_d = rbbox[3]
213 | corners_x = cuda.local.array((4, ), dtype=numba.float32)
214 | corners_y = cuda.local.array((4, ), dtype=numba.float32)
215 | corners_x[0] = -x_d / 2
216 | corners_x[1] = -x_d / 2
217 | corners_x[2] = x_d / 2
218 | corners_x[3] = x_d / 2
219 | corners_y[0] = -y_d / 2
220 | corners_y[1] = y_d / 2
221 | corners_y[2] = y_d / 2
222 | corners_y[3] = -y_d / 2
223 | for i in range(4):
224 | corners[2 *
225 | i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x
226 | corners[2 * i
227 | + 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y
228 |
229 |
230 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True)
231 | def inter(rbbox1, rbbox2):
232 | corners1 = cuda.local.array((8, ), dtype=numba.float32)
233 | corners2 = cuda.local.array((8, ), dtype=numba.float32)
234 | intersection_corners = cuda.local.array((16, ), dtype=numba.float32)
235 |
236 | rbbox_to_corners(corners1, rbbox1)
237 | rbbox_to_corners(corners2, rbbox2)
238 |
239 | num_intersection = quadrilateral_intersection(corners1, corners2,
240 | intersection_corners)
241 | sort_vertex_in_convex_polygon(intersection_corners, num_intersection)
242 | # print(intersection_corners.reshape([-1, 2])[:num_intersection])
243 |
244 | return area(intersection_corners, num_intersection)
245 |
246 |
247 | @cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True)
248 | def devRotateIoUEval(rbox1, rbox2, criterion=-1):
249 | area1 = rbox1[2] * rbox1[3]
250 | area2 = rbox2[2] * rbox2[3]
251 | area_inter = inter(rbox1, rbox2)
252 | if criterion == -1:
253 | return area_inter / (area1 + area2 - area_inter)
254 | elif criterion == 0:
255 | return area_inter / area1
256 | elif criterion == 1:
257 | return area_inter / area2
258 | else:
259 | return area_inter
260 |
261 | @cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False)
262 | def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1):
263 | threadsPerBlock = 8 * 8
264 | row_start = cuda.blockIdx.x
265 | col_start = cuda.blockIdx.y
266 | tx = cuda.threadIdx.x
267 | row_size = min(N - row_start * threadsPerBlock, threadsPerBlock)
268 | col_size = min(K - col_start * threadsPerBlock, threadsPerBlock)
269 | block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
270 | block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
271 |
272 | dev_query_box_idx = threadsPerBlock * col_start + tx
273 | dev_box_idx = threadsPerBlock * row_start + tx
274 | if (tx < col_size):
275 | block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0]
276 | block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1]
277 | block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2]
278 | block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3]
279 | block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4]
280 | if (tx < row_size):
281 | block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0]
282 | block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1]
283 | block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2]
284 | block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3]
285 | block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4]
286 | cuda.syncthreads()
287 | if tx < row_size:
288 | for i in range(col_size):
289 | offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i
290 | dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5],
291 | block_boxes[tx * 5:tx * 5 + 5], criterion)
292 |
293 |
294 | def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0):
295 | """rotated box iou running in gpu. 500x faster than cpu version
296 | (take 5ms in one example with numba.cuda code).
297 | convert from [this project](
298 | https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation).
299 |
300 | Args:
301 | boxes (float tensor: [N, 5]): rbboxes. format: centers, dims,
302 | angles(clockwise when positive)
303 | query_boxes (float tensor: [K, 5]): [description]
304 | device_id (int, optional): Defaults to 0. [description]
305 |
306 | Returns:
307 | [type]: [description]
308 | """
309 | box_dtype = boxes.dtype
310 | boxes = boxes.astype(np.float32)
311 | query_boxes = query_boxes.astype(np.float32)
312 | N = boxes.shape[0]
313 | K = query_boxes.shape[0]
314 | iou = np.zeros((N, K), dtype=np.float32)
315 | if N == 0 or K == 0:
316 | return iou
317 | threadsPerBlock = 8 * 8
318 | cuda.select_device(device_id)
319 | blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock))
320 |
321 | stream = cuda.stream()
322 | with stream.auto_synchronize():
323 | boxes_dev = cuda.to_device(boxes.reshape([-1]), stream)
324 | query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream)
325 | iou_dev = cuda.to_device(iou.reshape([-1]), stream)
326 | rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream](
327 | N, K, boxes_dev, query_boxes_dev, iou_dev, criterion)
328 | iou_dev.copy_to_host(iou.reshape([-1]), stream=stream)
329 | return iou.astype(boxes.dtype)
--------------------------------------------------------------------------------
/kitti_object_eval/run.sh:
--------------------------------------------------------------------------------
1 | python evaluate.py evaluate \
2 | --label_path=/home/your/path/data/KITTI/label_2 \
3 | --result_path=/home/your/path/data/KITTI/result \
4 | --label_split_file=/home/your/path/data/KITTI/ImageSets/val.txt \
5 | --current_class=0,1,2
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/logs/.gitkeep
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | # Project information
2 | site_name: YOLO3D
3 | site_url: https://ruhyadi.github.io/yolo3d-lightning
4 | site_author: Didi Ruhyadi
5 | site_description: >-
6 | YOLO3D: 3D Object Detection with YOLO
7 |
8 | # Repository
9 | repo_name: ruhyadi/yolo3d-lightning
10 | repo_url: https://github.com/ruhyadi/yolo3d-lightning
11 | edit_uri: ""
12 |
13 | # Copyright
14 | copyright: Copyright © 2020 - 2022 Didi Ruhyadi
15 |
16 | # Configuration
17 | theme:
18 | name: material
19 | language: en
20 |
21 | # Don't include MkDocs' JavaScript
22 | include_search_page: false
23 | search_index_only: true
24 |
25 | features:
26 | - content.code.annotate
27 | # - content.tabs.link
28 | # - header.autohide
29 | # - navigation.expand
30 | - navigation.indexes
31 | # - navigation.instant
32 | - navigation.sections
33 | - navigation.tabs
34 | # - navigation.tabs.sticky
35 | - navigation.top
36 | - navigation.tracking
37 | - search.highlight
38 | - search.share
39 | - search.suggest
40 | # - toc.integrate
41 | palette:
42 | - scheme: default
43 | primary: white
44 | accent: indigo
45 | toggle:
46 | icon: material/weather-night
47 | name: Vampire Mode
48 | - scheme: slate
49 | primary: indigo
50 | accent: blue
51 | toggle:
52 | icon: material/weather-sunny
53 | name: Beware of Your Eyes
54 | font:
55 | text: Noto Serif
56 | code: Noto Mono
57 | favicon: assets/logo.png
58 | logo: assets/logo.png
59 | icon:
60 | repo: fontawesome/brands/github
61 |
62 | # Plugins
63 | plugins:
64 |
65 | # Customization
66 | extra:
67 | social:
68 | - icon: fontawesome/brands/github
69 | link: https://github.com/ruhyadi
70 | - icon: fontawesome/brands/docker
71 | link: https://hub.docker.com/r/ruhyadi
72 | - icon: fontawesome/brands/twitter
73 | link: https://twitter.com/
74 | - icon: fontawesome/brands/linkedin
75 | link: https://linkedin.com/in/didiruhyadi
76 | - icon: fontawesome/brands/instagram
77 | link: https://instagram.com/didiir_
78 |
79 | extra_javascript:
80 | - javascripts/mathjax.js
81 | - https://polyfill.io/v3/polyfill.min.js?features=es6
82 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
83 |
84 | # Extensions
85 | markdown_extensions:
86 | - admonition
87 | - abbr
88 | - pymdownx.snippets
89 | - attr_list
90 | - def_list
91 | - footnotes
92 | - meta
93 | - md_in_html
94 | - toc:
95 | permalink: true
96 | - pymdownx.arithmatex:
97 | generic: true
98 | - pymdownx.betterem:
99 | smart_enable: all
100 | - pymdownx.caret
101 | - pymdownx.details
102 | - pymdownx.emoji:
103 | emoji_index: !!python/name:materialx.emoji.twemoji
104 | emoji_generator: !!python/name:materialx.emoji.to_svg
105 | - pymdownx.highlight:
106 | anchor_linenums: true
107 | - pymdownx.inlinehilite
108 | - pymdownx.keys
109 | - pymdownx.magiclink:
110 | repo_url_shorthand: true
111 | user: squidfunk
112 | repo: mkdocs-material
113 | - pymdownx.mark
114 | - pymdownx.smartsymbols
115 | - pymdownx.superfences:
116 | custom_fences:
117 | - name: mermaid
118 | class: mermaid
119 | format: !!python/name:pymdownx.superfences.fence_code_format
120 | - pymdownx.tabbed:
121 | alternate_style: true
122 | - pymdownx.tasklist:
123 | custom_checkbox: true
124 | - pymdownx.tilde
125 |
126 | # Page tree
127 | nav:
128 | - Home:
129 | - Home: index.md
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/notebooks/.gitkeep
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | addopts = [
3 | "--color=yes",
4 | "--durations=0",
5 | "--strict-markers",
6 | "--doctest-modules",
7 | ]
8 | filterwarnings = [
9 | "ignore::DeprecationWarning",
10 | "ignore::UserWarning",
11 | ]
12 | log_cli = "True"
13 | markers = [
14 | "slow: slow tests",
15 | ]
16 | minversion = "6.0"
17 | testpaths = "tests/"
18 |
19 | [tool.coverage.report]
20 | exclude_lines = [
21 | "pragma: nocover",
22 | "raise NotImplementedError",
23 | "raise NotImplementedError()",
24 | "if __name__ == .__main__.:",
25 | ]
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # --------- pytorch --------- #
2 | torch>=1.8.0
3 | torchvision>=0.9.1
4 | pytorch-lightning==1.6.5
5 | torchmetrics==0.9.2
6 |
7 | # --------- hydra --------- #
8 | hydra-core==1.2.0
9 | hydra-colorlog==1.2.0
10 | hydra-optuna-sweeper==1.2.0
11 |
12 | # --------- loggers --------- #
13 | # wandb
14 | # neptune-client
15 | # mlflow
16 | # comet-ml
17 |
18 | # --------- others --------- #
19 | pyrootutils # standardizing the project root setup
20 | pre-commit # hooks for applying linters on commit
21 | rich # beautiful text formatting in terminal
22 | pytest # tests
23 | sh # for running bash commands in some tests
24 |
--------------------------------------------------------------------------------
/scripts/frames_to_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate frames to vid
3 | Usage:
4 | python scripts/frames_to_video.py \
5 | --imgs_path /path/to/imgs \
6 | --vid_path /path/to/vid \
7 | --fps 24 \
8 | --frame_size 1242 375 \
9 | --resize
10 |
11 | python scripts/frames_to_video.py \
12 | --imgs_path outputs/2023-05-13/22-51-34/inference \
13 | --vid_path tmp/output_videos/001.mp4 \
14 | --fps 3 \
15 | --frame_size 1550 387 \
16 | --resize
17 | """
18 |
19 | import argparse
20 | import cv2
21 | from glob import glob
22 | import os
23 | from tqdm import tqdm
24 |
25 | def generate(imgs_path, vid_path, fps=30, frame_size=(1242, 375), resize=True):
26 | """Generate frames to vid"""
27 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
28 | vid_writer = cv2.VideoWriter(vid_path, fourcc, fps, frame_size)
29 | imgs_glob = sorted(glob(os.path.join(imgs_path, "*.png")))
30 | if resize:
31 | for img_path in tqdm(imgs_glob):
32 | img = cv2.imread(img_path)
33 | img = cv2.resize(img, frame_size)
34 | vid_writer.write(img)
35 | else:
36 | for img_path in imgs_glob:
37 | img = cv2.imread(img_path, cv2.IMREAD_COLOR)
38 | vid_writer.write(img)
39 | vid_writer.release()
40 | print('[INFO] Video saved to {}'.format(vid_path))
41 |
42 | if __name__ == "__main__":
43 | # create argparser
44 | parser = argparse.ArgumentParser(description="Generate frames to vid")
45 | parser.add_argument("--imgs_path", type=str, default="outputs/2022-10-23/21-03-50/inference", help="path to imgs")
46 | parser.add_argument("--vid_path", type=str, default="outputs/videos/004.mp4", help="path to vid")
47 | parser.add_argument("--fps", type=int, default=24, help="fps")
48 | parser.add_argument("--frame_size", type=int, nargs=2, default=(int(1242), int(375)), help="frame size")
49 | parser.add_argument("--resize", action="store_true", help="resize")
50 | args = parser.parse_args()
51 |
52 | # generate vid
53 | generate(args.imgs_path, args.vid_path, args.fps, args.frame_size)
--------------------------------------------------------------------------------
/scripts/generate_sets.py:
--------------------------------------------------------------------------------
1 | """Create training and validation sets"""
2 |
3 | from glob import glob
4 | import os
5 | import argparse
6 |
7 |
8 | def generate_sets(
9 | images_path: str,
10 | dump_dir: str,
11 | postfix: str = "",
12 | train_size: float = 0.8,
13 | is_yolo: bool = False,
14 | ):
15 | images = glob(os.path.join(images_path, "*.png"))
16 | ids = [id_.split("/")[-1].split(".")[0] for id_ in images]
17 |
18 | train_sets = sorted(ids[: int(len(ids) * train_size)])
19 | val_sets = sorted(ids[int(len(ids) * train_size) :])
20 |
21 | for name, sets in zip(["train", "val"], [train_sets, val_sets]):
22 | name = os.path.join(dump_dir, f"{name}{postfix}.txt")
23 | with open(name, "w") as f:
24 | for id in sets:
25 | if is_yolo:
26 | f.write(f"./images/{id}.png\n")
27 | else:
28 | f.write(f"{id}\n")
29 |
30 | print(f"[INFO] Training set: {len(train_sets)}")
31 | print(f"[INFO] Validation set: {len(val_sets)}")
32 | print(f"[INFO] Total: {len(train_sets) + len(val_sets)}")
33 | print(f"[INFO] Success Generate Sets")
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = argparse.ArgumentParser(description="Create training and validation sets")
38 | parser.add_argument("--images_path", type=str, default="./data/KITTI/images")
39 | parser.add_argument("--dump_dir", type=str, default="./data/KITTI")
40 | parser.add_argument("--postfix", type=str, default="_95")
41 | parser.add_argument("--train_size", type=float, default=0.95)
42 | parser.add_argument("--is_yolo", action="store_true")
43 | args = parser.parse_args()
44 |
45 | generate_sets(
46 | images_path=args.images_path,
47 | dump_dir=args.dump_dir,
48 | postfix=args.postfix,
49 | train_size=args.train_size,
50 | is_yolo=False,
51 | )
52 |
--------------------------------------------------------------------------------
/scripts/get_weights.py:
--------------------------------------------------------------------------------
1 | """Download pretrained weights from github release"""
2 |
3 | from pprint import pprint
4 | import requests
5 | import os
6 | import shutil
7 | import argparse
8 | from zipfile import ZipFile
9 |
10 | def get_assets(tag):
11 | """Get release assets by tag name"""
12 | url = 'https://api.github.com/repos/ruhyadi/yolo3d-lightning/releases/tags/' + tag
13 | response = requests.get(url)
14 | return response.json()['assets']
15 |
16 | def download_assets(assets, dir):
17 | """Download assets to dir"""
18 | for asset in assets:
19 | url = asset['browser_download_url']
20 | filename = asset['name']
21 | print('[INFO] Downloading {}'.format(filename))
22 | response = requests.get(url, stream=True)
23 | with open(os.path.join(dir, filename), 'wb') as f:
24 | shutil.copyfileobj(response.raw, f)
25 | del response
26 |
27 | with ZipFile(os.path.join(dir, filename), 'r') as zip_file:
28 | zip_file.extractall(dir)
29 | os.remove(os.path.join(dir, filename))
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description='Download pretrained weights')
33 | parser.add_argument('--tag', type=str, default='v0.1', help='tag name')
34 | parser.add_argument('--dir', type=str, default='./', help='directory to save weights')
35 | args = parser.parse_args()
36 |
37 | assets = get_assets(args.tag)
38 | download_assets(assets, args.dir)
39 |
--------------------------------------------------------------------------------
/scripts/kitti_to_yolo.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert KITTI format to YOLO format.
3 | """
4 |
5 | import os
6 | import numpy as np
7 | from glob import glob
8 | from tqdm import tqdm
9 | import argparse
10 |
11 | from typing import Tuple
12 |
13 |
14 | class KITTI2YOLO:
15 | def __init__(
16 | self,
17 | dataset_path: str = "../data/KITTI",
18 | classes: Tuple = ["car", "van", "truck", "pedestrian", "cyclist"],
19 | img_width: int = 1224,
20 | img_height: int = 370,
21 | ):
22 |
23 | self.dataset_path = dataset_path
24 | self.img_width = img_width
25 | self.img_height = img_height
26 | self.classes = classes
27 | self.ids = {self.classes[i]: i for i in range(len(self.classes))}
28 |
29 | # create new directory
30 | self.label_path = os.path.join(self.dataset_path, "labels")
31 | if not os.path.isdir(self.label_path):
32 | os.makedirs(self.label_path)
33 | else:
34 | print("[INFO] Directory already exist...")
35 |
36 | def convert(self):
37 | files = glob(os.path.join(self.dataset_path, "label_2", "*.txt"))
38 | for file in tqdm(files):
39 | with open(file, "r") as f:
40 | filename = os.path.join(self.label_path, file.split("/")[-1])
41 | dump_txt = open(filename, "w")
42 | for line in f:
43 | parse_line = self.parse_line(line)
44 | if parse_line["name"].lower() not in self.classes:
45 | continue
46 |
47 | xmin, ymin, xmax, ymax = parse_line["bbox_camera"]
48 | xcenter = ((xmax - xmin) / 2 + xmin) / self.img_width
49 | if xcenter > 1.0:
50 | xcenter = 1.0
51 | ycenter = ((ymax - ymin) / 2 + ymin) / self.img_height
52 | if ycenter > 1.0:
53 | ycenter = 1.0
54 | width = (xmax - xmin) / self.img_width
55 | if width > 1.0:
56 | width = 1.0
57 | height = (ymax - ymin) / self.img_height
58 | if height > 1.0:
59 | height = 1.0
60 |
61 | bbox_yolo = f"{self.ids[parse_line['name'].lower()]} {xcenter:.3f} {ycenter:.3f} {width:.3f} {height:.3f}"
62 | dump_txt.write(bbox_yolo + "\n")
63 |
64 | dump_txt.close()
65 |
66 | def parse_line(self, line):
67 | parts = line.split(" ")
68 | output = {
69 | "name": parts[0].strip(),
70 | "xyz_camera": (float(parts[11]), float(parts[12]), float(parts[13])),
71 | "wlh": (float(parts[9]), float(parts[10]), float(parts[8])),
72 | "yaw_camera": float(parts[14]),
73 | "bbox_camera": (
74 | float(parts[4]),
75 | float(parts[5]),
76 | float(parts[6]),
77 | float(parts[7]),
78 | ),
79 | "truncation": float(parts[1]),
80 | "occlusion": float(parts[2]),
81 | "alpha": float(parts[3]),
82 | }
83 |
84 | # Add score if specified
85 | if len(parts) > 15:
86 | output["score"] = float(parts[15])
87 | else:
88 | output["score"] = np.nan
89 |
90 | return output
91 |
92 |
93 | if __name__ == "__main__":
94 |
95 | # argparser
96 | parser = argparse.ArgumentParser(description="KITTI to YOLO Convertion")
97 | parser.add_argument("--dataset_path", type=str, default="../data/KITTI")
98 | parser.add_argument(
99 | "--classes",
100 | type=Tuple,
101 | default=["car", "van", "truck", "pedestrian", "cyclist"],
102 | )
103 | parser.add_argument("--img_width", type=int, default=1224)
104 | parser.add_argument("--img_height", type=int, default=370)
105 | args = parser.parse_args()
106 |
107 | kitit2yolo = KITTI2YOLO(
108 | dataset_path=args.dataset_path,
109 | classes=args.classes,
110 | img_width=args.img_width,
111 | img_height=args.img_height,
112 | )
113 | kitit2yolo.convert()
114 |
--------------------------------------------------------------------------------
/scripts/post_weights.py:
--------------------------------------------------------------------------------
1 | """Upload weights to github release"""
2 |
3 | from pprint import pprint
4 | import requests
5 | import os
6 | import dotenv
7 | import argparse
8 | from zipfile import ZipFile
9 |
10 | dotenv.load_dotenv()
11 |
12 |
13 | def create_release(tag, name, description, target="main"):
14 | """Create release"""
15 | token = os.environ.get("GITHUB_TOKEN")
16 | headers = {
17 | "Accept": "application/vnd.github.v3+json",
18 | "Authorization": f"token {token}",
19 | "Content-Type": "application/zip"
20 | }
21 | url = "https://api.github.com/repos/ruhyadi/yolo3d-lightning/releases"
22 | payload = {
23 | "tag_name": tag,
24 | "target_commitish": target,
25 | "name": name,
26 | "body": description,
27 | "draft": True,
28 | "prerelease": False,
29 | "generate_release_notes": True,
30 | }
31 | print("[INFO] Creating release {}".format(tag))
32 | response = requests.post(url, json=payload, headers=headers)
33 | print("[INFO] Release created id: {}".format(response.json()["id"]))
34 |
35 | return response.json()
36 |
37 |
38 | def post_assets(assets, release_id):
39 | """Post assets to release"""
40 | token = os.environ.get("GITHUB_TOKEN")
41 | headers = {
42 | "Accept": "application/vnd.github.v3+json",
43 | "Authorization": f"token {token}",
44 | "Content-Type": "application/zip"
45 | }
46 | for asset in assets:
47 | asset_path = os.path.join(os.getcwd(), asset)
48 | with ZipFile(f"{asset_path}.zip", "w") as zip_file:
49 | zip_file.write(asset)
50 | asset_path = f"{asset_path}.zip"
51 | filename = asset_path.split("/")[-1]
52 | url = (
53 | "https://uploads.github.com/repos/ruhyadi/yolo3d-lightning/releases/"
54 | + str(release_id)
55 | + f"/assets?name={filename}"
56 | )
57 | print("[INFO] Uploading {}".format(filename))
58 | response = requests.post(url, files={"name": open(asset_path, "rb")}, headers=headers)
59 | pprint(response.json())
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser(description="Upload weights to github release")
64 | parser.add_argument("--tag", type=str, default="v0.6", help="tag name")
65 | parser.add_argument("--name", type=str, default="Release v0.6", help="release name")
66 | parser.add_argument("--description", type=str, default="v0.6", help="release description")
67 | parser.add_argument("--assets", type=tuple, default=["weights/mobilenetv3-best.pt", "weights/mobilenetv3-last.pt", "logs/train/runs/2022-09-28_10-36-08/checkpoints/epoch_007.ckpt", "logs/train/runs/2022-09-28_10-36-08/checkpoints/last.ckpt"], help="directory to save weights",)
68 | args = parser.parse_args()
69 |
70 | release_id = create_release(args.tag, args.name, args.description)["id"]
71 | post_assets(args.assets, release_id)
72 |
--------------------------------------------------------------------------------
/scripts/schedule.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Schedule execution of many runs
3 | # Run from root folder with: bash scripts/schedule.sh
4 |
5 | python src/train.py trainer.max_epochs=5 logger=csv
6 |
7 | python src/train.py trainer.max_epochs=10 logger=csv
8 |
--------------------------------------------------------------------------------
/scripts/video_to_frame.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert video to frame
3 | Usage:
4 | python video_to_frame.py \
5 | --video_path /path/to/video \
6 | --output_path /path/to/output/folder \
7 | --fps 24
8 |
9 | python scripts/video_to_frame.py \
10 | --video_path tmp/video/20230513_100429.mp4 \
11 | --output_path tmp/video_001 \
12 | --fps 20
13 | """
14 |
15 | import argparse
16 |
17 | import os
18 | import cv2
19 |
20 |
21 | def video_to_frame(video_path: str, output_path: str, fps: int = 5):
22 | """
23 | Convert video to frame
24 |
25 | Args:
26 | video_path: path to video
27 | output_path: path to output folder
28 | fps: how many frames per second to save
29 | """
30 | if not os.path.exists(output_path):
31 | os.makedirs(output_path)
32 |
33 | cap = cv2.VideoCapture(video_path)
34 | frame_count = 0
35 | while cap.isOpened():
36 | ret, frame = cap.read()
37 | if not ret:
38 | break
39 | if frame_count % fps == 0:
40 | cv2.imwrite(os.path.join(output_path, f"{frame_count:06d}.jpg"), frame)
41 | frame_count += 1
42 |
43 | cap.release()
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("--video_path", type=str, required=True)
48 | parser.add_argument("--output_path", type=str, required=True)
49 | parser.add_argument("--fps", type=int, default=30)
50 | args = parser.parse_args()
51 |
52 | video_to_frame(args.video_path, args.output_path, args.fps)
--------------------------------------------------------------------------------
/scripts/video_to_gif.py:
--------------------------------------------------------------------------------
1 | """Convert video to gif with moviepy"""
2 |
3 | import argparse
4 | import moviepy.editor as mpy
5 |
6 | def generate(video_path, gif_path, fps):
7 | """Generate gif from video"""
8 | clip = mpy.VideoFileClip(video_path)
9 | clip.write_gif(gif_path, fps=fps)
10 | clip.close()
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser(description="Convert video to gif")
14 | parser.add_argument("--video_path", type=str, default="outputs/videos/004.mp4", help="Path to video")
15 | parser.add_argument("--gif_path", type=str, default="outputs/gif/002.gif", help="Path to gif")
16 | parser.add_argument("--fps", type=int, default=5, help="GIF fps")
17 | args = parser.parse_args()
18 |
19 | # generate gif
20 | generate(args.video_path, args.gif_path, args.fps)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | setup(
6 | name="src",
7 | version="0.0.1",
8 | description="Describe Your Cool Project",
9 | author="",
10 | author_email="",
11 | url="https://github.com/user/project", # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
12 | install_requires=["pytorch-lightning", "hydra-core"],
13 | packages=find_packages(),
14 | )
15 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/__init__.py
--------------------------------------------------------------------------------
/src/datamodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/datamodules/__init__.py
--------------------------------------------------------------------------------
/src/datamodules/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/datamodules/components/__init__.py
--------------------------------------------------------------------------------
/src/datamodules/kitti_datamodule.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset lightning class
3 | """
4 |
5 | from pytorch_lightning import LightningDataModule
6 | from torch.utils.data import DataLoader
7 | from torchvision.transforms import transforms
8 |
9 | from src.datamodules.components.kitti_dataset import KITTIDataset, KITTIDataset2, KITTIDataset3
10 |
11 | class KITTIDataModule(LightningDataModule):
12 | def __init__(
13 | self,
14 | dataset_path: str = './data/KITTI',
15 | train_sets: str = './data/KITTI/train.txt',
16 | val_sets: str = './data/KITTI/val.txt',
17 | test_sets: str = './data/KITTI/test.txt',
18 | batch_size: int = 32,
19 | num_worker: int = 4,
20 | ):
21 | super().__init__()
22 |
23 | # save hyperparameters
24 | self.save_hyperparameters(logger=False)
25 |
26 | # transforms
27 | # TODO: using albumentations
28 | self.dataset_transforms = transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.1307,), (0.3081,))
31 | ])
32 |
33 | def setup(self, stage=None):
34 | """ Split dataset to training and validation """
35 | self.KITTI_train = KITTIDataset(self.hparams.dataset_path, self.hparams.train_sets)
36 | self.KITTI_val = KITTIDataset(self.hparams.dataset_path, self.hparams.val_sets)
37 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets)
38 | # TODO: add test datasets dan test sets
39 |
40 | def train_dataloader(self):
41 | return DataLoader(
42 | dataset=self.KITTI_train,
43 | batch_size=self.hparams.batch_size,
44 | num_workers=self.hparams.num_worker,
45 | shuffle=True
46 | )
47 |
48 | def val_dataloader(self):
49 | return DataLoader(
50 | dataset=self.KITTI_val,
51 | batch_size=self.hparams.batch_size,
52 | num_workers=self.hparams.num_worker,
53 | shuffle=False
54 | )
55 |
56 | # def test_dataloader(self):
57 | # return DataLoader(
58 | # dataset=self.KITTI_test,
59 | # batch_size=self.hparams.batch_size,
60 | # num_workers=self.hparams.num_worker,
61 | # shuffle=False
62 | # )
63 |
64 | class KITTIDataModule2(LightningDataModule):
65 | def __init__(
66 | self,
67 | dataset_path: str = './data/KITTI',
68 | train_sets: str = './data/KITTI/train.txt',
69 | val_sets: str = './data/KITTI/val.txt',
70 | test_sets: str = './data/KITTI/test.txt',
71 | batch_size: int = 32,
72 | num_worker: int = 4,
73 | ):
74 | super().__init__()
75 |
76 | # save hyperparameters
77 | self.save_hyperparameters(logger=False)
78 |
79 | def setup(self, stage=None):
80 | """ Split dataset to training and validation """
81 | self.KITTI_train = KITTIDataset2(self.hparams.dataset_path, self.hparams.train_sets)
82 | self.KITTI_val = KITTIDataset2(self.hparams.dataset_path, self.hparams.val_sets)
83 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets)
84 | # TODO: add test datasets dan test sets
85 |
86 | def train_dataloader(self):
87 | return DataLoader(
88 | dataset=self.KITTI_train,
89 | batch_size=self.hparams.batch_size,
90 | num_workers=self.hparams.num_worker,
91 | shuffle=True
92 | )
93 |
94 | def val_dataloader(self):
95 | return DataLoader(
96 | dataset=self.KITTI_val,
97 | batch_size=self.hparams.batch_size,
98 | num_workers=self.hparams.num_worker,
99 | shuffle=False
100 | )
101 |
102 | class KITTIDataModule3(LightningDataModule):
103 | def __init__(
104 | self,
105 | dataset_path: str = './data/KITTI',
106 | train_sets: str = './data/KITTI/train.txt',
107 | val_sets: str = './data/KITTI/val.txt',
108 | test_sets: str = './data/KITTI/test.txt',
109 | batch_size: int = 32,
110 | num_worker: int = 4,
111 | ):
112 | super().__init__()
113 |
114 | # save hyperparameters
115 | self.save_hyperparameters(logger=False)
116 |
117 | # transforms
118 | # TODO: using albumentations
119 | self.dataset_transforms = transforms.Compose([
120 | transforms.ToTensor(),
121 | transforms.Normalize((0.1307,), (0.3081,))
122 | ])
123 |
124 | def setup(self, stage=None):
125 | """ Split dataset to training and validation """
126 | self.KITTI_train = KITTIDataset3(self.hparams.dataset_path, self.hparams.train_sets)
127 | self.KITTI_val = KITTIDataset3(self.hparams.dataset_path, self.hparams.val_sets)
128 | # self.KITTI_test = KITTIDataset(self.hparams.dataset_path, self.hparams.test_sets)
129 | # TODO: add test datasets dan test sets
130 |
131 | def train_dataloader(self):
132 | return DataLoader(
133 | dataset=self.KITTI_train,
134 | batch_size=self.hparams.batch_size,
135 | num_workers=self.hparams.num_worker,
136 | shuffle=True
137 | )
138 |
139 | def val_dataloader(self):
140 | return DataLoader(
141 | dataset=self.KITTI_val,
142 | batch_size=self.hparams.batch_size,
143 | num_workers=self.hparams.num_worker,
144 | shuffle=False
145 | )
146 |
147 |
148 | if __name__ == '__main__':
149 |
150 | from time import time
151 |
152 | start1 = time()
153 | datamodule1 = KITTIDataModule(
154 | dataset_path='./data/KITTI',
155 | train_sets='./data/KITTI/train_95.txt',
156 | val_sets='./data/KITTI/val_95.txt',
157 | test_sets='./data/KITTI/test_95.txt',
158 | batch_size=5,
159 | )
160 | datamodule1.setup()
161 | trainloader = datamodule1.val_dataloader()
162 |
163 | for img, label in trainloader:
164 | print(label["Orientation"])
165 | break
166 |
167 | results1 = (time() - start1) * 1000
168 |
169 | start2 = time()
170 | datamodule2 = KITTIDataModule3(
171 | dataset_path='./data/KITTI',
172 | train_sets='./data/KITTI/train_95.txt',
173 | val_sets='./data/KITTI/val_95.txt',
174 | test_sets='./data/KITTI/test_95.txt',
175 | batch_size=5,
176 | )
177 | datamodule2.setup()
178 | trainloader = datamodule2.val_dataloader()
179 |
180 | for img, label in trainloader:
181 | print(label["orientation"])
182 | break
183 |
184 | results2 = (time() - start2) * 1000
185 |
186 | print(f'Time taken for datamodule1: {results1} ms')
187 | print(f'Time taken for datamodule2: {results2} ms')
188 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import pyrootutils
2 |
3 | root = pyrootutils.setup_root(
4 | search_from=__file__,
5 | indicator=[".git", "pyproject.toml"],
6 | pythonpath=True,
7 | dotenv=True,
8 | )
9 |
10 | # ------------------------------------------------------------------------------------ #
11 | # `pyrootutils.setup_root(...)` is recommended at the top of each start file
12 | # to make the environment more robust and consistent
13 | #
14 | # the line above searches for ".git" or "pyproject.toml" in present and parent dirs
15 | # to determine the project root dir
16 | #
17 | # adds root dir to the PYTHONPATH (if `pythonpath=True`)
18 | # so this file can be run from any place without installing project as a package
19 | #
20 | # sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
21 | # this makes all paths relative to the project root
22 | #
23 | # additionally loads environment variables from ".env" file (if `dotenv=True`)
24 | #
25 | # you can get away without using `pyrootutils.setup_root(...)` if you:
26 | # 1. move this file to the project root dir or install project as a package
27 | # 2. modify paths in "configs/paths/default.yaml" to not use PROJECT_ROOT
28 | # 3. always run this file from the project root dir
29 | #
30 | # https://github.com/ashleve/pyrootutils
31 | # ------------------------------------------------------------------------------------ #
32 |
33 | from typing import List, Tuple
34 |
35 | import hydra
36 | from omegaconf import DictConfig
37 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer
38 | from pytorch_lightning.loggers import LightningLoggerBase
39 |
40 | from src import utils
41 |
42 | log = utils.get_pylogger(__name__)
43 |
44 |
45 | @utils.task_wrapper
46 | def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
47 | """Evaluates given checkpoint on a datamodule testset.
48 |
49 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities
50 | before and after the call.
51 |
52 | Args:
53 | cfg (DictConfig): Configuration composed by Hydra.
54 |
55 | Returns:
56 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
57 | """
58 |
59 | assert cfg.ckpt_path
60 |
61 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
62 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
63 |
64 | log.info(f"Instantiating model <{cfg.model._target_}>")
65 | model: LightningModule = hydra.utils.instantiate(cfg.model)
66 |
67 | log.info("Instantiating loggers...")
68 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger"))
69 |
70 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
71 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
72 |
73 | object_dict = {
74 | "cfg": cfg,
75 | "datamodule": datamodule,
76 | "model": model,
77 | "logger": logger,
78 | "trainer": trainer,
79 | }
80 |
81 | if logger:
82 | log.info("Logging hyperparameters!")
83 | utils.log_hyperparameters(object_dict)
84 |
85 | log.info("Starting testing!")
86 | trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
87 |
88 | # for predictions use trainer.predict(...)
89 | # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
90 |
91 | metric_dict = trainer.callback_metrics
92 |
93 | return metric_dict, object_dict
94 |
95 |
96 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="eval.yaml")
97 | def main(cfg: DictConfig) -> None:
98 | evaluate(cfg)
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/src/models/components/__init__.py
--------------------------------------------------------------------------------
/src/models/components/base.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Regressor Model
3 | """
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torchvision import models
8 |
9 | class RegressorNet(nn.Module):
10 | def __init__(
11 | self,
12 | backbone: nn.Module,
13 | bins: int,
14 | ):
15 | super().__init__()
16 | # init model
17 | self.in_features = self._get_in_features(backbone)
18 | self.model = nn.Sequential(*(list(backbone.children())[:-2]))
19 | self.bins = bins
20 |
21 | # # orientation head, for orientation estimation
22 | # self.orientation = nn.Sequential(
23 | # nn.Linear(self.in_features, 256),
24 | # nn.ReLU(True),
25 | # nn.Dropout(),
26 | # nn.Linear(256, 256),
27 | # nn.ReLU(True),
28 | # nn.Dropout(),
29 | # nn.Linear(256, self.bins*2) # 4 bins
30 | # )
31 |
32 | # # confident head, for orientation estimation
33 | # self.confidence = nn.Sequential(
34 | # nn.Linear(self.in_features, 256),
35 | # nn.ReLU(True),
36 | # nn.Dropout(),
37 | # nn.Linear(256, 256),
38 | # nn.ReLU(True),
39 | # nn.Dropout(),
40 | # nn.Linear(256, self.bins),
41 | # nn.Sigmoid()
42 | # )
43 | self.orientation = nn.Sequential(
44 | nn.Linear(self.in_features, 1024),
45 | nn.ReLU(True),
46 | nn.Dropout(),
47 | nn.Linear(1024, 1024),
48 | nn.ReLU(True),
49 | nn.Dropout(),
50 | nn.Linear(1024, self.bins*2) # 4 bins
51 | )
52 |
53 | # confident head, for orientation estimation
54 | self.confidence = nn.Sequential(
55 | nn.Linear(self.in_features, 512),
56 | nn.ReLU(True),
57 | nn.Dropout(),
58 | nn.Linear(512, 512),
59 | nn.ReLU(True),
60 | nn.Dropout(),
61 | nn.Linear(512, self.bins),
62 | nn.Sigmoid()
63 | )
64 | # dimension head
65 | self.dimension = nn.Sequential(
66 | nn.Linear(self.in_features, 512),
67 | nn.ReLU(True),
68 | nn.Dropout(),
69 | nn.Linear(512, 512),
70 | nn.ReLU(True),
71 | nn.Dropout(),
72 | nn.Linear(512, 3) # x, y, z
73 | )
74 |
75 | def forward(self, x):
76 | x = self.model(x)
77 | x = x.view(-1, self.in_features)
78 |
79 | orientation = self.orientation(x)
80 | orientation = orientation.view(-1, self.bins, 2)
81 | # orientation = F.normalize(orientation, dim=2)
82 | # TODO: export model use this
83 | orientation = orientation.div(orientation.norm(dim=2, keepdim=True))
84 |
85 | confidence = self.confidence(x)
86 |
87 | dimension = self.dimension(x)
88 |
89 | return orientation, confidence, dimension
90 |
91 | def _get_in_features(self, net: nn.Module):
92 |
93 | # TODO: add more models
94 | in_features = {
95 | 'resnet': (lambda: net.fc.in_features * 7 * 7), # 512 * 7 * 7 = 25088
96 | 'vgg': (lambda: net.classifier[0].in_features), # 512 * 7 * 7 = 25088
97 | # 'mobilenetv3_large': (lambda: (net.classifier[0].in_features) * 7 * 7), # 960 * 7 * 7 = 47040
98 | 'mobilenetv3': (lambda: (net.classifier[0].in_features) * 7 * 7), # 576 * 7 * 7 = 28416
99 | }
100 |
101 | return in_features[(net.__class__.__name__).lower()]()
102 |
103 |
104 | class RegressorNet2(nn.Module):
105 | def __init__(
106 | self,
107 | backbone: nn.Module,
108 | bins: int,
109 | ):
110 | super().__init__()
111 |
112 | # init model
113 | self.in_features = self._get_in_features(backbone)
114 | self.model = nn.Sequential(*(list(backbone.children())[:-2]))
115 | self.bins = bins
116 |
117 | # orientation head, for orientation estimation
118 | # TODO: inprove 256 to 1024
119 | self.orientation = nn.Sequential(
120 | nn.Linear(self.in_features, 256),
121 | nn.LeakyReLU(0.1),
122 | nn.Dropout(),
123 | nn.Linear(256, self.bins*2), # 4 bins
124 | nn.LeakyReLU(0.1)
125 | )
126 |
127 | # confident head, for orientation estimation
128 | self.confidence = nn.Sequential(
129 | nn.Linear(self.in_features, 256),
130 | nn.LeakyReLU(0.1),
131 | nn.Dropout(),
132 | nn.Linear(256, self.bins),
133 | nn.LeakyReLU(0.1)
134 | )
135 |
136 | # dimension head
137 | self.dimension = nn.Sequential(
138 | nn.Linear(self.in_features, 512),
139 | nn.LeakyReLU(0.1),
140 | nn.Dropout(),
141 | nn.Linear(512, 3), # x, y, z
142 | nn.LeakyReLU(0.1)
143 | )
144 |
145 | def forward(self, x):
146 | x = self.model(x)
147 | x = x.view(-1, self.in_features)
148 |
149 | orientation = self.orientation(x)
150 | orientation = orientation.view(-1, self.bins, 2)
151 | # TODO: export model use this
152 | orientation = orientation.div(orientation.norm(dim=2, keepdim=True))
153 |
154 | confidence = self.confidence(x)
155 |
156 | dimension = self.dimension(x)
157 |
158 | return orientation, confidence, dimension
159 |
160 | def _get_in_features(self, net: nn.Module):
161 |
162 | # TODO: add more models
163 | in_features = {
164 | 'resnet': (lambda: net.fc.in_features * 7 * 7),
165 | 'vgg': (lambda: net.classifier[0].in_features)
166 | }
167 |
168 | return in_features[(net.__class__.__name__).lower()]()
169 |
170 |
171 | def OrientationLoss(orient_batch, orientGT_batch, confGT_batch):
172 | """
173 | Orientation loss function
174 | """
175 | batch_size = orient_batch.size()[0]
176 | indexes = torch.max(confGT_batch, dim=1)[1]
177 |
178 | # extract important bin
179 | orientGT_batch = orientGT_batch[torch.arange(batch_size), indexes]
180 | orient_batch = orient_batch[torch.arange(batch_size), indexes]
181 |
182 | theta_diff = torch.atan2(orientGT_batch[:,1], orientGT_batch[:,0])
183 | estimated_theta_diff = torch.atan2(orient_batch[:,1], orient_batch[:,0])
184 |
185 | return 2 - 2 * torch.cos(theta_diff - estimated_theta_diff).mean()
186 | # return -torch.cos(theta_diff - estimated_theta_diff).mean()
187 |
188 |
189 | def orientation_loss2(y_pred, y_true):
190 | """
191 | Orientation loss function
192 | input: y_true -- (batch_size, bin, 2) ground truth orientation value in cos and sin form.
193 | y_pred -- (batch_size, bin, 2) estimated orientation value from the ConvNet
194 | output: loss -- loss values for orientation
195 | """
196 |
197 | # sin^2 + cons^2
198 | anchors = torch.sum(y_true ** 2, dim=2)
199 | # check which bin valid
200 | anchors = torch.gt(anchors, 0.5)
201 | # add valid bin
202 | anchors = torch.sum(anchors.type(torch.float32), dim=1)
203 |
204 | # cos(true)cos(estimate) + sin(true)sin(estimate)
205 | loss = (y_true[:, : ,0] * y_pred[:, :, 0] + y_true[:, :, 1] * y_pred[:, :, 1])
206 | # the mean value in each bin
207 | loss = torch.sum(loss, dim=1) / anchors
208 | # sum the value at each bin
209 | loss = torch.mean(loss)
210 | loss = 2 - 2 * loss
211 |
212 | return loss
213 |
214 | def get_model(backbone: str):
215 | """
216 | Get truncated model and in_features
217 | """
218 |
219 | # list of support model name
220 | # TODO: add more models
221 | list_model = ['resnet18', 'vgg11']
222 | # model_name = str(backbone.__class__.__name__).lower()
223 | assert backbone in list_model, f"Model not support, please choose {list_model}"
224 |
225 | # TODO: change if else with attributes
226 | in_features = None
227 | model = None
228 | if backbone == 'resnet18':
229 | backbone = models.resnet18(pretrained=True)
230 | in_features = backbone.fc.in_features * 7 * 7
231 | model = nn.Sequential(*(list(backbone.children())[:-2]))
232 | elif backbone == 'vgg11':
233 | backbone = models.vgg11(pretrained=True)
234 | in_features = backbone.classifier[0].in_features
235 | model = backbone.features
236 |
237 | return [model, in_features]
238 |
239 |
240 | if __name__ == '__main__':
241 |
242 | # from torchvision.models import resnet18
243 | # from torchsummary import summary
244 |
245 | # backbone = resnet18(pretrained=False)
246 | # model = RegressorNet(backbone, 2)
247 |
248 | # input_size = (3, 224, 224)
249 | # summary(model, input_size, device='cpu')
250 |
251 | # test orientation loss
252 | y_true = torch.tensor([[[0.0, 0.0], [0.9362, 0.3515]]])
253 | y_pred = torch.tensor([[[0.0, 0.0], [0.9362, 0.3515]]])
254 |
255 | print(y_true, "\n", y_pred)
256 | print(orientation_loss2(y_pred, y_true))
257 |
--------------------------------------------------------------------------------
/src/models/regressor.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Regressor Model
3 | """
4 | import torch
5 | from torch import nn
6 | from pytorch_lightning import LightningModule
7 |
8 | from src.models.components.base import OrientationLoss, orientation_loss2
9 |
10 |
11 | class RegressorModel(LightningModule):
12 | def __init__(
13 | self,
14 | net: nn.Module,
15 | optimizer: str = "adam",
16 | lr: float = 0.0001,
17 | momentum: float = 0.9,
18 | w: float = 0.4,
19 | alpha: float = 0.6,
20 | ):
21 | super().__init__()
22 |
23 | # save hyperparamters
24 | self.save_hyperparameters(logger=False)
25 |
26 | # init model
27 | self.net = net
28 |
29 | # loss functions
30 | self.conf_loss_func = nn.CrossEntropyLoss()
31 | self.dim_loss_func = nn.MSELoss()
32 | self.orient_loss_func = OrientationLoss
33 |
34 | # TODO: export model use this
35 | def forward(self, x):
36 | output = self.net(x)
37 | orient = output[0]
38 | conf = output[1]
39 | dim = output[2]
40 | return [orient, conf, dim]
41 |
42 | # def forward(self, x):
43 | # return self.net(x)
44 |
45 | def on_train_start(self):
46 | # by default lightning executes validation step sanity checks before training starts,
47 | # so we need to make sure val_acc_best doesn't store accuracy from these checks
48 | # self.val_acc_best.reset()
49 | pass
50 |
51 | def step(self, batch):
52 | x, y = batch
53 |
54 | # convert to float
55 | x = x.float()
56 | truth_orient = y["Orientation"].float()
57 | truth_conf = y["Confidence"].float() # front or back
58 | truth_dim = y["Dimensions"].float()
59 |
60 | # predict y_hat
61 | preds = self(x)
62 | [orient, conf, dim] = preds
63 |
64 | # compute loss
65 | orient_loss = self.orient_loss_func(orient, truth_orient, truth_conf)
66 | dim_loss = self.dim_loss_func(dim, truth_dim)
67 | # truth_conf = torch.max(truth_conf, dim=1)[1]
68 | conf_loss = self.conf_loss_func(conf, truth_conf)
69 |
70 | loss_theta = conf_loss + 1.5 * self.hparams.w * orient_loss
71 | loss = self.hparams.alpha * dim_loss + loss_theta
72 |
73 | return [loss, loss_theta, orient_loss, dim_loss, conf_loss], preds, y
74 |
75 | def training_step(self, batch, batch_idx):
76 | loss, preds, targets = self.step(batch)
77 |
78 | # logging
79 | self.log_dict(
80 | {
81 | "train/loss": loss[0],
82 | "train/theta_loss": loss[1],
83 | "train/orient_loss": loss[2],
84 | "train/dim_loss": loss[3],
85 | "train/conf_loss": loss[4],
86 | },
87 | on_step=False,
88 | on_epoch=True,
89 | prog_bar=False,
90 | )
91 | return {"loss": loss[0], "preds": preds, "targets": targets}
92 |
93 | def training_epoch_end(self, outputs):
94 | # `outputs` is a list of dicts returned from `training_step()`
95 | pass
96 |
97 | def validation_step(self, batch, batch_idx):
98 | loss, preds, targets = self.step(batch)
99 |
100 | # logging
101 | self.log_dict(
102 | {
103 | "val/loss": loss[0],
104 | "val/theta_loss": loss[1],
105 | "val/orient_loss": loss[2],
106 | "val/dim_loss": loss[3],
107 | "val/conf_loss": loss[4],
108 | },
109 | on_step=False,
110 | on_epoch=True,
111 | prog_bar=False,
112 | )
113 | return {"loss": loss[0], "preds": preds, "targets": targets}
114 |
115 | def validation_epoch_end(self, outputs):
116 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean()
117 |
118 | # log to tensorboard
119 | self.log("val/avg_loss", avg_val_loss)
120 | return {"loss": avg_val_loss}
121 |
122 | def on_epoch_end(self):
123 | # reset metrics at the end of every epoch
124 | pass
125 |
126 | def configure_optimizers(self):
127 | if self.hparams.optimizer.lower() == "adam":
128 | optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr)
129 | elif self.hparams.optimizer.lower() == "sgd":
130 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr,
131 | momentum=self.hparams.momentum
132 | )
133 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
134 | return optimizer
135 |
136 | class RegressorModel2(LightningModule):
137 | def __init__(
138 | self,
139 | net: nn.Module,
140 | lr: float = 0.0001,
141 | momentum: float = 0.9,
142 | w: float = 0.4,
143 | alpha: float = 0.6,
144 | ):
145 | super().__init__()
146 |
147 | # save hyperparamters
148 | self.save_hyperparameters(logger=False)
149 |
150 | # init model
151 | self.net = net
152 |
153 | # loss functions
154 | self.conf_loss_func = nn.CrossEntropyLoss()
155 | self.dim_loss_func = nn.MSELoss()
156 | self.orient_loss_func = orientation_loss2
157 |
158 | def forward(self, x):
159 | return self.net(x)
160 |
161 | def on_train_start(self):
162 | # by default lightning executes validation step sanity checks before training starts,
163 | # so we need to make sure val_acc_best doesn't store accuracy from these checks
164 | # self.val_acc_best.reset()
165 | pass
166 |
167 | def step(self, batch):
168 | x, y = batch
169 |
170 | # convert to float
171 | x = x.float()
172 | gt_orient = y["orientation"].float()
173 | gt_conf = y["confidence"].float()
174 | gt_dims = y["dimensions"].float()
175 |
176 | # predict y_true
177 | predictions = self(x)
178 | [pred_orient, pred_conf, pred_dims] = predictions
179 |
180 | # compute loss
181 | loss_orient = self.orient_loss_func(pred_orient, gt_orient)
182 | loss_dims = self.dim_loss_func(pred_dims, gt_dims)
183 | gt_conf = torch.max(gt_conf, dim=1)[1]
184 | loss_conf = self.conf_loss_func(pred_conf, gt_conf)
185 | # weighting loss => see paper
186 | loss_theta = loss_conf + (self.hparams.w * loss_orient)
187 | loss = (self.hparams.alpha * loss_dims) + loss_theta
188 |
189 | return [loss, loss_theta, loss_orient, loss_conf, loss_dims], predictions, y
190 |
191 | def training_step(self, batch, batch_idx):
192 | loss, preds, targets = self.step(batch)
193 |
194 | # logging
195 | self.log_dict(
196 | {
197 | "train/loss": loss[0],
198 | "train/theta_loss": loss[1],
199 | "train/orient_loss": loss[2],
200 | "train/conf_loss": loss[3],
201 | "train/dim_loss": loss[4],
202 | },
203 | on_step=False,
204 | on_epoch=True,
205 | prog_bar=False,
206 | )
207 | return {"loss": loss[0], "preds": preds, "targets": targets}
208 |
209 | def training_epoch_end(self, outputs):
210 | # `outputs` is a list of dicts returned from `training_step()`
211 | pass
212 |
213 | def validation_step(self, batch, batch_idx):
214 | loss, preds, targets = self.step(batch)
215 |
216 | # logging
217 | self.log_dict(
218 | {
219 | "val/loss": loss[0],
220 | "val/theta_loss": loss[1],
221 | "val/orient_loss": loss[2],
222 | "val/conf_loss": loss[3],
223 | "val/dim_loss": loss[4],
224 | },
225 | on_step=False,
226 | on_epoch=True,
227 | prog_bar=False,
228 | )
229 | return {"loss": loss[0], "preds": preds, "targets": targets}
230 |
231 | def validation_epoch_end(self, outputs):
232 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean()
233 |
234 | # log to tensorboard
235 | self.log("val/avg_loss", avg_val_loss)
236 | return {"loss": avg_val_loss}
237 |
238 | def on_epoch_end(self):
239 | # reset metrics at the end of every epoch
240 | pass
241 |
242 | def configure_optimizers(self):
243 | # optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr)
244 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr,
245 | momentum=self.hparams.momentum
246 | )
247 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)
248 |
249 | return optimizer
250 |
251 | class RegressorModel3(LightningModule):
252 | def __init__(
253 | self,
254 | net: nn.Module,
255 | optimizer: str = "adam",
256 | lr: float = 0.0001,
257 | momentum: float = 0.9,
258 | w: float = 0.4,
259 | alpha: float = 0.6,
260 | ):
261 | super().__init__()
262 |
263 | # save hyperparamters
264 | self.save_hyperparameters(logger=False)
265 |
266 | # init model
267 | self.net = net
268 |
269 | # loss functions
270 | self.conf_loss_func = nn.CrossEntropyLoss()
271 | self.dim_loss_func = nn.MSELoss()
272 | self.orient_loss_func = OrientationLoss
273 |
274 | def forward(self, x):
275 | return self.net(x)
276 |
277 | def on_train_start(self):
278 | # by default lightning executes validation step sanity checks before training starts,
279 | # so we need to make sure val_acc_best doesn't store accuracy from these checks
280 | # self.val_acc_best.reset()
281 | pass
282 |
283 | def step(self, batch):
284 | x, y = batch
285 |
286 | # convert to float
287 | x = x.float()
288 | gt_orient = y["orientation"].float()
289 | gt_conf = y["confidence"].float()
290 | gt_dims = y["dimensions"].float()
291 |
292 | # predict y_true
293 | predictions = self(x)
294 | [pred_orient, pred_conf, pred_dims] = predictions
295 |
296 | # compute loss
297 | loss_orient = self.orient_loss_func(pred_orient, gt_orient, gt_conf)
298 | loss_dims = self.dim_loss_func(pred_dims, gt_dims)
299 | gt_conf = torch.max(gt_conf, dim=1)[1]
300 | loss_conf = self.conf_loss_func(pred_conf, gt_conf)
301 | # weighting loss => see paper
302 | loss_theta = loss_conf + (self.hparams.w * loss_orient)
303 | loss = (self.hparams.alpha * loss_dims) + loss_theta
304 |
305 | return [loss, loss_theta, loss_orient, loss_conf, loss_dims], predictions, y
306 |
307 | def training_step(self, batch, batch_idx):
308 | loss, preds, targets = self.step(batch)
309 |
310 | # logging
311 | self.log_dict(
312 | {
313 | "train/loss": loss[0],
314 | "train/theta_loss": loss[1],
315 | "train/orient_loss": loss[2],
316 | "train/conf_loss": loss[3],
317 | "train/dim_loss": loss[4],
318 | },
319 | on_step=False,
320 | on_epoch=True,
321 | prog_bar=False,
322 | )
323 | return {"loss": loss[0], "preds": preds, "targets": targets}
324 |
325 | def training_epoch_end(self, outputs):
326 | # `outputs` is a list of dicts returned from `training_step()`
327 | pass
328 |
329 | def validation_step(self, batch, batch_idx):
330 | loss, preds, targets = self.step(batch)
331 |
332 | # logging
333 | self.log_dict(
334 | {
335 | "val/loss": loss[0],
336 | "val/theta_loss": loss[1],
337 | "val/orient_loss": loss[2],
338 | "val/conf_loss": loss[3],
339 | "val/dim_loss": loss[4],
340 | },
341 | on_step=False,
342 | on_epoch=True,
343 | prog_bar=False,
344 | )
345 | return {"loss": loss[0], "preds": preds, "targets": targets}
346 |
347 | def validation_epoch_end(self, outputs):
348 | avg_val_loss = torch.tensor([x["loss"] for x in outputs]).mean()
349 |
350 | # log to tensorboard
351 | self.log("val/avg_loss", avg_val_loss)
352 | return {"loss": avg_val_loss}
353 |
354 | def on_epoch_end(self):
355 | # reset metrics at the end of every epoch
356 | pass
357 |
358 | def configure_optimizers(self):
359 | if self.hparams.optimizer.lower() == "adam":
360 | optimizer = torch.optim.Adam(params=self.parameters(), lr=self.hparams.lr)
361 | elif self.hparams.optimizer.lower() == "sgd":
362 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr,
363 | momentum=self.hparams.momentum
364 | )
365 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
366 |
367 | return optimizer
368 |
369 | if __name__ == "__main__":
370 |
371 | from src.models.components.base import RegressorNet
372 | from torchvision.models import resnet18
373 |
374 | model1 = RegressorModel(
375 | net=RegressorNet(backbone=resnet18(pretrained=False), bins=2),
376 | )
377 |
378 | print(model1)
379 |
380 | model2 = RegressorModel3(
381 | net=RegressorNet(backbone=resnet18(pretrained=False), bins=2),
382 | )
383 |
384 | print(model2)
385 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import pyrootutils
2 |
3 | root = pyrootutils.setup_root(
4 | search_from=__file__,
5 | indicator=[".git", "pyproject.toml"],
6 | pythonpath=True,
7 | dotenv=True,
8 | )
9 |
10 | # ------------------------------------------------------------------------------------ #
11 | # `pyrootutils.setup_root(...)` is recommended at the top of each start file
12 | # to make the environment more robust and consistent
13 | #
14 | # the line above searches for ".git" or "pyproject.toml" in present and parent dirs
15 | # to determine the project root dir
16 | #
17 | # adds root dir to the PYTHONPATH (if `pythonpath=True`)
18 | # so this file can be run from any place without installing project as a package
19 | #
20 | # sets PROJECT_ROOT environment variable which is used in "configs/paths/default.yaml"
21 | # this makes all paths relative to the project root
22 | #
23 | # additionally loads environment variables from ".env" file (if `dotenv=True`)
24 | #
25 | # you can get away without using `pyrootutils.setup_root(...)` if you:
26 | # 1. move this file to the project root dir or install project as a package
27 | # 2. modify paths in "configs/paths/default.yaml" to not use PROJECT_ROOT
28 | # 3. always run this file from the project root dir
29 | #
30 | # https://github.com/ashleve/pyrootutils
31 | # ------------------------------------------------------------------------------------ #
32 |
33 | from typing import List, Optional, Tuple
34 |
35 | import hydra
36 | import pytorch_lightning as pl
37 | from omegaconf import DictConfig
38 | from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
39 | from pytorch_lightning.loggers import LightningLoggerBase
40 |
41 | from src import utils
42 |
43 | log = utils.get_pylogger(__name__)
44 |
45 |
46 | @utils.task_wrapper
47 | def train(cfg: DictConfig) -> Tuple[dict, dict]:
48 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
49 | training.
50 |
51 | This method is wrapped in optional @task_wrapper decorator which applies extra utilities
52 | before and after the call.
53 |
54 | Args:
55 | cfg (DictConfig): Configuration composed by Hydra.
56 |
57 | Returns:
58 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
59 | """
60 |
61 | # set seed for random number generators in pytorch, numpy and python.random
62 | if cfg.get("seed"):
63 | pl.seed_everything(cfg.seed, workers=True)
64 |
65 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
66 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
67 |
68 | log.info(f"Instantiating model <{cfg.model._target_}>")
69 | model: LightningModule = hydra.utils.instantiate(cfg.model)
70 |
71 | log.info("Instantiating callbacks...")
72 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
73 |
74 | log.info("Instantiating loggers...")
75 | logger: List[LightningLoggerBase] = utils.instantiate_loggers(cfg.get("logger"))
76 |
77 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
78 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
79 |
80 | object_dict = {
81 | "cfg": cfg,
82 | "datamodule": datamodule,
83 | "model": model,
84 | "callbacks": callbacks,
85 | "logger": logger,
86 | "trainer": trainer,
87 | }
88 |
89 | if logger:
90 | log.info("Logging hyperparameters!")
91 | utils.log_hyperparameters(object_dict)
92 |
93 | # train
94 | if cfg.get("train"):
95 | log.info("Starting training!")
96 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
97 |
98 | train_metrics = trainer.callback_metrics
99 |
100 | if cfg.get("test"):
101 | log.info("Starting testing!")
102 | ckpt_path = trainer.checkpoint_callback.best_model_path
103 | if ckpt_path == "":
104 | log.warning("Best ckpt not found! Using current weights for testing...")
105 | ckpt_path = None
106 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
107 | log.info(f"Best ckpt path: {ckpt_path}")
108 |
109 | test_metrics = trainer.callback_metrics
110 |
111 | # merge train and test metrics
112 | metric_dict = {**train_metrics, **test_metrics}
113 |
114 | return metric_dict, object_dict
115 |
116 |
117 | @hydra.main(version_base="1.2", config_path=root / "configs", config_name="train.yaml")
118 | def main(cfg: DictConfig) -> Optional[float]:
119 |
120 | # train the model
121 | metric_dict, _ = train(cfg)
122 |
123 | # safely retrieve metric value for hydra-based hyperparameter optimization
124 | metric_value = utils.get_metric_value(
125 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
126 | )
127 |
128 | # return optimized metric
129 | return metric_value
130 |
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/src/utils/Calib.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for handling calibration file
3 | """
4 |
5 | import numpy as np
6 |
7 | def get_P(calib_file):
8 | """
9 | Get matrix P_rect_02 (camera 2 RGB)
10 | and transform to 3 x 4 matrix
11 | """
12 | for line in open(calib_file, 'r'):
13 | if 'P_rect_02' in line:
14 | cam_P = line.strip().split(' ')
15 | cam_P = np.asarray([float(cam_P) for cam_P in cam_P[1:]])
16 | matrix = np.zeros((3, 4))
17 | matrix = cam_P.reshape((3, 4))
18 | return matrix
19 |
20 | # TODO: understand this
21 |
22 | def get_calibration_cam_to_image(cab_f):
23 | for line in open(cab_f):
24 | if 'P2:' in line:
25 | cam_to_img = line.strip().split(' ')
26 | cam_to_img = np.asarray([float(number) for number in cam_to_img[1:]])
27 | cam_to_img = np.reshape(cam_to_img, (3, 4))
28 | return cam_to_img
29 |
30 | file_not_found(cab_f)
31 |
32 | def get_R0(cab_f):
33 | for line in open(cab_f):
34 | if 'R0_rect:' in line:
35 | R0 = line.strip().split(' ')
36 | R0 = np.asarray([float(number) for number in R0[1:]])
37 | R0 = np.reshape(R0, (3, 3))
38 |
39 | R0_rect = np.zeros([4,4])
40 | R0_rect[3,3] = 1
41 | R0_rect[:3,:3] = R0
42 |
43 | return R0_rect
44 |
45 | def get_tr_to_velo(cab_f):
46 | for line in open(cab_f):
47 | if 'Tr_velo_to_cam:' in line:
48 | Tr = line.strip().split(' ')
49 | Tr = np.asarray([float(number) for number in Tr[1:]])
50 | Tr = np.reshape(Tr, (3, 4))
51 |
52 | Tr_to_velo = np.zeros([4,4])
53 | Tr_to_velo[3,3] = 1
54 | Tr_to_velo[:3,:4] = Tr
55 |
56 | return Tr_to_velo
57 |
58 | def file_not_found(filename):
59 | print("\nError! Can't read calibration file, does %s exist?"%filename)
60 | exit()
61 |
62 |
--------------------------------------------------------------------------------
/src/utils/Math.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # using this math: https://en.wikipedia.org/wiki/Rotation_matrix
4 | def rotation_matrix(yaw, pitch=0, roll=0):
5 | tx = roll
6 | ty = yaw
7 | tz = pitch
8 |
9 | Rx = np.array([[1,0,0], [0, np.cos(tx), -np.sin(tx)], [0, np.sin(tx), np.cos(tx)]])
10 | Ry = np.array([[np.cos(ty), 0, np.sin(ty)], [0, 1, 0], [-np.sin(ty), 0, np.cos(ty)]])
11 | Rz = np.array([[np.cos(tz), -np.sin(tz), 0], [np.sin(tz), np.cos(tz), 0], [0,0,1]])
12 |
13 |
14 | return Ry.reshape([3,3])
15 | # return np.dot(np.dot(Rz,Ry), Rx)
16 |
17 | # option to rotate and shift (for label info)
18 | def create_corners(dimension, location=None, R=None):
19 | dx = dimension[2] / 2
20 | dy = dimension[0] / 2
21 | dz = dimension[1] / 2
22 |
23 | x_corners = []
24 | y_corners = []
25 | z_corners = []
26 |
27 | for i in [1, -1]:
28 | for j in [1,-1]:
29 | for k in [1,-1]:
30 | x_corners.append(dx*i)
31 | y_corners.append(dy*j)
32 | z_corners.append(dz*k)
33 |
34 | corners = [x_corners, y_corners, z_corners]
35 |
36 | # rotate if R is passed in
37 | if R is not None:
38 | corners = np.dot(R, corners)
39 |
40 | # shift if location is passed in
41 | if location is not None:
42 | for i,loc in enumerate(location):
43 | corners[i,:] = corners[i,:] + loc
44 |
45 | final_corners = []
46 | for i in range(8):
47 | final_corners.append([corners[0][i], corners[1][i], corners[2][i]])
48 |
49 |
50 | return final_corners
51 |
52 | # this is based on the paper. Math!
53 | # calib is a 3x4 matrix, box_2d is [(xmin, ymin), (xmax, ymax)]
54 | # Math help: http://ywpkwon.github.io/pdf/bbox3d-study.pdf
55 | def calc_location(dimension, proj_matrix, box_2d, alpha, theta_ray):
56 | #global orientation
57 | orient = alpha + theta_ray
58 | R = rotation_matrix(orient)
59 |
60 | # format 2d corners
61 | try:
62 | xmin = box_2d[0][0]
63 | ymin = box_2d[0][1]
64 | xmax = box_2d[1][0]
65 | ymax = box_2d[1][1]
66 | except:
67 | xmin = box_2d[0]
68 | ymin = box_2d[1]
69 | xmax = box_2d[2]
70 | ymax = box_2d[3]
71 |
72 | # left top right bottom
73 | box_corners = [xmin, ymin, xmax, ymax]
74 |
75 | # get the point constraints
76 | constraints = []
77 |
78 | left_constraints = []
79 | right_constraints = []
80 | top_constraints = []
81 | bottom_constraints = []
82 |
83 | # using a different coord system
84 | dx = dimension[2] / 2
85 | dy = dimension[0] / 2
86 | dz = dimension[1] / 2
87 |
88 | # below is very much based on trial and error
89 |
90 | # based on the relative angle, a different configuration occurs
91 | # negative is back of car, positive is front
92 | left_mult = 1
93 | right_mult = -1
94 |
95 | # about straight on but opposite way
96 | if alpha < np.deg2rad(92) and alpha > np.deg2rad(88):
97 | left_mult = 1
98 | right_mult = 1
99 | # about straight on and same way
100 | elif alpha < np.deg2rad(-88) and alpha > np.deg2rad(-92):
101 | left_mult = -1
102 | right_mult = -1
103 | # this works but doesnt make much sense
104 | elif alpha < np.deg2rad(90) and alpha > -np.deg2rad(90):
105 | left_mult = -1
106 | right_mult = 1
107 |
108 | # if the car is facing the oppositeway, switch left and right
109 | switch_mult = -1
110 | if alpha > 0:
111 | switch_mult = 1
112 |
113 | # left and right could either be the front of the car ot the back of the car
114 | # careful to use left and right based on image, no of actual car's left and right
115 | for i in (-1,1):
116 | left_constraints.append([left_mult * dx, i*dy, -switch_mult * dz])
117 | for i in (-1,1):
118 | right_constraints.append([right_mult * dx, i*dy, switch_mult * dz])
119 |
120 | # top and bottom are easy, just the top and bottom of car
121 | for i in (-1,1):
122 | for j in (-1,1):
123 | top_constraints.append([i*dx, -dy, j*dz])
124 | for i in (-1,1):
125 | for j in (-1,1):
126 | bottom_constraints.append([i*dx, dy, j*dz])
127 |
128 | # now, 64 combinations
129 | for left in left_constraints:
130 | for top in top_constraints:
131 | for right in right_constraints:
132 | for bottom in bottom_constraints:
133 | constraints.append([left, top, right, bottom])
134 |
135 | # filter out the ones with repeats
136 | constraints = filter(lambda x: len(x) == len(set(tuple(i) for i in x)), constraints)
137 |
138 | # create pre M (the term with I and the R*X)
139 | pre_M = np.zeros([4,4])
140 | # 1's down diagonal
141 | for i in range(0,4):
142 | pre_M[i][i] = 1
143 |
144 | best_loc = None
145 | best_error = [1e09]
146 | best_X = None
147 |
148 | # loop through each possible constraint, hold on to the best guess
149 | # constraint will be 64 sets of 4 corners
150 | count = 0
151 | for constraint in constraints:
152 | # each corner
153 | Xa = constraint[0]
154 | Xb = constraint[1]
155 | Xc = constraint[2]
156 | Xd = constraint[3]
157 |
158 | X_array = [Xa, Xb, Xc, Xd]
159 |
160 | # M: all 1's down diagonal, and upper 3x1 is Rotation_matrix * [x, y, z]
161 | Ma = np.copy(pre_M)
162 | Mb = np.copy(pre_M)
163 | Mc = np.copy(pre_M)
164 | Md = np.copy(pre_M)
165 |
166 | M_array = [Ma, Mb, Mc, Md]
167 |
168 | # create A, b
169 | A = np.zeros([4,3], dtype=np.float)
170 | b = np.zeros([4,1])
171 |
172 | indicies = [0,1,0,1]
173 | for row, index in enumerate(indicies):
174 | X = X_array[row]
175 | M = M_array[row]
176 |
177 | # create M for corner Xx
178 | RX = np.dot(R, X)
179 | M[:3,3] = RX.reshape(3)
180 |
181 | M = np.dot(proj_matrix, M)
182 |
183 | A[row, :] = M[index,:3] - box_corners[row] * M[2,:3]
184 | b[row] = box_corners[row] * M[2,3] - M[index,3]
185 |
186 | # solve here with least squares, since over fit will get some error
187 | loc, error, rank, s = np.linalg.lstsq(A, b, rcond=None)
188 |
189 | # found a better estimation
190 | if error < best_error:
191 | count += 1 # for debugging
192 | best_loc = loc
193 | best_error = error
194 | best_X = X_array
195 |
196 | # return best_loc, [left_constraints, right_constraints] # for debugging
197 | best_loc = [best_loc[0][0], best_loc[1][0], best_loc[2][0]]
198 | return best_loc, best_X
199 |
200 | """
201 | Code for generating new plot with bev and 3dbbox
202 | source: https://github.com/lzccccc/3d-bounding-box-estimation-for-autonomous-driving
203 | """
204 |
205 | def get_new_alpha(alpha):
206 | """
207 | change the range of orientation from [-pi, pi] to [0, 2pi]
208 | :param alpha: original orientation in KITTI
209 | :return: new alpha
210 | """
211 | new_alpha = float(alpha) + np.pi / 2.
212 | if new_alpha < 0:
213 | new_alpha = new_alpha + 2. * np.pi
214 | # make sure angle lies in [0, 2pi]
215 | new_alpha = new_alpha - int(new_alpha / (2. * np.pi)) * (2. * np.pi)
216 |
217 | return new_alpha
218 |
219 | def recover_angle(bin_anchor, bin_confidence, bin_num):
220 | # select anchor from bins
221 | max_anc = np.argmax(bin_confidence)
222 | anchors = bin_anchor[max_anc]
223 | # compute the angle offset
224 | if anchors[1] > 0:
225 | angle_offset = np.arccos(anchors[0])
226 | else:
227 | angle_offset = -np.arccos(anchors[0])
228 |
229 | # add the angle offset to the center ray of each bin to obtain the local orientation
230 | wedge = 2 * np.pi / bin_num
231 | angle = angle_offset + max_anc * wedge
232 |
233 | # angle - 2pi, if exceed 2pi
234 | angle_l = angle % (2 * np.pi)
235 |
236 | # change to ray back to [-pi, pi]
237 | angle = angle_l + wedge / 2 - np.pi
238 | if angle > np.pi:
239 | angle -= 2 * np.pi
240 | angle = round(angle, 2)
241 | return angle
242 |
243 |
244 | def compute_orientaion(P2, obj):
245 | x = (obj.xmax + obj.xmin) / 2
246 | # compute camera orientation
247 | u_distance = x - P2[0, 2]
248 | focal_length = P2[0, 0]
249 | rot_ray = np.arctan(u_distance / focal_length)
250 | # global = alpha + ray
251 | rot_global = obj.alpha + rot_ray
252 |
253 | # local orientation, [0, 2 * pi]
254 | # rot_local = obj.alpha + np.pi / 2
255 | rot_local = get_new_alpha(obj.alpha)
256 |
257 | rot_global = round(rot_global, 2)
258 | return rot_global, rot_local
259 |
260 |
261 | def translation_constraints(P2, obj, rot_local):
262 | bbox = [obj.xmin, obj.ymin, obj.xmax, obj.ymax]
263 | # rotation matrix
264 | R = np.array([[ np.cos(obj.rot_global), 0, np.sin(obj.rot_global)],
265 | [ 0, 1, 0 ],
266 | [-np.sin(obj.rot_global), 0, np.cos(obj.rot_global)]])
267 | A = np.zeros((4, 3))
268 | b = np.zeros((4, 1))
269 | I = np.identity(3)
270 |
271 | # object coordinate T, samply divide into xyz
272 | # bug1: h div 2
273 | xmin_candi, xmax_candi, ymin_candi, ymax_candi = obj.box3d_candidate(rot_local, soft_range=8)
274 |
275 | X = np.bmat([xmin_candi, xmax_candi,
276 | ymin_candi, ymax_candi])
277 | # X: [x, y, z] in object coordinate
278 | X = X.reshape(4,3).T
279 |
280 | # construct equation (3, 4)
281 | # object four point in bev
282 | for i in range(4):
283 | # X[:,i] sames as Ti
284 | # matrice = [R T] * Xo
285 | matrice = np.bmat([[I, np.matmul(R, X[:,i])], [np.zeros((1,3)), np.ones((1,1))]])
286 | # M = K * [R T] * Xo
287 | M = np.matmul(P2, matrice)
288 |
289 | if i % 2 == 0:
290 | A[i, :] = M[0, 0:3] - bbox[i] * M[2, 0:3]
291 | b[i, :] = M[2, 3] * bbox[i] - M[0, 3]
292 |
293 | else:
294 | A[i, :] = M[1, 0:3] - bbox[i] * M[2, 0:3]
295 | b[i, :] = M[2, 3] * bbox[i] - M[1, 3]
296 | # solve x, y, z, using method of least square
297 | Tran = np.matmul(np.linalg.pinv(A), b)
298 |
299 | tx, ty, tz = [float(np.around(tran, 2)) for tran in Tran]
300 | return tx, ty, tz
301 |
302 |
303 | class detectionInfo(object):
304 | def __init__(self, line):
305 | self.name = line[0]
306 |
307 | self.truncation = float(line[1])
308 | self.occlusion = int(line[2])
309 |
310 | # local orientation = alpha + pi/2
311 | self.alpha = float(line[3])
312 |
313 | # in pixel coordinate
314 | self.xmin = float(line[4])
315 | self.ymin = float(line[5])
316 | self.xmax = float(line[6])
317 | self.ymax = float(line[7])
318 |
319 | # height, weigh, length in object coordinate, meter
320 | self.h = float(line[8])
321 | self.w = float(line[9])
322 | self.l = float(line[10])
323 |
324 | # x, y, z in camera coordinate, meter
325 | self.tx = float(line[11])
326 | self.ty = float(line[12])
327 | self.tz = float(line[13])
328 |
329 | # global orientation [-pi, pi]
330 | self.rot_global = float(line[14])
331 |
332 | def member_to_list(self):
333 | output_line = []
334 | for name, value in vars(self).items():
335 | output_line.append(value)
336 | return output_line
337 |
338 | def box3d_candidate(self, rot_local, soft_range):
339 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0]
340 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0]
341 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0]
342 |
343 | x_corners = [i - self.l / 2 for i in x_corners]
344 | y_corners = [i - self.h / 2 for i in y_corners]
345 | z_corners = [i - self.w / 2 for i in z_corners]
346 |
347 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners]))
348 | point1 = corners_3d[0, :]
349 | point2 = corners_3d[1, :]
350 | point3 = corners_3d[2, :]
351 | point4 = corners_3d[3, :]
352 | point5 = corners_3d[6, :]
353 | point6 = corners_3d[7, :]
354 | point7 = corners_3d[4, :]
355 | point8 = corners_3d[5, :]
356 |
357 | # set up projection relation based on local orientation
358 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0
359 |
360 | if 0 < rot_local < np.pi / 2:
361 | xmin_candi = point8
362 | xmax_candi = point2
363 | ymin_candi = point2
364 | ymax_candi = point5
365 |
366 | if np.pi / 2 <= rot_local <= np.pi:
367 | xmin_candi = point6
368 | xmax_candi = point4
369 | ymin_candi = point4
370 | ymax_candi = point1
371 |
372 | if np.pi < rot_local <= 3 / 2 * np.pi:
373 | xmin_candi = point2
374 | xmax_candi = point8
375 | ymin_candi = point8
376 | ymax_candi = point1
377 |
378 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi:
379 | xmin_candi = point4
380 | xmax_candi = point6
381 | ymin_candi = point6
382 | ymax_candi = point5
383 |
384 | # soft constraint
385 | div = soft_range * np.pi / 180
386 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi:
387 | xmin_candi = point8
388 | xmax_candi = point6
389 | ymin_candi = point6
390 | ymax_candi = point5
391 |
392 | if np.pi - div < rot_local < np.pi + div:
393 | xmin_candi = point2
394 | xmax_candi = point4
395 | ymin_candi = point8
396 | ymax_candi = point1
397 |
398 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from src.utils.pylogger import get_pylogger
2 | from src.utils.rich_utils import enforce_tags, print_config_tree
3 | from src.utils.utils import (
4 | close_loggers,
5 | extras,
6 | get_metric_value,
7 | instantiate_callbacks,
8 | instantiate_loggers,
9 | log_hyperparameters,
10 | save_file,
11 | task_wrapper,
12 | )
13 |
--------------------------------------------------------------------------------
/src/utils/averages.py:
--------------------------------------------------------------------------------
1 | """Average dimension class"""
2 |
3 | from typing import List
4 | import numpy as np
5 | import os
6 | import json
7 |
8 | class DimensionAverages:
9 | """
10 | Class to calculate the average dimensions of the objects in the dataset.
11 | """
12 | def __init__(
13 | self,
14 | categories: List[str] = ['car', 'pedestrian', 'cyclist'],
15 | save_file: str = 'dimension_averages.txt'
16 | ):
17 | self.dimension_map = {}
18 | self.filename = os.path.abspath(os.path.dirname(__file__)) + '/' + save_file
19 | self.categories = categories
20 |
21 | if len(self.categories) == 0:
22 | self.load_items_from_file()
23 |
24 | for det in self.categories:
25 | cat_ = det.lower()
26 | if cat_ in self.dimension_map.keys():
27 | continue
28 | self.dimension_map[cat_] = {}
29 | self.dimension_map[cat_]['count'] = 0
30 | self.dimension_map[cat_]['total'] = np.zeros(3, dtype=np.float32)
31 |
32 | def add_items(self, items_path):
33 | for path in items_path:
34 | with open(path, "r") as f:
35 | for line in f:
36 | line = line.split(" ")
37 | if line[0].lower() in self.categories:
38 | self.add_item(
39 | line[0],
40 | np.array([float(line[8]), float(line[9]), float(line[10])])
41 | )
42 |
43 | def add_item(self, cat, dim):
44 | cat = cat.lower()
45 | self.dimension_map[cat]['count'] += 1
46 | self.dimension_map[cat]['total'] += dim
47 |
48 | def get_item(self, cat):
49 | cat = cat.lower()
50 | return self.dimension_map[cat]['total'] / self.dimension_map[cat]['count']
51 |
52 | def load_items_from_file(self):
53 | f = open(self.filename, 'r')
54 | dimension_map = json.load(f)
55 |
56 | for cat in dimension_map:
57 | dimension_map[cat]['total'] = np.asarray(dimension_map[cat]['total'])
58 |
59 | self.dimension_map = dimension_map
60 |
61 | def dump_to_file(self):
62 | f = open(self.filename, "w")
63 | f.write(json.dumps(self.dimension_map, cls=NumpyEncoder))
64 | f.close()
65 |
66 | def recognized_class(self, cat):
67 | return cat.lower() in self.dimension_map
68 |
69 | class ClassAverages:
70 | def __init__(self, classes=[]):
71 | self.dimension_map = {}
72 | self.filename = os.path.abspath(os.path.dirname(__file__)) + '/class_averages.txt'
73 |
74 | if len(classes) == 0: # eval mode
75 | self.load_items_from_file()
76 |
77 | for detection_class in classes:
78 | class_ = detection_class.lower()
79 | if class_ in self.dimension_map.keys():
80 | continue
81 | self.dimension_map[class_] = {}
82 | self.dimension_map[class_]['count'] = 0
83 | self.dimension_map[class_]['total'] = np.zeros(3, dtype=np.double)
84 |
85 |
86 | def add_item(self, class_, dimension):
87 | class_ = class_.lower()
88 | self.dimension_map[class_]['count'] += 1
89 | self.dimension_map[class_]['total'] += dimension
90 | # self.dimension_map[class_]['total'] /= self.dimension_map[class_]['count']
91 |
92 | def get_item(self, class_):
93 | class_ = class_.lower()
94 | return self.dimension_map[class_]['total'] / self.dimension_map[class_]['count']
95 |
96 | def dump_to_file(self):
97 | f = open(self.filename, "w")
98 | f.write(json.dumps(self.dimension_map, cls=NumpyEncoder))
99 | f.close()
100 |
101 | def load_items_from_file(self):
102 | f = open(self.filename, 'r')
103 | dimension_map = json.load(f)
104 |
105 | for class_ in dimension_map:
106 | dimension_map[class_]['total'] = np.asarray(dimension_map[class_]['total'])
107 |
108 | self.dimension_map = dimension_map
109 |
110 | def recognized_class(self, class_):
111 | return class_.lower() in self.dimension_map
112 |
113 | class NumpyEncoder(json.JSONEncoder):
114 | def default(self, obj):
115 | if isinstance(obj, np.ndarray):
116 | return obj.tolist()
117 | return json.JSONEncoder.default(self,obj)
--------------------------------------------------------------------------------
/src/utils/class_averages-L4.txt:
--------------------------------------------------------------------------------
1 | {"car": {"count": 15939, "total": [25041.166112000075, 27306.989660000134, 68537.86737500015]}, "pedestrian": {"count": 1793, "total": [3018.555634000002, 974.4651910000001, 847.5994529999997]}, "cyclist": {"count": 116, "total": [169.2123550000001, 54.55659699999999, 187.05904800000002]}, "truck": {"count": 741, "total": [2219.4890700000014, 1835.661420999999, 6906.846059999999]}, "van": {"count": 632, "total": [1366.3955200000005, 1232.9152299999998, 3155.905800000001]}, "trafficcone": {"count": 0, "total": [0.0, 0.0, 0.0]}, "unknown": {"count": 4294, "total": [4924.141288999996, 3907.031903999998, 6185.184788000007]}}
--------------------------------------------------------------------------------
/src/utils/class_averages-kitti6.txt:
--------------------------------------------------------------------------------
1 | {"car": {"count": 14385, "total": [21898.43999999967, 23568.289999999495, 55754.239999999765]}, "cyclist": {"count": 893, "total": [1561.5099999999982, 552.850000000001, 1569.5100000000007]}, "truck": {"count": 606, "total": [1916.6199999999872, 1554.710000000011, 6567.400000000018]}, "van": {"count": 1617, "total": [3593.439999999989, 3061.370000000014, 8122.769999999951]}, "pedestrian": {"count": 2280, "total": [3998.8900000000003, 1576.6400000000049, 1974.090000000009]}, "tram": {"count": 287, "total": [1012.7000000000005, 771.13, 4739.249999999991]}}
--------------------------------------------------------------------------------
/src/utils/class_averages.txt:
--------------------------------------------------------------------------------
1 | {"car": {"count": 14385, "total": [21898.43999999967, 23568.289999999495, 55754.239999999765]}, "cyclist": {"count": 893, "total": [1561.5099999999982, 552.850000000001, 1569.5100000000007]}, "truck": {"count": 606, "total": [1916.6199999999872, 1554.710000000011, 6567.400000000018]}, "van": {"count": 1617, "total": [3593.439999999989, 3061.370000000014, 8122.769999999951]}, "pedestrian": {"count": 2280, "total": [3998.8900000000003, 1576.6400000000049, 1974.090000000009]}, "tram": {"count": 287, "total": [1012.7000000000005, 771.13, 4739.249999999991]}}
--------------------------------------------------------------------------------
/src/utils/pylogger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from pytorch_lightning.utilities import rank_zero_only
4 |
5 |
6 | def get_pylogger(name=__name__) -> logging.Logger:
7 | """Initializes multi-GPU-friendly python command line logger."""
8 |
9 | logger = logging.getLogger(name)
10 |
11 | # this ensures all logging levels get marked with the rank zero decorator
12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
14 | for level in logging_levels:
15 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
16 |
17 | return logger
18 |
--------------------------------------------------------------------------------
/src/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence
3 |
4 | import rich
5 | import rich.syntax
6 | import rich.tree
7 | from hydra.core.hydra_config import HydraConfig
8 | from omegaconf import DictConfig, OmegaConf, open_dict
9 | from pytorch_lightning.utilities import rank_zero_only
10 | from rich.prompt import Prompt
11 |
12 | from src.utils import pylogger
13 |
14 | log = pylogger.get_pylogger(__name__)
15 |
16 |
17 | @rank_zero_only
18 | def print_config_tree(
19 | cfg: DictConfig,
20 | print_order: Sequence[str] = (
21 | "datamodule",
22 | "model",
23 | "callbacks",
24 | "logger",
25 | "trainer",
26 | "paths",
27 | "extras",
28 | ),
29 | resolve: bool = False,
30 | save_to_file: bool = False,
31 | ) -> None:
32 | """Prints content of DictConfig using Rich library and its tree structure.
33 |
34 | Args:
35 | cfg (DictConfig): Configuration composed by Hydra.
36 | print_order (Sequence[str], optional): Determines in what order config components are printed.
37 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
38 | save_to_file (bool, optional): Whether to export config to the hydra output folder.
39 | """
40 |
41 | style = "dim"
42 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
43 |
44 | queue = []
45 |
46 | # add fields from `print_order` to queue
47 | for field in print_order:
48 | queue.append(field) if field in cfg else log.warning(
49 | f"Field '{field}' not found in config. Skipping '{field}' config printing..."
50 | )
51 |
52 | # add all the other fields to queue (not specified in `print_order`)
53 | for field in cfg:
54 | if field not in queue:
55 | queue.append(field)
56 |
57 | # generate config tree from queue
58 | for field in queue:
59 | branch = tree.add(field, style=style, guide_style=style)
60 |
61 | config_group = cfg[field]
62 | if isinstance(config_group, DictConfig):
63 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
64 | else:
65 | branch_content = str(config_group)
66 |
67 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
68 |
69 | # print config tree
70 | rich.print(tree)
71 |
72 | # save config tree to file
73 | if save_to_file:
74 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
75 | rich.print(tree, file=file)
76 |
77 |
78 | @rank_zero_only
79 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
80 | """Prompts user to input tags from command line if no tags are provided in config."""
81 |
82 | if not cfg.get("tags"):
83 | if "id" in HydraConfig().cfg.hydra.job:
84 | raise ValueError("Specify tags before launching a multirun!")
85 |
86 | log.warning("No tags provided in config. Prompting user to input tags...")
87 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
88 | tags = [t.strip() for t in tags.split(",") if t != ""]
89 |
90 | with open_dict(cfg):
91 | cfg.tags = tags
92 |
93 | log.info(f"Tags: {cfg.tags}")
94 |
95 | if save_to_file:
96 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
97 | rich.print(cfg.tags, file=file)
98 |
99 |
100 | if __name__ == "__main__":
101 | from hydra import compose, initialize
102 |
103 | with initialize(version_base="1.2", config_path="../../configs"):
104 | cfg = compose(config_name="train.yaml", return_hydra_config=False, overrides=[])
105 | print_config_tree(cfg, resolve=False, save_to_file=False)
106 |
--------------------------------------------------------------------------------
/src/utils/rotate_iou.py:
--------------------------------------------------------------------------------
1 | #####################
2 | # Based on https://github.com/hongzhenwang/RRPN-revise
3 | # Licensed under The MIT License
4 | # Author: yanyan, scrin@foxmail.com
5 | #####################
6 | import math
7 |
8 | import numba
9 | import numpy as np
10 | from numba import cuda
11 |
12 | @numba.jit(nopython=True)
13 | def div_up(m, n):
14 | return m // n + (m % n > 0)
15 |
16 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
17 | def trangle_area(a, b, c):
18 | return ((a[0] - c[0]) * (b[1] - c[1]) - (a[1] - c[1]) *
19 | (b[0] - c[0])) / 2.0
20 |
21 |
22 | @cuda.jit('(float32[:], int32)', device=True, inline=True)
23 | def area(int_pts, num_of_inter):
24 | area_val = 0.0
25 | for i in range(num_of_inter - 2):
26 | area_val += abs(
27 | trangle_area(int_pts[:2], int_pts[2 * i + 2:2 * i + 4],
28 | int_pts[2 * i + 4:2 * i + 6]))
29 | return area_val
30 |
31 |
32 | @cuda.jit('(float32[:], int32)', device=True, inline=True)
33 | def sort_vertex_in_convex_polygon(int_pts, num_of_inter):
34 | if num_of_inter > 0:
35 | center = cuda.local.array((2, ), dtype=numba.float32)
36 | center[:] = 0.0
37 | for i in range(num_of_inter):
38 | center[0] += int_pts[2 * i]
39 | center[1] += int_pts[2 * i + 1]
40 | center[0] /= num_of_inter
41 | center[1] /= num_of_inter
42 | v = cuda.local.array((2, ), dtype=numba.float32)
43 | vs = cuda.local.array((16, ), dtype=numba.float32)
44 | for i in range(num_of_inter):
45 | v[0] = int_pts[2 * i] - center[0]
46 | v[1] = int_pts[2 * i + 1] - center[1]
47 | d = math.sqrt(v[0] * v[0] + v[1] * v[1])
48 | v[0] = v[0] / d
49 | v[1] = v[1] / d
50 | if v[1] < 0:
51 | v[0] = -2 - v[0]
52 | vs[i] = v[0]
53 | j = 0
54 | temp = 0
55 | for i in range(1, num_of_inter):
56 | if vs[i - 1] > vs[i]:
57 | temp = vs[i]
58 | tx = int_pts[2 * i]
59 | ty = int_pts[2 * i + 1]
60 | j = i
61 | while j > 0 and vs[j - 1] > temp:
62 | vs[j] = vs[j - 1]
63 | int_pts[j * 2] = int_pts[j * 2 - 2]
64 | int_pts[j * 2 + 1] = int_pts[j * 2 - 1]
65 | j -= 1
66 |
67 | vs[j] = temp
68 | int_pts[j * 2] = tx
69 | int_pts[j * 2 + 1] = ty
70 |
71 |
72 | @cuda.jit(
73 | '(float32[:], float32[:], int32, int32, float32[:])',
74 | device=True,
75 | inline=True)
76 | def line_segment_intersection(pts1, pts2, i, j, temp_pts):
77 | A = cuda.local.array((2, ), dtype=numba.float32)
78 | B = cuda.local.array((2, ), dtype=numba.float32)
79 | C = cuda.local.array((2, ), dtype=numba.float32)
80 | D = cuda.local.array((2, ), dtype=numba.float32)
81 |
82 | A[0] = pts1[2 * i]
83 | A[1] = pts1[2 * i + 1]
84 |
85 | B[0] = pts1[2 * ((i + 1) % 4)]
86 | B[1] = pts1[2 * ((i + 1) % 4) + 1]
87 |
88 | C[0] = pts2[2 * j]
89 | C[1] = pts2[2 * j + 1]
90 |
91 | D[0] = pts2[2 * ((j + 1) % 4)]
92 | D[1] = pts2[2 * ((j + 1) % 4) + 1]
93 | BA0 = B[0] - A[0]
94 | BA1 = B[1] - A[1]
95 | DA0 = D[0] - A[0]
96 | CA0 = C[0] - A[0]
97 | DA1 = D[1] - A[1]
98 | CA1 = C[1] - A[1]
99 | acd = DA1 * CA0 > CA1 * DA0
100 | bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
101 | if acd != bcd:
102 | abc = CA1 * BA0 > BA1 * CA0
103 | abd = DA1 * BA0 > BA1 * DA0
104 | if abc != abd:
105 | DC0 = D[0] - C[0]
106 | DC1 = D[1] - C[1]
107 | ABBA = A[0] * B[1] - B[0] * A[1]
108 | CDDC = C[0] * D[1] - D[0] * C[1]
109 | DH = BA1 * DC0 - BA0 * DC1
110 | Dx = ABBA * DC0 - BA0 * CDDC
111 | Dy = ABBA * DC1 - BA1 * CDDC
112 | temp_pts[0] = Dx / DH
113 | temp_pts[1] = Dy / DH
114 | return True
115 | return False
116 |
117 |
118 | @cuda.jit(
119 | '(float32[:], float32[:], int32, int32, float32[:])',
120 | device=True,
121 | inline=True)
122 | def line_segment_intersection_v1(pts1, pts2, i, j, temp_pts):
123 | a = cuda.local.array((2, ), dtype=numba.float32)
124 | b = cuda.local.array((2, ), dtype=numba.float32)
125 | c = cuda.local.array((2, ), dtype=numba.float32)
126 | d = cuda.local.array((2, ), dtype=numba.float32)
127 |
128 | a[0] = pts1[2 * i]
129 | a[1] = pts1[2 * i + 1]
130 |
131 | b[0] = pts1[2 * ((i + 1) % 4)]
132 | b[1] = pts1[2 * ((i + 1) % 4) + 1]
133 |
134 | c[0] = pts2[2 * j]
135 | c[1] = pts2[2 * j + 1]
136 |
137 | d[0] = pts2[2 * ((j + 1) % 4)]
138 | d[1] = pts2[2 * ((j + 1) % 4) + 1]
139 |
140 | area_abc = trangle_area(a, b, c)
141 | area_abd = trangle_area(a, b, d)
142 |
143 | if area_abc * area_abd >= 0:
144 | return False
145 |
146 | area_cda = trangle_area(c, d, a)
147 | area_cdb = area_cda + area_abc - area_abd
148 |
149 | if area_cda * area_cdb >= 0:
150 | return False
151 | t = area_cda / (area_abd - area_abc)
152 |
153 | dx = t * (b[0] - a[0])
154 | dy = t * (b[1] - a[1])
155 | temp_pts[0] = a[0] + dx
156 | temp_pts[1] = a[1] + dy
157 | return True
158 |
159 |
160 | @cuda.jit('(float32, float32, float32[:])', device=True, inline=True)
161 | def point_in_quadrilateral(pt_x, pt_y, corners):
162 | ab0 = corners[2] - corners[0]
163 | ab1 = corners[3] - corners[1]
164 |
165 | ad0 = corners[6] - corners[0]
166 | ad1 = corners[7] - corners[1]
167 |
168 | ap0 = pt_x - corners[0]
169 | ap1 = pt_y - corners[1]
170 |
171 | abab = ab0 * ab0 + ab1 * ab1
172 | abap = ab0 * ap0 + ab1 * ap1
173 | adad = ad0 * ad0 + ad1 * ad1
174 | adap = ad0 * ap0 + ad1 * ap1
175 |
176 | return abab >= abap and abap >= 0 and adad >= adap and adap >= 0
177 |
178 |
179 | @cuda.jit('(float32[:], float32[:], float32[:])', device=True, inline=True)
180 | def quadrilateral_intersection(pts1, pts2, int_pts):
181 | num_of_inter = 0
182 | for i in range(4):
183 | if point_in_quadrilateral(pts1[2 * i], pts1[2 * i + 1], pts2):
184 | int_pts[num_of_inter * 2] = pts1[2 * i]
185 | int_pts[num_of_inter * 2 + 1] = pts1[2 * i + 1]
186 | num_of_inter += 1
187 | if point_in_quadrilateral(pts2[2 * i], pts2[2 * i + 1], pts1):
188 | int_pts[num_of_inter * 2] = pts2[2 * i]
189 | int_pts[num_of_inter * 2 + 1] = pts2[2 * i + 1]
190 | num_of_inter += 1
191 | temp_pts = cuda.local.array((2, ), dtype=numba.float32)
192 | for i in range(4):
193 | for j in range(4):
194 | has_pts = line_segment_intersection(pts1, pts2, i, j, temp_pts)
195 | if has_pts:
196 | int_pts[num_of_inter * 2] = temp_pts[0]
197 | int_pts[num_of_inter * 2 + 1] = temp_pts[1]
198 | num_of_inter += 1
199 |
200 | return num_of_inter
201 |
202 |
203 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True)
204 | def rbbox_to_corners(corners, rbbox):
205 | # generate clockwise corners and rotate it clockwise
206 | angle = rbbox[4]
207 | a_cos = math.cos(angle)
208 | a_sin = math.sin(angle)
209 | center_x = rbbox[0]
210 | center_y = rbbox[1]
211 | x_d = rbbox[2]
212 | y_d = rbbox[3]
213 | corners_x = cuda.local.array((4, ), dtype=numba.float32)
214 | corners_y = cuda.local.array((4, ), dtype=numba.float32)
215 | corners_x[0] = -x_d / 2
216 | corners_x[1] = -x_d / 2
217 | corners_x[2] = x_d / 2
218 | corners_x[3] = x_d / 2
219 | corners_y[0] = -y_d / 2
220 | corners_y[1] = y_d / 2
221 | corners_y[2] = y_d / 2
222 | corners_y[3] = -y_d / 2
223 | for i in range(4):
224 | corners[2 *
225 | i] = a_cos * corners_x[i] + a_sin * corners_y[i] + center_x
226 | corners[2 * i
227 | + 1] = -a_sin * corners_x[i] + a_cos * corners_y[i] + center_y
228 |
229 |
230 | @cuda.jit('(float32[:], float32[:])', device=True, inline=True)
231 | def inter(rbbox1, rbbox2):
232 | corners1 = cuda.local.array((8, ), dtype=numba.float32)
233 | corners2 = cuda.local.array((8, ), dtype=numba.float32)
234 | intersection_corners = cuda.local.array((16, ), dtype=numba.float32)
235 |
236 | rbbox_to_corners(corners1, rbbox1)
237 | rbbox_to_corners(corners2, rbbox2)
238 |
239 | num_intersection = quadrilateral_intersection(corners1, corners2,
240 | intersection_corners)
241 | sort_vertex_in_convex_polygon(intersection_corners, num_intersection)
242 | # print(intersection_corners.reshape([-1, 2])[:num_intersection])
243 |
244 | return area(intersection_corners, num_intersection)
245 |
246 |
247 | @cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True)
248 | def devRotateIoUEval(rbox1, rbox2, criterion=-1):
249 | area1 = rbox1[2] * rbox1[3]
250 | area2 = rbox2[2] * rbox2[3]
251 | area_inter = inter(rbox1, rbox2)
252 | if criterion == -1:
253 | return area_inter / (area1 + area2 - area_inter)
254 | elif criterion == 0:
255 | return area_inter / area1
256 | elif criterion == 1:
257 | return area_inter / area2
258 | else:
259 | return area_inter
260 |
261 | @cuda.jit('(int64, int64, float32[:], float32[:], float32[:], int32)', fastmath=False)
262 | def rotate_iou_kernel_eval(N, K, dev_boxes, dev_query_boxes, dev_iou, criterion=-1):
263 | threadsPerBlock = 8 * 8
264 | row_start = cuda.blockIdx.x
265 | col_start = cuda.blockIdx.y
266 | tx = cuda.threadIdx.x
267 | row_size = min(N - row_start * threadsPerBlock, threadsPerBlock)
268 | col_size = min(K - col_start * threadsPerBlock, threadsPerBlock)
269 | block_boxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
270 | block_qboxes = cuda.shared.array(shape=(64 * 5, ), dtype=numba.float32)
271 |
272 | dev_query_box_idx = threadsPerBlock * col_start + tx
273 | dev_box_idx = threadsPerBlock * row_start + tx
274 | if (tx < col_size):
275 | block_qboxes[tx * 5 + 0] = dev_query_boxes[dev_query_box_idx * 5 + 0]
276 | block_qboxes[tx * 5 + 1] = dev_query_boxes[dev_query_box_idx * 5 + 1]
277 | block_qboxes[tx * 5 + 2] = dev_query_boxes[dev_query_box_idx * 5 + 2]
278 | block_qboxes[tx * 5 + 3] = dev_query_boxes[dev_query_box_idx * 5 + 3]
279 | block_qboxes[tx * 5 + 4] = dev_query_boxes[dev_query_box_idx * 5 + 4]
280 | if (tx < row_size):
281 | block_boxes[tx * 5 + 0] = dev_boxes[dev_box_idx * 5 + 0]
282 | block_boxes[tx * 5 + 1] = dev_boxes[dev_box_idx * 5 + 1]
283 | block_boxes[tx * 5 + 2] = dev_boxes[dev_box_idx * 5 + 2]
284 | block_boxes[tx * 5 + 3] = dev_boxes[dev_box_idx * 5 + 3]
285 | block_boxes[tx * 5 + 4] = dev_boxes[dev_box_idx * 5 + 4]
286 | cuda.syncthreads()
287 | if tx < row_size:
288 | for i in range(col_size):
289 | offset = row_start * threadsPerBlock * K + col_start * threadsPerBlock + tx * K + i
290 | dev_iou[offset] = devRotateIoUEval(block_qboxes[i * 5:i * 5 + 5],
291 | block_boxes[tx * 5:tx * 5 + 5], criterion)
292 |
293 |
294 | def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0):
295 | """rotated box iou running in gpu. 500x faster than cpu version
296 | (take 5ms in one example with numba.cuda code).
297 | convert from [this project](
298 | https://github.com/hongzhenwang/RRPN-revise/tree/master/lib/rotation).
299 |
300 | Args:
301 | boxes (float tensor: [N, 5]): rbboxes. format: centers, dims,
302 | angles(clockwise when positive)
303 | query_boxes (float tensor: [K, 5]): [description]
304 | device_id (int, optional): Defaults to 0. [description]
305 |
306 | Returns:
307 | [type]: [description]
308 | """
309 | box_dtype = boxes.dtype
310 | boxes = boxes.astype(np.float32)
311 | query_boxes = query_boxes.astype(np.float32)
312 | N = boxes.shape[0]
313 | K = query_boxes.shape[0]
314 | iou = np.zeros((N, K), dtype=np.float32)
315 | if N == 0 or K == 0:
316 | return iou
317 | threadsPerBlock = 8 * 8
318 | cuda.select_device(device_id)
319 | blockspergrid = (div_up(N, threadsPerBlock), div_up(K, threadsPerBlock))
320 |
321 | stream = cuda.stream()
322 | with stream.auto_synchronize():
323 | boxes_dev = cuda.to_device(boxes.reshape([-1]), stream)
324 | query_boxes_dev = cuda.to_device(query_boxes.reshape([-1]), stream)
325 | iou_dev = cuda.to_device(iou.reshape([-1]), stream)
326 | rotate_iou_kernel_eval[blockspergrid, threadsPerBlock, stream](
327 | N, K, boxes_dev, query_boxes_dev, iou_dev, criterion)
328 | iou_dev.copy_to_host(iou.reshape([-1]), stream=stream)
329 | return iou.astype(boxes.dtype)
330 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import warnings
3 | from importlib.util import find_spec
4 | from pathlib import Path
5 | from typing import Any, Callable, Dict, List
6 | import numpy as np
7 |
8 | import hydra
9 | from omegaconf import DictConfig
10 | from pytorch_lightning import Callback
11 | from pytorch_lightning.loggers import LightningLoggerBase
12 | from pytorch_lightning.utilities import rank_zero_only
13 |
14 | from src.utils import pylogger, rich_utils
15 |
16 | log = pylogger.get_pylogger(__name__)
17 |
18 |
19 | def task_wrapper(task_func: Callable) -> Callable:
20 | """Optional decorator that wraps the task function in extra utilities.
21 |
22 | Makes multirun more resistant to failure.
23 |
24 | Utilities:
25 | - Calling the `utils.extras()` before the task is started
26 | - Calling the `utils.close_loggers()` after the task is finished
27 | - Logging the exception if occurs
28 | - Logging the task total execution time
29 | - Logging the output dir
30 | """
31 |
32 | def wrap(cfg: DictConfig):
33 |
34 | # apply extra utilities
35 | extras(cfg)
36 |
37 | # execute the task
38 | try:
39 | start_time = time.time()
40 | metric_dict, object_dict = task_func(cfg=cfg)
41 | except Exception as ex:
42 | log.exception("") # save exception to `.log` file
43 | raise ex
44 | finally:
45 | path = Path(cfg.paths.output_dir, "exec_time.log")
46 | content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
47 | save_file(path, content) # save task execution time (even if exception occurs)
48 | close_loggers() # close loggers (even if exception occurs so multirun won't fail)
49 |
50 | log.info(f"Output dir: {cfg.paths.output_dir}")
51 |
52 | return metric_dict, object_dict
53 |
54 | return wrap
55 |
56 |
57 | def extras(cfg: DictConfig) -> None:
58 | """Applies optional utilities before the task is started.
59 |
60 | Utilities:
61 | - Ignoring python warnings
62 | - Setting tags from command line
63 | - Rich config printing
64 | """
65 |
66 | # return if no `extras` config
67 | if not cfg.get("extras"):
68 | log.warning("Extras config not found! ")
69 | return
70 |
71 | # disable python warnings
72 | if cfg.extras.get("ignore_warnings"):
73 | log.info("Disabling python warnings! ")
74 | warnings.filterwarnings("ignore")
75 |
76 | # prompt user to input tags from command line if none are provided in the config
77 | if cfg.extras.get("enforce_tags"):
78 | log.info("Enforcing tags! ")
79 | rich_utils.enforce_tags(cfg, save_to_file=True)
80 |
81 | # pretty print config tree using Rich library
82 | if cfg.extras.get("print_config"):
83 | log.info("Printing config tree with Rich! ")
84 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
85 |
86 |
87 | @rank_zero_only
88 | def save_file(path: str, content: str) -> None:
89 | """Save file in rank zero mode (only on one process in multi-GPU setup)."""
90 | with open(path, "w+") as file:
91 | file.write(content)
92 |
93 |
94 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
95 | """Instantiates callbacks from config."""
96 | callbacks: List[Callback] = []
97 |
98 | if not callbacks_cfg:
99 | log.warning("Callbacks config is empty.")
100 | return callbacks
101 |
102 | if not isinstance(callbacks_cfg, DictConfig):
103 | raise TypeError("Callbacks config must be a DictConfig!")
104 |
105 | for _, cb_conf in callbacks_cfg.items():
106 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
107 | log.info(f"Instantiating callback <{cb_conf._target_}>")
108 | callbacks.append(hydra.utils.instantiate(cb_conf))
109 |
110 | return callbacks
111 |
112 |
113 | def instantiate_loggers(logger_cfg: DictConfig) -> List[LightningLoggerBase]:
114 | """Instantiates loggers from config."""
115 | logger: List[LightningLoggerBase] = []
116 |
117 | if not logger_cfg:
118 | log.warning("Logger config is empty.")
119 | return logger
120 |
121 | if not isinstance(logger_cfg, DictConfig):
122 | raise TypeError("Logger config must be a DictConfig!")
123 |
124 | for _, lg_conf in logger_cfg.items():
125 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
126 | log.info(f"Instantiating logger <{lg_conf._target_}>")
127 | logger.append(hydra.utils.instantiate(lg_conf))
128 |
129 | return logger
130 |
131 |
132 | @rank_zero_only
133 | def log_hyperparameters(object_dict: dict) -> None:
134 | """Controls which config parts are saved by lightning loggers.
135 |
136 | Additionally saves:
137 | - Number of model parameters
138 | """
139 |
140 | hparams = {}
141 |
142 | cfg = object_dict["cfg"]
143 | model = object_dict["model"]
144 | trainer = object_dict["trainer"]
145 |
146 | if not trainer.logger:
147 | log.warning("Logger not found! Skipping hyperparameter logging...")
148 | return
149 |
150 | hparams["model"] = cfg["model"]
151 |
152 | # save number of model parameters
153 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
154 | hparams["model/params/trainable"] = sum(
155 | p.numel() for p in model.parameters() if p.requires_grad
156 | )
157 | hparams["model/params/non_trainable"] = sum(
158 | p.numel() for p in model.parameters() if not p.requires_grad
159 | )
160 |
161 | hparams["datamodule"] = cfg["datamodule"]
162 | hparams["trainer"] = cfg["trainer"]
163 |
164 | hparams["callbacks"] = cfg.get("callbacks")
165 | hparams["extras"] = cfg.get("extras")
166 |
167 | hparams["task_name"] = cfg.get("task_name")
168 | hparams["tags"] = cfg.get("tags")
169 | hparams["ckpt_path"] = cfg.get("ckpt_path")
170 | hparams["seed"] = cfg.get("seed")
171 |
172 | # send hparams to all loggers
173 | trainer.logger.log_hyperparams(hparams)
174 |
175 |
176 | def get_metric_value(metric_dict: dict, metric_name: str) -> float:
177 | """Safely retrieves value of the metric logged in LightningModule."""
178 |
179 | if not metric_name:
180 | log.info("Metric name is None! Skipping metric value retrieval...")
181 | return None
182 |
183 | if metric_name not in metric_dict:
184 | raise Exception(
185 | f"Metric value not found! \n"
186 | "Make sure metric name logged in LightningModule is correct!\n"
187 | "Make sure `optimized_metric` name in `hparams_search` config is correct!"
188 | )
189 |
190 | metric_value = metric_dict[metric_name].item()
191 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
192 |
193 | return metric_value
194 |
195 |
196 | def close_loggers() -> None:
197 | """Makes sure all loggers closed properly (prevents logging failure during multirun)."""
198 |
199 | log.info("Closing loggers...")
200 |
201 | if find_spec("wandb"): # if wandb is installed
202 | import wandb
203 |
204 | if wandb.run:
205 | log.info("Closing wandb!")
206 | wandb.finish()
207 |
208 | class detectionInfo(object):
209 | """
210 | utils for YOLO3D
211 | detectionInfo is a class that contains information about the detection
212 | """
213 | def __init__(self, line):
214 | self.name = line[0]
215 |
216 | self.truncation = float(line[1])
217 | self.occlusion = int(line[2])
218 |
219 | # local orientation = alpha + pi/2
220 | self.alpha = float(line[3])
221 |
222 | # in pixel coordinate
223 | self.xmin = float(line[4])
224 | self.ymin = float(line[5])
225 | self.xmax = float(line[6])
226 | self.ymax = float(line[7])
227 |
228 | # height, weigh, length in object coordinate, meter
229 | self.h = float(line[8])
230 | self.w = float(line[9])
231 | self.l = float(line[10])
232 |
233 | # x, y, z in camera coordinate, meter
234 | self.tx = float(line[11])
235 | self.ty = float(line[12])
236 | self.tz = float(line[13])
237 |
238 | # global orientation [-pi, pi]
239 | self.rot_global = float(line[14])
240 |
241 | def member_to_list(self):
242 | output_line = []
243 | for name, value in vars(self).items():
244 | output_line.append(value)
245 | return output_line
246 |
247 | def box3d_candidate(self, rot_local, soft_range):
248 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0]
249 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0]
250 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0]
251 |
252 | x_corners = [i - self.l / 2 for i in x_corners]
253 | y_corners = [i - self.h / 2 for i in y_corners]
254 | z_corners = [i - self.w / 2 for i in z_corners]
255 |
256 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners]))
257 | point1 = corners_3d[0, :]
258 | point2 = corners_3d[1, :]
259 | point3 = corners_3d[2, :]
260 | point4 = corners_3d[3, :]
261 | point5 = corners_3d[6, :]
262 | point6 = corners_3d[7, :]
263 | point7 = corners_3d[4, :]
264 | point8 = corners_3d[5, :]
265 |
266 | # set up projection relation based on local orientation
267 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0
268 |
269 | if 0 < rot_local < np.pi / 2:
270 | xmin_candi = point8
271 | xmax_candi = point2
272 | ymin_candi = point2
273 | ymax_candi = point5
274 |
275 | if np.pi / 2 <= rot_local <= np.pi:
276 | xmin_candi = point6
277 | xmax_candi = point4
278 | ymin_candi = point4
279 | ymax_candi = point1
280 |
281 | if np.pi < rot_local <= 3 / 2 * np.pi:
282 | xmin_candi = point2
283 | xmax_candi = point8
284 | ymin_candi = point8
285 | ymax_candi = point1
286 |
287 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi:
288 | xmin_candi = point4
289 | xmax_candi = point6
290 | ymin_candi = point6
291 | ymax_candi = point5
292 |
293 | # soft constraint
294 | div = soft_range * np.pi / 180
295 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi:
296 | xmin_candi = point8
297 | xmax_candi = point6
298 | ymin_candi = point6
299 | ymax_candi = point5
300 |
301 | if np.pi - div < rot_local < np.pi + div:
302 | xmin_candi = point2
303 | xmax_candi = point4
304 | ymin_candi = point8
305 | ymax_candi = point1
306 |
307 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi
308 |
309 |
310 | class KITTIObject():
311 | """
312 | utils for YOLO3D
313 | detectionInfo is a class that contains information about the detection
314 | """
315 | def __init__(self, line = np.zeros(16)):
316 | self.name = line[0]
317 |
318 | self.truncation = float(line[1])
319 | self.occlusion = int(line[2])
320 |
321 | # local orientation = alpha + pi/2
322 | self.alpha = float(line[3])
323 |
324 | # in pixel coordinate
325 | self.xmin = float(line[4])
326 | self.ymin = float(line[5])
327 | self.xmax = float(line[6])
328 | self.ymax = float(line[7])
329 |
330 | # height, weigh, length in object coordinate, meter
331 | self.h = float(line[8])
332 | self.w = float(line[9])
333 | self.l = float(line[10])
334 |
335 | # x, y, z in camera coordinate, meter
336 | self.tx = float(line[11])
337 | self.ty = float(line[12])
338 | self.tz = float(line[13])
339 |
340 | # global orientation [-pi, pi]
341 | self.rot_global = float(line[14])
342 |
343 | # score
344 | self.score = float(line[15])
345 |
346 | def member_to_list(self):
347 | output_line = []
348 | for name, value in vars(self).items():
349 | output_line.append(value)
350 | return output_line
351 |
352 | def box3d_candidate(self, rot_local, soft_range):
353 | x_corners = [self.l, self.l, self.l, self.l, 0, 0, 0, 0]
354 | y_corners = [self.h, 0, self.h, 0, self.h, 0, self.h, 0]
355 | z_corners = [0, 0, self.w, self.w, self.w, self.w, 0, 0]
356 |
357 | x_corners = [i - self.l / 2 for i in x_corners]
358 | y_corners = [i - self.h / 2 for i in y_corners]
359 | z_corners = [i - self.w / 2 for i in z_corners]
360 |
361 | corners_3d = np.transpose(np.array([x_corners, y_corners, z_corners]))
362 | point1 = corners_3d[0, :]
363 | point2 = corners_3d[1, :]
364 | point3 = corners_3d[2, :]
365 | point4 = corners_3d[3, :]
366 | point5 = corners_3d[6, :]
367 | point6 = corners_3d[7, :]
368 | point7 = corners_3d[4, :]
369 | point8 = corners_3d[5, :]
370 |
371 | # set up projection relation based on local orientation
372 | xmin_candi = xmax_candi = ymin_candi = ymax_candi = 0
373 |
374 | if 0 < rot_local < np.pi / 2:
375 | xmin_candi = point8
376 | xmax_candi = point2
377 | ymin_candi = point2
378 | ymax_candi = point5
379 |
380 | if np.pi / 2 <= rot_local <= np.pi:
381 | xmin_candi = point6
382 | xmax_candi = point4
383 | ymin_candi = point4
384 | ymax_candi = point1
385 |
386 | if np.pi < rot_local <= 3 / 2 * np.pi:
387 | xmin_candi = point2
388 | xmax_candi = point8
389 | ymin_candi = point8
390 | ymax_candi = point1
391 |
392 | if 3 * np.pi / 2 <= rot_local <= 2 * np.pi:
393 | xmin_candi = point4
394 | xmax_candi = point6
395 | ymin_candi = point6
396 | ymax_candi = point5
397 |
398 | # soft constraint
399 | div = soft_range * np.pi / 180
400 | if 0 < rot_local < div or 2*np.pi-div < rot_local < 2*np.pi:
401 | xmin_candi = point8
402 | xmax_candi = point6
403 | ymin_candi = point6
404 | ymax_candi = point5
405 |
406 | if np.pi - div < rot_local < np.pi + div:
407 | xmin_candi = point2
408 | xmax_candi = point4
409 | ymin_candi = point8
410 | ymax_candi = point1
411 |
412 | return xmin_candi, xmax_candi, ymin_candi, ymax_candi
413 |
414 | if __name__ == "__main__":
415 | pass
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pyrootutils
2 | import pytest
3 | from hydra import compose, initialize
4 | from hydra.core.global_hydra import GlobalHydra
5 | from omegaconf import DictConfig, open_dict
6 |
7 |
8 | @pytest.fixture(scope="package")
9 | def cfg_train_global() -> DictConfig:
10 | with initialize(version_base="1.2", config_path="../configs"):
11 | cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
12 |
13 | # set defaults for all tests
14 | with open_dict(cfg):
15 | cfg.paths.root_dir = str(pyrootutils.find_root())
16 | cfg.trainer.max_epochs = 1
17 | cfg.trainer.limit_train_batches = 0.01
18 | cfg.trainer.limit_val_batches = 0.1
19 | cfg.trainer.limit_test_batches = 0.1
20 | cfg.trainer.accelerator = "cpu"
21 | cfg.trainer.devices = 1
22 | cfg.datamodule.num_workers = 0
23 | cfg.datamodule.pin_memory = False
24 | cfg.extras.print_config = False
25 | cfg.extras.enforce_tags = False
26 | cfg.logger = None
27 |
28 | return cfg
29 |
30 |
31 | @pytest.fixture(scope="package")
32 | def cfg_eval_global() -> DictConfig:
33 | with initialize(version_base="1.2", config_path="../configs"):
34 | cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."])
35 |
36 | # set defaults for all tests
37 | with open_dict(cfg):
38 | cfg.paths.root_dir = str(pyrootutils.find_root())
39 | cfg.trainer.max_epochs = 1
40 | cfg.trainer.limit_test_batches = 0.1
41 | cfg.trainer.accelerator = "cpu"
42 | cfg.trainer.devices = 1
43 | cfg.datamodule.num_workers = 0
44 | cfg.datamodule.pin_memory = False
45 | cfg.extras.print_config = False
46 | cfg.extras.enforce_tags = False
47 | cfg.logger = None
48 |
49 | return cfg
50 |
51 |
52 | # this is called by each test which uses `cfg_train` arg
53 | # each test generates its own temporary logging path
54 | @pytest.fixture(scope="function")
55 | def cfg_train(cfg_train_global, tmp_path) -> DictConfig:
56 | cfg = cfg_train_global.copy()
57 |
58 | with open_dict(cfg):
59 | cfg.paths.output_dir = str(tmp_path)
60 | cfg.paths.log_dir = str(tmp_path)
61 |
62 | yield cfg
63 |
64 | GlobalHydra.instance().clear()
65 |
66 |
67 | # this is called by each test which uses `cfg_eval` arg
68 | # each test generates its own temporary logging path
69 | @pytest.fixture(scope="function")
70 | def cfg_eval(cfg_eval_global, tmp_path) -> DictConfig:
71 | cfg = cfg_eval_global.copy()
72 |
73 | with open_dict(cfg):
74 | cfg.paths.output_dir = str(tmp_path)
75 | cfg.paths.log_dir = str(tmp_path)
76 |
77 | yield cfg
78 |
79 | GlobalHydra.instance().clear()
80 |
--------------------------------------------------------------------------------
/tests/helpers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tests/helpers/__init__.py
--------------------------------------------------------------------------------
/tests/helpers/package_available.py:
--------------------------------------------------------------------------------
1 | import platform
2 |
3 | import pkg_resources
4 | from pytorch_lightning.utilities.xla_device import XLADeviceUtils
5 |
6 |
7 | def _package_available(package_name: str) -> bool:
8 | """Check if a package is available in your environment."""
9 | try:
10 | return pkg_resources.require(package_name) is not None
11 | except pkg_resources.DistributionNotFound:
12 | return False
13 |
14 |
15 | _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
16 |
17 | _IS_WINDOWS = platform.system() == "Windows"
18 |
19 | _SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh")
20 |
21 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed")
22 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale")
23 |
24 | _WANDB_AVAILABLE = _package_available("wandb")
25 | _NEPTUNE_AVAILABLE = _package_available("neptune")
26 | _COMET_AVAILABLE = _package_available("comet_ml")
27 | _MLFLOW_AVAILABLE = _package_available("mlflow")
28 |
--------------------------------------------------------------------------------
/tests/helpers/run_if.py:
--------------------------------------------------------------------------------
1 | """Adapted from:
2 |
3 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
4 | """
5 |
6 | import sys
7 | from typing import Optional
8 |
9 | import pytest
10 | import torch
11 | from packaging.version import Version
12 | from pkg_resources import get_distribution
13 |
14 | from tests.helpers.package_available import (
15 | _COMET_AVAILABLE,
16 | _DEEPSPEED_AVAILABLE,
17 | _FAIRSCALE_AVAILABLE,
18 | _IS_WINDOWS,
19 | _MLFLOW_AVAILABLE,
20 | _NEPTUNE_AVAILABLE,
21 | _SH_AVAILABLE,
22 | _TPU_AVAILABLE,
23 | _WANDB_AVAILABLE,
24 | )
25 |
26 |
27 | class RunIf:
28 | """RunIf wrapper for conditional skipping of tests.
29 |
30 | Fully compatible with `@pytest.mark`.
31 |
32 | Example:
33 |
34 | @RunIf(min_torch="1.8")
35 | @pytest.mark.parametrize("arg1", [1.0, 2.0])
36 | def test_wrapper(arg1):
37 | assert arg1 > 0
38 | """
39 |
40 | def __new__(
41 | self,
42 | min_gpus: int = 0,
43 | min_torch: Optional[str] = None,
44 | max_torch: Optional[str] = None,
45 | min_python: Optional[str] = None,
46 | skip_windows: bool = False,
47 | sh: bool = False,
48 | tpu: bool = False,
49 | fairscale: bool = False,
50 | deepspeed: bool = False,
51 | wandb: bool = False,
52 | neptune: bool = False,
53 | comet: bool = False,
54 | mlflow: bool = False,
55 | **kwargs,
56 | ):
57 | """
58 | Args:
59 | min_gpus: min number of GPUs required to run test
60 | min_torch: minimum pytorch version to run test
61 | max_torch: maximum pytorch version to run test
62 | min_python: minimum python version required to run test
63 | skip_windows: skip test for Windows platform
64 | tpu: if TPU is available
65 | sh: if `sh` module is required to run the test
66 | fairscale: if `fairscale` module is required to run the test
67 | deepspeed: if `deepspeed` module is required to run the test
68 | wandb: if `wandb` module is required to run the test
69 | neptune: if `neptune` module is required to run the test
70 | comet: if `comet` module is required to run the test
71 | mlflow: if `mlflow` module is required to run the test
72 | kwargs: native pytest.mark.skipif keyword arguments
73 | """
74 | conditions = []
75 | reasons = []
76 |
77 | if min_gpus:
78 | conditions.append(torch.cuda.device_count() < min_gpus)
79 | reasons.append(f"GPUs>={min_gpus}")
80 |
81 | if min_torch:
82 | torch_version = get_distribution("torch").version
83 | conditions.append(Version(torch_version) < Version(min_torch))
84 | reasons.append(f"torch>={min_torch}")
85 |
86 | if max_torch:
87 | torch_version = get_distribution("torch").version
88 | conditions.append(Version(torch_version) >= Version(max_torch))
89 | reasons.append(f"torch<{max_torch}")
90 |
91 | if min_python:
92 | py_version = (
93 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
94 | )
95 | conditions.append(Version(py_version) < Version(min_python))
96 | reasons.append(f"python>={min_python}")
97 |
98 | if skip_windows:
99 | conditions.append(_IS_WINDOWS)
100 | reasons.append("does not run on Windows")
101 |
102 | if tpu:
103 | conditions.append(not _TPU_AVAILABLE)
104 | reasons.append("TPU")
105 |
106 | if sh:
107 | conditions.append(not _SH_AVAILABLE)
108 | reasons.append("sh")
109 |
110 | if fairscale:
111 | conditions.append(not _FAIRSCALE_AVAILABLE)
112 | reasons.append("fairscale")
113 |
114 | if deepspeed:
115 | conditions.append(not _DEEPSPEED_AVAILABLE)
116 | reasons.append("deepspeed")
117 |
118 | if wandb:
119 | conditions.append(not _WANDB_AVAILABLE)
120 | reasons.append("wandb")
121 |
122 | if neptune:
123 | conditions.append(not _NEPTUNE_AVAILABLE)
124 | reasons.append("neptune")
125 |
126 | if comet:
127 | conditions.append(not _COMET_AVAILABLE)
128 | reasons.append("comet")
129 |
130 | if mlflow:
131 | conditions.append(not _MLFLOW_AVAILABLE)
132 | reasons.append("mlflow")
133 |
134 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
135 | return pytest.mark.skipif(
136 | condition=any(conditions),
137 | reason=f"Requires: [{' + '.join(reasons)}]",
138 | **kwargs,
139 | )
140 |
--------------------------------------------------------------------------------
/tests/helpers/run_sh_command.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pytest
4 |
5 | from tests.helpers.package_available import _SH_AVAILABLE
6 |
7 | if _SH_AVAILABLE:
8 | import sh
9 |
10 |
11 | def run_sh_command(command: List[str]):
12 | """Default method for executing shell commands with pytest and sh package."""
13 | msg = None
14 | try:
15 | sh.python(command)
16 | except sh.ErrorReturnCode as e:
17 | msg = e.stderr.decode()
18 | if msg:
19 | pytest.fail(msg=msg)
20 |
--------------------------------------------------------------------------------
/tests/test_configs.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from hydra.core.hydra_config import HydraConfig
3 | from omegaconf import DictConfig
4 |
5 |
6 | def test_train_config(cfg_train: DictConfig):
7 | assert cfg_train
8 | assert cfg_train.datamodule
9 | assert cfg_train.model
10 | assert cfg_train.trainer
11 |
12 | HydraConfig().set_config(cfg_train)
13 |
14 | hydra.utils.instantiate(cfg_train.datamodule)
15 | hydra.utils.instantiate(cfg_train.model)
16 | hydra.utils.instantiate(cfg_train.trainer)
17 |
18 |
19 | def test_eval_config(cfg_eval: DictConfig):
20 | assert cfg_eval
21 | assert cfg_eval.datamodule
22 | assert cfg_eval.model
23 | assert cfg_eval.trainer
24 |
25 | HydraConfig().set_config(cfg_eval)
26 |
27 | hydra.utils.instantiate(cfg_eval.datamodule)
28 | hydra.utils.instantiate(cfg_eval.model)
29 | hydra.utils.instantiate(cfg_eval.trainer)
30 |
--------------------------------------------------------------------------------
/tests/test_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from hydra.core.hydra_config import HydraConfig
5 | from omegaconf import open_dict
6 |
7 | from src.eval import evaluate
8 | from src.train import train
9 |
10 |
11 | @pytest.mark.slow
12 | def test_train_eval(tmp_path, cfg_train, cfg_eval):
13 | """Train for 1 epoch with `train.py` and evaluate with `eval.py`"""
14 | assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir
15 |
16 | with open_dict(cfg_train):
17 | cfg_train.trainer.max_epochs = 1
18 | cfg_train.test = True
19 |
20 | HydraConfig().set_config(cfg_train)
21 | train_metric_dict, _ = train(cfg_train)
22 |
23 | assert "last.ckpt" in os.listdir(tmp_path / "checkpoints")
24 |
25 | with open_dict(cfg_eval):
26 | cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
27 |
28 | HydraConfig().set_config(cfg_eval)
29 | test_metric_dict, _ = evaluate(cfg_eval)
30 |
31 | assert test_metric_dict["test/acc"] > 0.0
32 | assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001
33 |
--------------------------------------------------------------------------------
/tests/test_mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 | import torch
5 |
6 | from src.datamodules.mnist_datamodule import MNISTDataModule
7 |
8 |
9 | @pytest.mark.parametrize("batch_size", [32, 128])
10 | def test_mnist_datamodule(batch_size):
11 | data_dir = "data/"
12 |
13 | dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size)
14 | dm.prepare_data()
15 |
16 | assert not dm.data_train and not dm.data_val and not dm.data_test
17 | assert Path(data_dir, "MNIST").exists()
18 | assert Path(data_dir, "MNIST", "raw").exists()
19 |
20 | dm.setup()
21 | assert dm.data_train and dm.data_val and dm.data_test
22 | assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader()
23 |
24 | num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test)
25 | assert num_datapoints == 70_000
26 |
27 | batch = next(iter(dm.train_dataloader()))
28 | x, y = batch
29 | assert len(x) == batch_size
30 | assert len(y) == batch_size
31 | assert x.dtype == torch.float32
32 | assert y.dtype == torch.int64
33 |
--------------------------------------------------------------------------------
/tests/test_sweeps.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from tests.helpers.run_if import RunIf
4 | from tests.helpers.run_sh_command import run_sh_command
5 |
6 | startfile = "src/train.py"
7 | overrides = ["logger=[]"]
8 |
9 |
10 | @RunIf(sh=True)
11 | @pytest.mark.slow
12 | def test_experiments(tmp_path):
13 | """Test running all available experiment configs with fast_dev_run=True."""
14 | command = [
15 | startfile,
16 | "-m",
17 | "experiment=glob(*)",
18 | "hydra.sweep.dir=" + str(tmp_path),
19 | "++trainer.fast_dev_run=true",
20 | ] + overrides
21 | run_sh_command(command)
22 |
23 |
24 | @RunIf(sh=True)
25 | @pytest.mark.slow
26 | def test_hydra_sweep(tmp_path):
27 | """Test default hydra sweep."""
28 | command = [
29 | startfile,
30 | "-m",
31 | "hydra.sweep.dir=" + str(tmp_path),
32 | "model.optimizer.lr=0.005,0.01",
33 | "++trainer.fast_dev_run=true",
34 | ] + overrides
35 |
36 | run_sh_command(command)
37 |
38 |
39 | @RunIf(sh=True)
40 | @pytest.mark.slow
41 | def test_hydra_sweep_ddp_sim(tmp_path):
42 | """Test default hydra sweep with ddp sim."""
43 | command = [
44 | startfile,
45 | "-m",
46 | "hydra.sweep.dir=" + str(tmp_path),
47 | "trainer=ddp_sim",
48 | "trainer.max_epochs=3",
49 | "+trainer.limit_train_batches=0.01",
50 | "+trainer.limit_val_batches=0.1",
51 | "+trainer.limit_test_batches=0.1",
52 | "model.optimizer.lr=0.005,0.01,0.02",
53 | ] + overrides
54 | run_sh_command(command)
55 |
56 |
57 | @RunIf(sh=True)
58 | @pytest.mark.slow
59 | def test_optuna_sweep(tmp_path):
60 | """Test optuna sweep."""
61 | command = [
62 | startfile,
63 | "-m",
64 | "hparams_search=mnist_optuna",
65 | "hydra.sweep.dir=" + str(tmp_path),
66 | "hydra.sweeper.n_trials=10",
67 | "hydra.sweeper.sampler.n_startup_trials=5",
68 | "++trainer.fast_dev_run=true",
69 | ] + overrides
70 | run_sh_command(command)
71 |
72 |
73 | @RunIf(wandb=True, sh=True)
74 | @pytest.mark.slow
75 | def test_optuna_sweep_ddp_sim_wandb(tmp_path):
76 | """Test optuna sweep with wandb and ddp sim."""
77 | command = [
78 | startfile,
79 | "-m",
80 | "hparams_search=mnist_optuna",
81 | "hydra.sweep.dir=" + str(tmp_path),
82 | "hydra.sweeper.n_trials=5",
83 | "trainer=ddp_sim",
84 | "trainer.max_epochs=3",
85 | "+trainer.limit_train_batches=0.01",
86 | "+trainer.limit_val_batches=0.1",
87 | "+trainer.limit_test_batches=0.1",
88 | "logger=wandb",
89 | ]
90 | run_sh_command(command)
91 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | from hydra.core.hydra_config import HydraConfig
5 | from omegaconf import open_dict
6 |
7 | from src.train import train
8 | from tests.helpers.run_if import RunIf
9 |
10 |
11 | def test_train_fast_dev_run(cfg_train):
12 | """Run for 1 train, val and test step."""
13 | HydraConfig().set_config(cfg_train)
14 | with open_dict(cfg_train):
15 | cfg_train.trainer.fast_dev_run = True
16 | cfg_train.trainer.accelerator = "cpu"
17 | train(cfg_train)
18 |
19 |
20 | @RunIf(min_gpus=1)
21 | def test_train_fast_dev_run_gpu(cfg_train):
22 | """Run for 1 train, val and test step on GPU."""
23 | HydraConfig().set_config(cfg_train)
24 | with open_dict(cfg_train):
25 | cfg_train.trainer.fast_dev_run = True
26 | cfg_train.trainer.accelerator = "gpu"
27 | train(cfg_train)
28 |
29 |
30 | @RunIf(min_gpus=1)
31 | @pytest.mark.slow
32 | def test_train_epoch_gpu_amp(cfg_train):
33 | """Train 1 epoch on GPU with mixed-precision."""
34 | HydraConfig().set_config(cfg_train)
35 | with open_dict(cfg_train):
36 | cfg_train.trainer.max_epochs = 1
37 | cfg_train.trainer.accelerator = "cpu"
38 | cfg_train.trainer.precision = 16
39 | train(cfg_train)
40 |
41 |
42 | @pytest.mark.slow
43 | def test_train_epoch_double_val_loop(cfg_train):
44 | """Train 1 epoch with validation loop twice per epoch."""
45 | HydraConfig().set_config(cfg_train)
46 | with open_dict(cfg_train):
47 | cfg_train.trainer.max_epochs = 1
48 | cfg_train.trainer.val_check_interval = 0.5
49 | train(cfg_train)
50 |
51 |
52 | @pytest.mark.slow
53 | def test_train_ddp_sim(cfg_train):
54 | """Simulate DDP (Distributed Data Parallel) on 2 CPU processes."""
55 | HydraConfig().set_config(cfg_train)
56 | with open_dict(cfg_train):
57 | cfg_train.trainer.max_epochs = 2
58 | cfg_train.trainer.accelerator = "cpu"
59 | cfg_train.trainer.devices = 2
60 | cfg_train.trainer.strategy = "ddp_spawn"
61 | train(cfg_train)
62 |
63 |
64 | @pytest.mark.slow
65 | def test_train_resume(tmp_path, cfg_train):
66 | """Run 1 epoch, finish, and resume for another epoch."""
67 | with open_dict(cfg_train):
68 | cfg_train.trainer.max_epochs = 1
69 |
70 | HydraConfig().set_config(cfg_train)
71 | metric_dict_1, _ = train(cfg_train)
72 |
73 | files = os.listdir(tmp_path / "checkpoints")
74 | assert "last.ckpt" in files
75 | assert "epoch_000.ckpt" in files
76 |
77 | with open_dict(cfg_train):
78 | cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt")
79 | cfg_train.trainer.max_epochs = 2
80 |
81 | metric_dict_2, _ = train(cfg_train)
82 |
83 | files = os.listdir(tmp_path / "checkpoints")
84 | assert "epoch_001.ckpt" in files
85 | assert "epoch_002.ckpt" not in files
86 |
87 | assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"]
88 | assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]
89 |
--------------------------------------------------------------------------------
/tmp/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ApolloAuto/apollo-model-yolo3d/504017b5aa6f68318716a099b6923db062d82ba7/tmp/.gitkeep
--------------------------------------------------------------------------------
/weights/get_regressor_weights.py:
--------------------------------------------------------------------------------
1 | """Get checkpoint from W&B"""
2 |
3 | import wandb
4 |
5 | run = wandb.init()
6 | artifact = run.use_artifact('3ddetection/yolo3d-regressor/experiment-ckpts:v11', type='checkpoints')
7 | artifact_dir = artifact.download()
--------------------------------------------------------------------------------