├── .github
└── FUNDING.yml
├── .gitignore
├── LICENSE
├── README.md
├── figs
└── pipeline.png
├── projects
├── __init__.py
├── bevformer
│ ├── __init__.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── custom_base_transformer_layer.py
│ │ ├── decoder.py
│ │ ├── encoder.py
│ │ ├── multi_scale_deformable_attn_function.py
│ │ ├── spatial_cross_attention.py
│ │ └── temporal_self_attention.py
├── configs
│ ├── toponet_r50_8x1_24e_olv2_subset_A.py
│ └── toponet_r50_8x1_24e_olv2_subset_B.py
└── toponet
│ ├── __init__.py
│ ├── core
│ ├── __init__.py
│ ├── lane
│ │ ├── __init__.py
│ │ ├── assigners
│ │ │ ├── __init__.py
│ │ │ └── lane_hungarian_assigner.py
│ │ ├── coders
│ │ │ ├── __init__.py
│ │ │ └── lane_coder.py
│ │ ├── match_costs
│ │ │ ├── __init__.py
│ │ │ └── match_cost.py
│ │ └── util.py
│ └── visualizer
│ │ ├── __init__.py
│ │ └── lane.py
│ ├── datasets
│ ├── __init__.py
│ ├── openlanev2_subset_A_dataset.py
│ ├── openlanev2_subset_B_dataset.py
│ └── pipelines
│ │ ├── __init__.py
│ │ ├── formating.py
│ │ ├── loading.py
│ │ ├── transform_3d.py
│ │ └── transform_3d_lane.py
│ ├── models
│ ├── __init__.py
│ ├── dense_heads
│ │ ├── __init__.py
│ │ ├── deformable_detr_head.py
│ │ ├── relationship_head.py
│ │ └── toponet_head.py
│ ├── detectors
│ │ ├── __init__.py
│ │ └── toponet.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── bevformer_constructer.py
│ │ ├── sgnn_decoder.py
│ │ └── transformer_decoder_only.py
│ └── utils
│ ├── __init__.py
│ └── builder.py
├── requirements.txt
└── tools
├── dist_test.sh
├── dist_train.sh
├── test.py
└── train.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 |
132 | # Custom
133 | ckpts
134 | data
135 | work_dirs
136 | .vscode
137 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # TopoNet: A New Baseline for Scene Topology Reasoning
4 |
5 | ## Graph-based Topology Reasoning for Driving Scenes
6 |
7 | [](https://arxiv.org/abs/2304.05277)
8 | [](https://github.com/OpenDriveLab/OpenLane-V2)
9 | [](./LICENSE)
10 |
11 | 
12 |
13 |
14 |
15 |
16 | > - Production from [OpenDriveLab](https://opendrivelab.com) at Shanghai AI Lab. Jointly with collaborators at Huawei.
17 | > - Primary contact: [Tianyu Li](https://scholar.google.com/citations?user=X6vTmEMAAAAJ) ( litianyu@opendrivelab.com )
18 |
19 | ---
20 |
21 | This repository contains the source code of **TopoNet**, [Graph-based Topology Reasoning for Driving Scenes](https://arxiv.org/abs/2304.05277).
22 |
23 | TopoNet is the first end-to-end framework capable of abstracting traffic knowledge beyond conventional perception tasks, _i.e._, **reasoning connections between centerlines and traffic elements** from sensor inputs. It unifies heterogeneous feature
24 | learning and enhances feature interactions via the graph neural network architecture and the knowledge graph design.
25 |
26 | Instead of recognizing lanes, we adhere that modeling the lane topology is `appropriate` to construct road components within the perception framework, to facilitate the ultimate driving comfort.
27 | This is in accordance with the [UniAD philosophy](https://github.com/OpenDriveLab/UniAD).
28 |
29 | ## Table of Contents
30 | - [News](#news)
31 | - [Main Results](#main-results)
32 | - [Prerequisites](#prerequisites)
33 | - [Installation](#installation)
34 | - [Prepare Dataset](#prepare-dataset)
35 | - [Train and Evaluate](#train-and-evaluate)
36 | - [License](#license)
37 | - [Citation](#citation)
38 |
39 | ## News
40 |
41 | - **`Pinned:`** The [leaderboard](https://opendrivelab.com/AD23Challenge.html#openlane_topology) for Lane Topology Challenge is open for regular submissions year around. This Challenge **`would`** be back in 2024's edition.
42 | - **`[2023/11]`** :fire:The code and model of OpenLane-V2 subset-B is released!
43 | - **`[2023/08]`** The code and model of TopoNet is released!
44 | - **`[2023/04]`** TopoNet [paper](https://arxiv.org/abs/2304.05277) is available on arXiv.
45 | - **`[2023/01]`** Introducing [Autonomous Driving Challenge](https://opendrivelab.com/AD23Challenge.html) for Lane Topology at CVPR 2023.
46 |
47 |
48 | ## Main Results
49 |
50 | ### Results on OpenLane-V2 subset-A val
51 |
52 | We provide results on **[Openlane-V2](https://github.com/OpenDriveLab/OpenLane-V2) subset-A val** set.
53 |
54 | | Method | Backbone | Epoch | DETl | TOPll | DETt | TOPlt | OLS |
55 | | :----------: | :-------: | :---: | :-------------: | :--------------: | :-------------: | :--------------: | :------: |
56 | | STSU | ResNet-50 | 24 | 12.7 | 0.5 | 43.0 | 15.1 | 25.4 |
57 | | VectorMapNet | ResNet-50 | 24 | 11.1 | 0.4 | 41.7 | 6.2 | 20.8 |
58 | | MapTR | ResNet-50 | 24 | 8.3 | 0.2 | 43.5 | 5.8 | 20.0 |
59 | | MapTR* | ResNet-50 | 24 | 17.7 | 1.1 | 43.5 | 10.4 | 26.0 |
60 | | **TopoNet** | ResNet-50 | 24 | **28.6** | **4.1** | **48.6** | **20.3** | **35.6** |
61 |
62 | :fire:: Based on the updated `v1.1` OpenLane-V2 devkit and metrics, we have reassessed the performance of TopoNet and other SOTA models. For more details please see issue [#76](https://github.com/OpenDriveLab/OpenLane-V2/issues/76) of OpenLane-V2.
63 |
64 | | Method | Backbone | Epoch | DETl | TOPll | DETt | TOPlt | OLS |
65 | | :----------: | :-------: | :---: | :-------------: | :--------------: | :-------------: | :--------------: | :------: |
66 | | STSU | ResNet-50 | 24 | 12.7 | 2.9 | 43.0 | 19.8 | 29.3 |
67 | | VectorMapNet | ResNet-50 | 24 | 11.1 | 2.7 | 41.7 | 9.2 | 24.9 |
68 | | MapTR | ResNet-50 | 24 | 8.3 | 2.3 | 43.5 | 8.9 | 24.2 |
69 | | MapTR* | ResNet-50 | 24 | 17.7 | 5.9 | 43.5 | 15.1 | 31.0 |
70 | | **TopoNet** | ResNet-50 | 24 | **28.6** | **10.9** | **48.6** | **23.8** | **39.8** |
71 |
72 | > *: evaluation based on matching results on Chamfer distance.
73 | > The result of TopoNet is from this repo.
74 |
75 |
76 | ### Results on OpenLane-V2 subset-B val
77 |
78 | | Method | Backbone | Epoch | DETl | TOPll | DETt | TOPlt | OLS |
79 | | :----------: | :-------: | :---: | :-------------: | :--------------: | :-------------: | :--------------: | :------: |
80 | | **TopoNet** | ResNet-50 | 24 | **24.4** | **6.7** | **52.6** | **16.7** | **36.0** |
81 |
82 | > The result is based on the updated `v1.1` OpenLane-V2 devkit and metrics.
83 | > The result of TopoNet is from this repo.
84 |
85 | ## Model Zoo
86 |
87 | | Model | Dataset | Backbone | Epoch | OLS | Memory | Config | Download |
88 | | :---: | :-----: | :------: | :---: | :---: | :----: | :----: | :------: |
89 | | TopoNet-R50 | subset-A | ResNet-50 | 24 | 39.8 | 12.3G | [config](projects/configs/toponet_r50_8x1_24e_olv2_subset_A.py) | [ckpt](https://huggingface.co/OpenDriveLab/toponet_r50_8x1_24e_olv2_subset_A/resolve/main/toponet_r50_8x1_24e_olv2_subset_A.pth) / [log](https://huggingface.co/OpenDriveLab/toponet_r50_8x1_24e_olv2_subset_A/resolve/main/20231017_113808.log) |
90 | | TopoNet-R50 | subset-B | ResNet-50 | 24 | 36.0 | 8.2G | [config](projects/configs/toponet_r50_8x1_24e_olv2_subset_B.py) | [ckpt](https://huggingface.co/OpenDriveLab/toponet_r50_8x1_24e_olv2_subset_B/resolve/main/toponet_r50_8x1_24e_olv2_subset_B.pth) / [log](https://huggingface.co/OpenDriveLab/toponet_r50_8x1_24e_olv2_subset_B/resolve/main/20231127_121131.log) |
91 |
92 |
93 | ## Prerequisites
94 |
95 | - Linux
96 | - Python 3.8.x
97 | - NVIDIA GPU + CUDA 11.1
98 | - PyTorch 1.9.1
99 |
100 | ## Installation
101 |
102 | We recommend using [conda](https://docs.conda.io/en/latest/miniconda.html) to run the code.
103 | ```bash
104 | conda create -n toponet python=3.8 -y
105 | conda activate toponet
106 |
107 | # (optional) If you have CUDA installed on your computer, skip this step.
108 | conda install cudatoolkit=11.1.1 -c conda-forge
109 |
110 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
111 | ```
112 |
113 | Install mm-series packages.
114 | ```bash
115 | pip install mmcv-full==1.5.2 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
116 | pip install mmdet==2.26.0
117 | pip install mmsegmentation==0.29.1
118 | pip install mmdet3d==1.0.0rc6
119 | ```
120 |
121 | Install other required packages.
122 | ```bash
123 | pip install -r requirements.txt
124 | ```
125 |
126 | ## Prepare Dataset
127 |
128 | Following [OpenLane-V2 repo](https://github.com/OpenDriveLab/OpenLane-V2/blob/v1.0.0/data) to download the data and run the [preprocessing](https://github.com/OpenDriveLab/OpenLane-V2/tree/v1.0.0/data#preprocess) code.
129 |
130 | ```bash
131 | cd TopoNet
132 | mkdir data && cd data
133 |
134 | ln -s {PATH to OpenLane-V2 repo}/data/OpenLane-V2
135 | ```
136 |
137 | After setup, the hierarchy of folder `data` is described below:
138 | ```
139 | data/OpenLane-V2
140 | ├── train
141 | | └── ...
142 | ├── val
143 | | └── ...
144 | ├── test
145 | | └── ...
146 | ├── data_dict_subset_A_train.pkl
147 | ├── data_dict_subset_A_val.pkl
148 | ├── data_dict_subset_B_train.pkl
149 | ├── data_dict_subset_B_val.pkl
150 | ├── ...
151 | ```
152 |
153 | ## Train and Evaluate
154 |
155 | ### Train
156 |
157 | We recommend using 8 GPUs for training. If a different number of GPUs is utilized, you can enhance performance by configuring the `--autoscale-lr` option. The training logs will be saved to `work_dirs/toponet`.
158 |
159 | ```bash
160 | cd TopoNet
161 | mkdir -p work_dirs/toponet
162 |
163 | ./tools/dist_train.sh 8 [--autoscale-lr]
164 | ```
165 |
166 | ### Evaluate
167 | You can set `--show` to visualize the results.
168 |
169 | ```bash
170 | ./tools/dist_test.sh 8 [--show]
171 | ```
172 |
173 | ## License
174 |
175 | All assets and code are under the [Apache 2.0 license](./LICENSE) unless specified otherwise.
176 |
177 | ## Citation
178 | If this work is helpful for your research, please consider citing the following BibTeX entry.
179 |
180 | ``` bibtex
181 | @article{li2023toponet,
182 | title={Graph-based Topology Reasoning for Driving Scenes},
183 | author={Li, Tianyu and Chen, Li and Wang, Huijie and Li, Yang and Yang, Jiazhi and Geng, Xiangwei and Jiang, Shengyin and Wang, Yuting and Xu, Hang and Xu, Chunjing and Yan, Junchi and Luo, Ping and Li, Hongyang},
184 | journal={arXiv preprint arXiv:2304.05277},
185 | year={2023}
186 | }
187 |
188 | @inproceedings{wang2023openlanev2,
189 | title={OpenLane-V2: A Topology Reasoning Benchmark for Unified 3D HD Mapping},
190 | author={Wang, Huijie and Li, Tianyu and Li, Yang and Chen, Li and Sima, Chonghao and Liu, Zhenbo and Wang, Bangjun and Jia, Peijin and Wang, Yuting and Jiang, Shengyin and Wen, Feng and Xu, Hang and Luo, Ping and Yan, Junchi and Zhang, Wei and Li, Hongyang},
191 | booktitle={NeurIPS},
192 | year={2023}
193 | }
194 | ```
195 |
196 | ## Related resources
197 |
198 | We acknowledge all the open-source contributors for the following projects to make this work possible:
199 |
200 | - [Openlane-V2](https://github.com/OpenDriveLab/OpenLane-V2)
201 | - [BEVFormer](https://github.com/fundamentalvision/BEVFormer)
202 |
--------------------------------------------------------------------------------
/figs/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/TopoNet/284442bd36e8c0af5dbbf4285ebb79897c1ccd82/figs/pipeline.png
--------------------------------------------------------------------------------
/projects/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/TopoNet/284442bd36e8c0af5dbbf4285ebb79897c1ccd82/projects/__init__.py
--------------------------------------------------------------------------------
/projects/bevformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .spatial_cross_attention import SpatialCrossAttention, MSDeformableAttention3D
2 | from .temporal_self_attention import TemporalSelfAttention
3 | from .encoder import BEVFormerEncoder, BEVFormerLayer
4 | from .decoder import DetectionTransformerDecoder
5 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/custom_base_transformer_layer.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 |
7 | import copy
8 | import warnings
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from mmcv import ConfigDict, deprecated_api_warning
14 | from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
15 | from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
16 |
17 | from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
18 | TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
19 |
20 | # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
21 | try:
22 | from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
23 | warnings.warn(
24 | ImportWarning(
25 | '``MultiScaleDeformableAttention`` has been moved to '
26 | '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
27 | '``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
28 | 'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
29 | ))
30 | except ImportError:
31 | warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
32 | '``mmcv.ops.multi_scale_deform_attn``, '
33 | 'You should install ``mmcv-full`` if you need this module. ')
34 | from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention
35 |
36 |
37 | @TRANSFORMER_LAYER.register_module()
38 | class MyCustomBaseTransformerLayer(BaseModule):
39 | """Base `TransformerLayer` for vision transformer.
40 | It can be built from `mmcv.ConfigDict` and support more flexible
41 | customization, for example, using any number of `FFN or LN ` and
42 | use different kinds of `attention` by specifying a list of `ConfigDict`
43 | named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
44 | when you specifying `norm` as the first element of `operation_order`.
45 | More details about the `prenorm`: `On Layer Normalization in the
46 | Transformer Architecture `_ .
47 | Args:
48 | attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
49 | Configs for `self_attention` or `cross_attention` modules,
50 | The order of the configs in the list should be consistent with
51 | corresponding attentions in operation_order.
52 | If it is a dict, all of the attention modules in operation_order
53 | will be built with this config. Default: None.
54 | ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
55 | Configs for FFN, The order of the configs in the list should be
56 | consistent with corresponding ffn in operation_order.
57 | If it is a dict, all of the attention modules in operation_order
58 | will be built with this config.
59 | operation_order (tuple[str]): The execution order of operation
60 | in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
61 | Support `prenorm` when you specifying first element as `norm`.
62 | Default:None.
63 | norm_cfg (dict): Config dict for normalization layer.
64 | Default: dict(type='LN').
65 | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
66 | Default: None.
67 | batch_first (bool): Key, Query and Value are shape
68 | of (batch, n, embed_dim)
69 | or (n, batch, embed_dim). Default to False.
70 | """
71 |
72 | def __init__(self,
73 | attn_cfgs=None,
74 | ffn_cfgs=dict(
75 | type='FFN',
76 | embed_dims=256,
77 | feedforward_channels=1024,
78 | num_fcs=2,
79 | ffn_drop=0.,
80 | act_cfg=dict(type='ReLU', inplace=True),
81 | ),
82 | operation_order=None,
83 | norm_cfg=dict(type='LN'),
84 | init_cfg=None,
85 | batch_first=True,
86 | **kwargs):
87 |
88 | deprecated_args = dict(
89 | feedforward_channels='feedforward_channels',
90 | ffn_dropout='ffn_drop',
91 | ffn_num_fcs='num_fcs')
92 | for ori_name, new_name in deprecated_args.items():
93 | if ori_name in kwargs:
94 | warnings.warn(
95 | f'The arguments `{ori_name}` in BaseTransformerLayer '
96 | f'has been deprecated, now you should set `{new_name}` '
97 | f'and other FFN related arguments '
98 | f'to a dict named `ffn_cfgs`. ')
99 | ffn_cfgs[new_name] = kwargs[ori_name]
100 |
101 | super(MyCustomBaseTransformerLayer, self).__init__(init_cfg)
102 |
103 | self.batch_first = batch_first
104 |
105 | assert set(operation_order) & set(
106 | ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
107 | set(operation_order), f'The operation_order of' \
108 | f' {self.__class__.__name__} should ' \
109 | f'contains all four operation type ' \
110 | f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
111 |
112 | num_attn = operation_order.count('self_attn') + operation_order.count(
113 | 'cross_attn')
114 | if isinstance(attn_cfgs, dict):
115 | attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
116 | else:
117 | assert num_attn == len(attn_cfgs), f'The length ' \
118 | f'of attn_cfg {num_attn} is ' \
119 | f'not consistent with the number of attention' \
120 | f'in operation_order {operation_order}.'
121 |
122 | self.num_attn = num_attn
123 | self.operation_order = operation_order
124 | self.norm_cfg = norm_cfg
125 | self.pre_norm = operation_order[0] == 'norm'
126 | self.attentions = ModuleList()
127 |
128 | index = 0
129 | for operation_name in operation_order:
130 | if operation_name in ['self_attn', 'cross_attn']:
131 | if 'batch_first' in attn_cfgs[index]:
132 | assert self.batch_first == attn_cfgs[index]['batch_first']
133 | else:
134 | attn_cfgs[index]['batch_first'] = self.batch_first
135 | attention = build_attention(attn_cfgs[index])
136 | # Some custom attentions used as `self_attn`
137 | # or `cross_attn` can have different behavior.
138 | attention.operation_name = operation_name
139 | self.attentions.append(attention)
140 | index += 1
141 |
142 | self.embed_dims = self.attentions[0].embed_dims
143 |
144 | self.ffns = ModuleList()
145 | num_ffns = operation_order.count('ffn')
146 | if isinstance(ffn_cfgs, dict):
147 | ffn_cfgs = ConfigDict(ffn_cfgs)
148 | if isinstance(ffn_cfgs, dict):
149 | ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
150 | assert len(ffn_cfgs) == num_ffns
151 | for ffn_index in range(num_ffns):
152 | if 'embed_dims' not in ffn_cfgs[ffn_index]:
153 | ffn_cfgs['embed_dims'] = self.embed_dims
154 | else:
155 | assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
156 |
157 | self.ffns.append(
158 | build_feedforward_network(ffn_cfgs[ffn_index]))
159 |
160 | self.norms = ModuleList()
161 | num_norms = operation_order.count('norm')
162 | for _ in range(num_norms):
163 | self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
164 |
165 | def forward(self,
166 | query,
167 | key=None,
168 | value=None,
169 | query_pos=None,
170 | key_pos=None,
171 | attn_masks=None,
172 | query_key_padding_mask=None,
173 | key_padding_mask=None,
174 | **kwargs):
175 | """Forward function for `TransformerDecoderLayer`.
176 | **kwargs contains some specific arguments of attentions.
177 | Args:
178 | query (Tensor): The input query with shape
179 | [num_queries, bs, embed_dims] if
180 | self.batch_first is False, else
181 | [bs, num_queries embed_dims].
182 | key (Tensor): The key tensor with shape [num_keys, bs,
183 | embed_dims] if self.batch_first is False, else
184 | [bs, num_keys, embed_dims] .
185 | value (Tensor): The value tensor with same shape as `key`.
186 | query_pos (Tensor): The positional encoding for `query`.
187 | Default: None.
188 | key_pos (Tensor): The positional encoding for `key`.
189 | Default: None.
190 | attn_masks (List[Tensor] | None): 2D Tensor used in
191 | calculation of corresponding attention. The length of
192 | it should equal to the number of `attention` in
193 | `operation_order`. Default: None.
194 | query_key_padding_mask (Tensor): ByteTensor for `query`, with
195 | shape [bs, num_queries]. Only used in `self_attn` layer.
196 | Defaults to None.
197 | key_padding_mask (Tensor): ByteTensor for `query`, with
198 | shape [bs, num_keys]. Default: None.
199 | Returns:
200 | Tensor: forwarded results with shape [num_queries, bs, embed_dims].
201 | """
202 |
203 | norm_index = 0
204 | attn_index = 0
205 | ffn_index = 0
206 | identity = query
207 | if attn_masks is None:
208 | attn_masks = [None for _ in range(self.num_attn)]
209 | elif isinstance(attn_masks, torch.Tensor):
210 | attn_masks = [
211 | copy.deepcopy(attn_masks) for _ in range(self.num_attn)
212 | ]
213 | warnings.warn(f'Use same attn_mask in all attentions in '
214 | f'{self.__class__.__name__} ')
215 | else:
216 | assert len(attn_masks) == self.num_attn, f'The length of ' \
217 | f'attn_masks {len(attn_masks)} must be equal ' \
218 | f'to the number of attention in ' \
219 | f'operation_order {self.num_attn}'
220 |
221 | for layer in self.operation_order:
222 | if layer == 'self_attn':
223 | temp_key = temp_value = query
224 | query = self.attentions[attn_index](
225 | query,
226 | temp_key,
227 | temp_value,
228 | identity if self.pre_norm else None,
229 | query_pos=query_pos,
230 | key_pos=query_pos,
231 | attn_mask=attn_masks[attn_index],
232 | key_padding_mask=query_key_padding_mask,
233 | **kwargs)
234 | attn_index += 1
235 | identity = query
236 |
237 | elif layer == 'norm':
238 | query = self.norms[norm_index](query)
239 | norm_index += 1
240 |
241 | elif layer == 'cross_attn':
242 | query = self.attentions[attn_index](
243 | query,
244 | key,
245 | value,
246 | identity if self.pre_norm else None,
247 | query_pos=query_pos,
248 | key_pos=key_pos,
249 | attn_mask=attn_masks[attn_index],
250 | key_padding_mask=key_padding_mask,
251 | **kwargs)
252 | attn_index += 1
253 | identity = query
254 |
255 | elif layer == 'ffn':
256 | query = self.ffns[ffn_index](
257 | query, identity if self.pre_norm else None)
258 | ffn_index += 1
259 |
260 | return query
261 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/decoder.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 |
7 | from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
8 | import mmcv
9 | import cv2 as cv
10 | import copy
11 | import warnings
12 | from matplotlib import pyplot as plt
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from mmcv.cnn import xavier_init, constant_init
18 | from mmcv.cnn.bricks.registry import (ATTENTION,
19 | TRANSFORMER_LAYER_SEQUENCE)
20 | from mmcv.cnn.bricks.transformer import TransformerLayerSequence
21 | import math
22 | from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
23 | from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
24 | to_2tuple)
25 |
26 | from mmcv.utils import ext_loader
27 | from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32, \
28 | MultiScaleDeformableAttnFunction_fp16
29 |
30 | ext_module = ext_loader.load_ext(
31 | '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
32 |
33 |
34 | def inverse_sigmoid(x, eps=1e-5):
35 | """Inverse function of sigmoid.
36 | Args:
37 | x (Tensor): The tensor to do the
38 | inverse.
39 | eps (float): EPS avoid numerical
40 | overflow. Defaults 1e-5.
41 | Returns:
42 | Tensor: The x has passed the inverse
43 | function of sigmoid, has same
44 | shape with input.
45 | """
46 | x = x.clamp(min=0, max=1)
47 | x1 = x.clamp(min=eps)
48 | x2 = (1 - x).clamp(min=eps)
49 | return torch.log(x1 / x2)
50 |
51 |
52 | @TRANSFORMER_LAYER_SEQUENCE.register_module()
53 | class DetectionTransformerDecoder(TransformerLayerSequence):
54 | """Implements the decoder in DETR3D transformer.
55 | Args:
56 | return_intermediate (bool): Whether to return intermediate outputs.
57 | coder_norm_cfg (dict): Config of last normalization layer. Default:
58 | `LN`.
59 | """
60 |
61 | def __init__(self, *args, return_intermediate=False, **kwargs):
62 | super(DetectionTransformerDecoder, self).__init__(*args, **kwargs)
63 | self.return_intermediate = return_intermediate
64 | self.fp16_enabled = False
65 |
66 | def forward(self,
67 | query,
68 | *args,
69 | reference_points=None,
70 | reg_branches=None,
71 | key_padding_mask=None,
72 | **kwargs):
73 | """Forward function for `Detr3DTransformerDecoder`.
74 | Args:
75 | query (Tensor): Input query with shape
76 | `(num_query, bs, embed_dims)`.
77 | reference_points (Tensor): The reference
78 | points of offset. has shape
79 | (bs, num_query, 4) when as_two_stage,
80 | otherwise has shape ((bs, num_query, 2).
81 | reg_branch: (obj:`nn.ModuleList`): Used for
82 | refining the regression results. Only would
83 | be passed when with_box_refine is True,
84 | otherwise would be passed a `None`.
85 | Returns:
86 | Tensor: Results with shape [1, num_query, bs, embed_dims] when
87 | return_intermediate is `False`, otherwise it has shape
88 | [num_layers, num_query, bs, embed_dims].
89 | """
90 | output = query
91 | intermediate = []
92 | intermediate_reference_points = []
93 | for lid, layer in enumerate(self.layers):
94 |
95 | reference_points_input = reference_points[..., :2].unsqueeze(
96 | 2) # BS NUM_QUERY NUM_LEVEL 2
97 | output = layer(
98 | output,
99 | *args,
100 | reference_points=reference_points_input,
101 | key_padding_mask=key_padding_mask,
102 | **kwargs)
103 | output = output.permute(1, 0, 2)
104 |
105 | if reg_branches is not None:
106 | tmp = reg_branches[lid](output)
107 |
108 | assert reference_points.shape[-1] == 3
109 |
110 | new_reference_points = torch.zeros_like(reference_points)
111 | new_reference_points[..., :2] = tmp[
112 | ..., :2] + inverse_sigmoid(reference_points[..., :2])
113 | new_reference_points[..., 2:3] = tmp[
114 | ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])
115 |
116 | new_reference_points = new_reference_points.sigmoid()
117 |
118 | reference_points = new_reference_points.detach()
119 |
120 | output = output.permute(1, 0, 2)
121 | if self.return_intermediate:
122 | intermediate.append(output)
123 | intermediate_reference_points.append(reference_points)
124 |
125 | if self.return_intermediate:
126 | return torch.stack(intermediate), torch.stack(
127 | intermediate_reference_points)
128 |
129 | return output, reference_points
130 |
131 |
132 | @ATTENTION.register_module()
133 | class CustomMSDeformableAttention(BaseModule):
134 | """An attention module used in Deformable-Detr.
135 |
136 | `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
137 | `_.
138 |
139 | Args:
140 | embed_dims (int): The embedding dimension of Attention.
141 | Default: 256.
142 | num_heads (int): Parallel attention heads. Default: 64.
143 | num_levels (int): The number of feature map used in
144 | Attention. Default: 4.
145 | num_points (int): The number of sampling points for
146 | each query in each head. Default: 4.
147 | im2col_step (int): The step used in image_to_column.
148 | Default: 64.
149 | dropout (float): A Dropout layer on `inp_identity`.
150 | Default: 0.1.
151 | batch_first (bool): Key, Query and Value are shape of
152 | (batch, n, embed_dim)
153 | or (n, batch, embed_dim). Default to False.
154 | norm_cfg (dict): Config dict for normalization layer.
155 | Default: None.
156 | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
157 | Default: None.
158 | """
159 |
160 | def __init__(self,
161 | embed_dims=256,
162 | num_heads=8,
163 | num_levels=4,
164 | num_points=4,
165 | im2col_step=64,
166 | dropout=0.1,
167 | batch_first=False,
168 | norm_cfg=None,
169 | init_cfg=None):
170 | super().__init__(init_cfg)
171 | if embed_dims % num_heads != 0:
172 | raise ValueError(f'embed_dims must be divisible by num_heads, '
173 | f'but got {embed_dims} and {num_heads}')
174 | dim_per_head = embed_dims // num_heads
175 | self.norm_cfg = norm_cfg
176 | self.dropout = nn.Dropout(dropout)
177 | self.batch_first = batch_first
178 | self.fp16_enabled = False
179 |
180 | # you'd better set dim_per_head to a power of 2
181 | # which is more efficient in the CUDA implementation
182 | def _is_power_of_2(n):
183 | if (not isinstance(n, int)) or (n < 0):
184 | raise ValueError(
185 | 'invalid input for _is_power_of_2: {} (type: {})'.format(
186 | n, type(n)))
187 | return (n & (n - 1) == 0) and n != 0
188 |
189 | if not _is_power_of_2(dim_per_head):
190 | warnings.warn(
191 | "You'd better set embed_dims in "
192 | 'MultiScaleDeformAttention to make '
193 | 'the dimension of each attention head a power of 2 '
194 | 'which is more efficient in our CUDA implementation.')
195 |
196 | self.im2col_step = im2col_step
197 | self.embed_dims = embed_dims
198 | self.num_levels = num_levels
199 | self.num_heads = num_heads
200 | self.num_points = num_points
201 | self.sampling_offsets = nn.Linear(
202 | embed_dims, num_heads * num_levels * num_points * 2)
203 | self.attention_weights = nn.Linear(embed_dims,
204 | num_heads * num_levels * num_points)
205 | self.value_proj = nn.Linear(embed_dims, embed_dims)
206 | self.output_proj = nn.Linear(embed_dims, embed_dims)
207 | self.init_weights()
208 |
209 | def init_weights(self):
210 | """Default initialization for Parameters of Module."""
211 | constant_init(self.sampling_offsets, 0.)
212 | thetas = torch.arange(
213 | self.num_heads,
214 | dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
215 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
216 | grid_init = (grid_init /
217 | grid_init.abs().max(-1, keepdim=True)[0]).view(
218 | self.num_heads, 1, 1,
219 | 2).repeat(1, self.num_levels, self.num_points, 1)
220 | for i in range(self.num_points):
221 | grid_init[:, :, i, :] *= i + 1
222 |
223 | self.sampling_offsets.bias.data = grid_init.view(-1)
224 | constant_init(self.attention_weights, val=0., bias=0.)
225 | xavier_init(self.value_proj, distribution='uniform', bias=0.)
226 | xavier_init(self.output_proj, distribution='uniform', bias=0.)
227 | self._is_init = True
228 |
229 | @deprecated_api_warning({'residual': 'identity'},
230 | cls_name='MultiScaleDeformableAttention')
231 | def forward(self,
232 | query,
233 | key=None,
234 | value=None,
235 | identity=None,
236 | query_pos=None,
237 | key_padding_mask=None,
238 | reference_points=None,
239 | spatial_shapes=None,
240 | level_start_index=None,
241 | flag='decoder',
242 | **kwargs):
243 | """Forward Function of MultiScaleDeformAttention.
244 |
245 | Args:
246 | query (Tensor): Query of Transformer with shape
247 | (num_query, bs, embed_dims).
248 | key (Tensor): The key tensor with shape
249 | `(num_key, bs, embed_dims)`.
250 | value (Tensor): The value tensor with shape
251 | `(num_key, bs, embed_dims)`.
252 | identity (Tensor): The tensor used for addition, with the
253 | same shape as `query`. Default None. If None,
254 | `query` will be used.
255 | query_pos (Tensor): The positional encoding for `query`.
256 | Default: None.
257 | key_pos (Tensor): The positional encoding for `key`. Default
258 | None.
259 | reference_points (Tensor): The normalized reference
260 | points with shape (bs, num_query, num_levels, 2),
261 | all elements is range in [0, 1], top-left (0,0),
262 | bottom-right (1, 1), including padding area.
263 | or (N, Length_{query}, num_levels, 4), add
264 | additional two dimensions is (w, h) to
265 | form reference boxes.
266 | key_padding_mask (Tensor): ByteTensor for `query`, with
267 | shape [bs, num_key].
268 | spatial_shapes (Tensor): Spatial shape of features in
269 | different levels. With shape (num_levels, 2),
270 | last dimension represents (h, w).
271 | level_start_index (Tensor): The start index of each level.
272 | A tensor has shape ``(num_levels, )`` and can be represented
273 | as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
274 |
275 | Returns:
276 | Tensor: forwarded results with shape [num_query, bs, embed_dims].
277 | """
278 |
279 | if value is None:
280 | value = query
281 |
282 | if identity is None:
283 | identity = query
284 | if query_pos is not None:
285 | query = query + query_pos
286 | if not self.batch_first:
287 | # change to (bs, num_query ,embed_dims)
288 | query = query.permute(1, 0, 2)
289 | value = value.permute(1, 0, 2)
290 |
291 | bs, num_query, _ = query.shape
292 | bs, num_value, _ = value.shape
293 | assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
294 |
295 | value = self.value_proj(value)
296 | if key_padding_mask is not None:
297 | value = value.masked_fill(key_padding_mask[..., None], 0.0)
298 | value = value.view(bs, num_value, self.num_heads, -1)
299 |
300 | sampling_offsets = self.sampling_offsets(query).view(
301 | bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
302 | attention_weights = self.attention_weights(query).view(
303 | bs, num_query, self.num_heads, self.num_levels * self.num_points)
304 | attention_weights = attention_weights.softmax(-1)
305 |
306 | attention_weights = attention_weights.view(bs, num_query,
307 | self.num_heads,
308 | self.num_levels,
309 | self.num_points)
310 | if reference_points.shape[-1] == 2:
311 | offset_normalizer = torch.stack(
312 | [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
313 | sampling_locations = reference_points[:, :, None, :, None, :] \
314 | + sampling_offsets \
315 | / offset_normalizer[None, None, None, :, None, :]
316 | elif reference_points.shape[-1] == 4:
317 | sampling_locations = reference_points[:, :, None, :, None, :2] \
318 | + sampling_offsets / self.num_points \
319 | * reference_points[:, :, None, :, None, 2:] \
320 | * 0.5
321 | else:
322 | raise ValueError(
323 | f'Last dim of reference_points must be'
324 | f' 2 or 4, but get {reference_points.shape[-1]} instead.')
325 | if torch.cuda.is_available() and value.is_cuda:
326 |
327 | # using fp16 deformable attention is unstable because it performs many sum operations
328 | if value.dtype == torch.float16:
329 | MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
330 | else:
331 | MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
332 | output = MultiScaleDeformableAttnFunction.apply(
333 | value, spatial_shapes, level_start_index, sampling_locations,
334 | attention_weights, self.im2col_step)
335 | else:
336 | output = multi_scale_deformable_attn_pytorch(
337 | value, spatial_shapes, sampling_locations, attention_weights)
338 |
339 | output = self.output_proj(output)
340 |
341 | if not self.batch_first:
342 | # (num_query, bs ,embed_dims)
343 | output = output.permute(1, 0, 2)
344 |
345 | return self.dropout(output) + identity
346 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/encoder.py:
--------------------------------------------------------------------------------
1 |
2 | # ---------------------------------------------
3 | # Copyright (c) OpenMMLab. All rights reserved.
4 | # ---------------------------------------------
5 | # Modified by Zhiqi Li
6 | # ---------------------------------------------
7 |
8 | from .custom_base_transformer_layer import MyCustomBaseTransformerLayer
9 | import copy
10 | import warnings
11 | from mmcv.cnn.bricks.registry import (ATTENTION,
12 | TRANSFORMER_LAYER,
13 | TRANSFORMER_LAYER_SEQUENCE)
14 | from mmcv.cnn.bricks.transformer import TransformerLayerSequence
15 | from mmcv.runner import force_fp32, auto_fp16
16 | import numpy as np
17 | import torch
18 | import cv2 as cv
19 | import mmcv
20 | from mmcv.utils import TORCH_VERSION, digit_version
21 | from mmcv.utils import ext_loader
22 | ext_module = ext_loader.load_ext(
23 | '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
24 |
25 |
26 | @TRANSFORMER_LAYER_SEQUENCE.register_module()
27 | class BEVFormerEncoder(TransformerLayerSequence):
28 |
29 | """
30 | Attention with both self and cross
31 | Implements the decoder in DETR transformer.
32 | Args:
33 | return_intermediate (bool): Whether to return intermediate outputs.
34 | coder_norm_cfg (dict): Config of last normalization layer. Default:
35 | `LN`.
36 | """
37 |
38 | def __init__(self, *args, pc_range=None, num_points_in_pillar=4, return_intermediate=False, dataset_type='nuscenes',
39 | **kwargs):
40 |
41 | super(BEVFormerEncoder, self).__init__(*args, **kwargs)
42 | self.return_intermediate = return_intermediate
43 |
44 | self.num_points_in_pillar = num_points_in_pillar
45 | self.pc_range = pc_range
46 | self.fp16_enabled = False
47 |
48 | @staticmethod
49 | def get_reference_points(H, W, Z=8, num_points_in_pillar=4, dim='3d', bs=1, device='cuda', dtype=torch.float):
50 | """Get the reference points used in SCA and TSA.
51 | Args:
52 | H, W: spatial shape of bev.
53 | Z: hight of pillar.
54 | D: sample D points uniformly from each pillar.
55 | device (obj:`device`): The device where
56 | reference_points should be.
57 | Returns:
58 | Tensor: reference points used in decoder, has \
59 | shape (bs, num_keys, num_levels, 2).
60 | """
61 |
62 | # reference points in 3D space, used in spatial cross-attention (SCA)
63 | if dim == '3d':
64 | zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
65 | device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
66 | xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
67 | device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
68 | ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
69 | device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
70 | ref_3d = torch.stack((xs, ys, zs), -1)
71 | ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
72 | ref_3d = ref_3d[None].repeat(bs, 1, 1, 1)
73 | return ref_3d
74 |
75 | # reference points on 2D bev plane, used in temporal self-attention (TSA).
76 | elif dim == '2d':
77 | ref_y, ref_x = torch.meshgrid(
78 | torch.linspace(
79 | 0.5, H - 0.5, H, dtype=dtype, device=device),
80 | torch.linspace(
81 | 0.5, W - 0.5, W, dtype=dtype, device=device)
82 | )
83 | ref_y = ref_y.reshape(-1)[None] / H
84 | ref_x = ref_x.reshape(-1)[None] / W
85 | ref_2d = torch.stack((ref_x, ref_y), -1)
86 | ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
87 | return ref_2d
88 |
89 | # This function must use fp32!!!
90 | @force_fp32(apply_to=('reference_points', 'img_metas'))
91 | def point_sampling(self, reference_points, pc_range, img_metas):
92 |
93 | lidar2img = []
94 | for img_meta in img_metas:
95 | lidar2img.append(img_meta['lidar2img'])
96 | lidar2img = np.asarray(lidar2img)
97 | lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
98 | reference_points = reference_points.clone()
99 |
100 | reference_points[..., 0:1] = reference_points[..., 0:1] * \
101 | (pc_range[3] - pc_range[0]) + pc_range[0]
102 | reference_points[..., 1:2] = reference_points[..., 1:2] * \
103 | (pc_range[4] - pc_range[1]) + pc_range[1]
104 | reference_points[..., 2:3] = reference_points[..., 2:3] * \
105 | (pc_range[5] - pc_range[2]) + pc_range[2]
106 |
107 | reference_points = torch.cat(
108 | (reference_points, torch.ones_like(reference_points[..., :1])), -1)
109 |
110 | reference_points = reference_points.permute(1, 0, 2, 3)
111 | D, B, num_query = reference_points.size()[:3]
112 | num_cam = lidar2img.size(1)
113 |
114 | reference_points = reference_points.view(
115 | D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)
116 |
117 | lidar2img = lidar2img.view(
118 | 1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
119 |
120 | reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
121 | reference_points.to(torch.float32)).squeeze(-1)
122 | eps = 1e-5
123 |
124 | bev_mask = (reference_points_cam[..., 2:3] > eps)
125 | reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
126 | reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)
127 |
128 | reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
129 | reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
130 |
131 | bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
132 | & (reference_points_cam[..., 1:2] < 1.0)
133 | & (reference_points_cam[..., 0:1] < 1.0)
134 | & (reference_points_cam[..., 0:1] > 0.0))
135 | if digit_version(TORCH_VERSION) >= digit_version('1.8'):
136 | bev_mask = torch.nan_to_num(bev_mask)
137 | else:
138 | bev_mask = bev_mask.new_tensor(
139 | np.nan_to_num(bev_mask.cpu().numpy()))
140 |
141 | reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
142 | bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
143 |
144 | return reference_points_cam, bev_mask
145 |
146 | @auto_fp16()
147 | def forward(self,
148 | bev_query,
149 | key,
150 | value,
151 | *args,
152 | bev_h=None,
153 | bev_w=None,
154 | bev_pos=None,
155 | spatial_shapes=None,
156 | level_start_index=None,
157 | valid_ratios=None,
158 | prev_bev=None,
159 | shift=0.,
160 | **kwargs):
161 | """Forward function for `TransformerDecoder`.
162 | Args:
163 | bev_query (Tensor): Input BEV query with shape
164 | `(num_query, bs, embed_dims)`.
165 | key & value (Tensor): Input multi-cameta features with shape
166 | (num_cam, num_value, bs, embed_dims)
167 | reference_points (Tensor): The reference
168 | points of offset. has shape
169 | (bs, num_query, 4) when as_two_stage,
170 | otherwise has shape ((bs, num_query, 2).
171 | valid_ratios (Tensor): The radios of valid
172 | points on the feature map, has shape
173 | (bs, num_levels, 2)
174 | Returns:
175 | Tensor: Results with shape [1, num_query, bs, embed_dims] when
176 | return_intermediate is `False`, otherwise it has shape
177 | [num_layers, num_query, bs, embed_dims].
178 | """
179 |
180 | output = bev_query
181 | intermediate = []
182 |
183 | ref_3d = self.get_reference_points(
184 | bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
185 | ref_2d = self.get_reference_points(
186 | bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
187 |
188 | reference_points_cam, bev_mask = self.point_sampling(
189 | ref_3d, self.pc_range, kwargs['img_metas'])
190 |
191 | # bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
192 | shift_ref_2d = ref_2d.clone()
193 | shift_ref_2d += shift[:, None, None, :]
194 |
195 | # (num_query, bs, embed_dims) -> (bs, num_query, embed_dims)
196 | bev_query = bev_query.permute(1, 0, 2)
197 | bev_pos = bev_pos.permute(1, 0, 2)
198 | bs, len_bev, num_bev_level, _ = ref_2d.shape
199 | if prev_bev is not None:
200 | prev_bev = prev_bev.permute(1, 0, 2)
201 | prev_bev = torch.stack(
202 | [prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
203 | hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
204 | bs*2, len_bev, num_bev_level, 2)
205 | else:
206 | hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
207 | bs*2, len_bev, num_bev_level, 2)
208 |
209 | for lid, layer in enumerate(self.layers):
210 | output = layer(
211 | bev_query,
212 | key,
213 | value,
214 | *args,
215 | bev_pos=bev_pos,
216 | ref_2d=hybird_ref_2d,
217 | ref_3d=ref_3d,
218 | bev_h=bev_h,
219 | bev_w=bev_w,
220 | spatial_shapes=spatial_shapes,
221 | level_start_index=level_start_index,
222 | reference_points_cam=reference_points_cam,
223 | bev_mask=bev_mask,
224 | prev_bev=prev_bev,
225 | **kwargs)
226 |
227 | bev_query = output
228 | if self.return_intermediate:
229 | intermediate.append(output)
230 |
231 | if self.return_intermediate:
232 | return torch.stack(intermediate)
233 |
234 | return output
235 |
236 |
237 | @TRANSFORMER_LAYER.register_module()
238 | class BEVFormerLayer(MyCustomBaseTransformerLayer):
239 | """Implements decoder layer in DETR transformer.
240 | Args:
241 | attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
242 | Configs for self_attention or cross_attention, the order
243 | should be consistent with it in `operation_order`. If it is
244 | a dict, it would be expand to the number of attention in
245 | `operation_order`.
246 | feedforward_channels (int): The hidden dimension for FFNs.
247 | ffn_dropout (float): Probability of an element to be zeroed
248 | in ffn. Default 0.0.
249 | operation_order (tuple[str]): The execution order of operation
250 | in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
251 | Default:None
252 | act_cfg (dict): The activation config for FFNs. Default: `LN`
253 | norm_cfg (dict): Config dict for normalization layer.
254 | Default: `LN`.
255 | ffn_num_fcs (int): The number of fully-connected layers in FFNs.
256 | Default:2.
257 | """
258 |
259 | def __init__(self,
260 | attn_cfgs,
261 | ffn_cfgs,
262 | operation_order=None,
263 | act_cfg=dict(type='ReLU', inplace=True),
264 | norm_cfg=dict(type='LN'),
265 | **kwargs):
266 | super(BEVFormerLayer, self).__init__(
267 | attn_cfgs=attn_cfgs,
268 | ffn_cfgs=ffn_cfgs,
269 | operation_order=operation_order,
270 | act_cfg=act_cfg,
271 | norm_cfg=norm_cfg,
272 | **kwargs)
273 | self.fp16_enabled = False
274 | assert len(operation_order) == 6
275 | assert set(operation_order) == set(
276 | ['self_attn', 'norm', 'cross_attn', 'ffn'])
277 |
278 | def forward(self,
279 | query,
280 | key=None,
281 | value=None,
282 | bev_pos=None,
283 | query_pos=None,
284 | key_pos=None,
285 | attn_masks=None,
286 | query_key_padding_mask=None,
287 | key_padding_mask=None,
288 | ref_2d=None,
289 | ref_3d=None,
290 | bev_h=None,
291 | bev_w=None,
292 | reference_points_cam=None,
293 | mask=None,
294 | spatial_shapes=None,
295 | level_start_index=None,
296 | prev_bev=None,
297 | **kwargs):
298 | """Forward function for `TransformerDecoderLayer`.
299 |
300 | **kwargs contains some specific arguments of attentions.
301 |
302 | Args:
303 | query (Tensor): The input query with shape
304 | [num_queries, bs, embed_dims] if
305 | self.batch_first is False, else
306 | [bs, num_queries embed_dims].
307 | key (Tensor): The key tensor with shape [num_keys, bs,
308 | embed_dims] if self.batch_first is False, else
309 | [bs, num_keys, embed_dims] .
310 | value (Tensor): The value tensor with same shape as `key`.
311 | query_pos (Tensor): The positional encoding for `query`.
312 | Default: None.
313 | key_pos (Tensor): The positional encoding for `key`.
314 | Default: None.
315 | attn_masks (List[Tensor] | None): 2D Tensor used in
316 | calculation of corresponding attention. The length of
317 | it should equal to the number of `attention` in
318 | `operation_order`. Default: None.
319 | query_key_padding_mask (Tensor): ByteTensor for `query`, with
320 | shape [bs, num_queries]. Only used in `self_attn` layer.
321 | Defaults to None.
322 | key_padding_mask (Tensor): ByteTensor for `query`, with
323 | shape [bs, num_keys]. Default: None.
324 |
325 | Returns:
326 | Tensor: forwarded results with shape [num_queries, bs, embed_dims].
327 | """
328 |
329 | norm_index = 0
330 | attn_index = 0
331 | ffn_index = 0
332 | identity = query
333 | if attn_masks is None:
334 | attn_masks = [None for _ in range(self.num_attn)]
335 | elif isinstance(attn_masks, torch.Tensor):
336 | attn_masks = [
337 | copy.deepcopy(attn_masks) for _ in range(self.num_attn)
338 | ]
339 | warnings.warn(f'Use same attn_mask in all attentions in '
340 | f'{self.__class__.__name__} ')
341 | else:
342 | assert len(attn_masks) == self.num_attn, f'The length of ' \
343 | f'attn_masks {len(attn_masks)} must be equal ' \
344 | f'to the number of attention in ' \
345 | f'operation_order {self.num_attn}'
346 |
347 | for layer in self.operation_order:
348 | # temporal self attention
349 | if layer == 'self_attn':
350 |
351 | query = self.attentions[attn_index](
352 | query,
353 | prev_bev,
354 | prev_bev,
355 | identity if self.pre_norm else None,
356 | query_pos=bev_pos,
357 | key_pos=bev_pos,
358 | attn_mask=attn_masks[attn_index],
359 | key_padding_mask=query_key_padding_mask,
360 | reference_points=ref_2d,
361 | spatial_shapes=torch.tensor(
362 | [[bev_h, bev_w]], device=query.device),
363 | level_start_index=torch.tensor([0], device=query.device),
364 | **kwargs)
365 | attn_index += 1
366 | identity = query
367 |
368 | elif layer == 'norm':
369 | query = self.norms[norm_index](query)
370 | norm_index += 1
371 |
372 | # spaital cross attention
373 | elif layer == 'cross_attn':
374 | query = self.attentions[attn_index](
375 | query,
376 | key,
377 | value,
378 | identity if self.pre_norm else None,
379 | query_pos=query_pos,
380 | key_pos=key_pos,
381 | reference_points=ref_3d,
382 | reference_points_cam=reference_points_cam,
383 | mask=mask,
384 | attn_mask=attn_masks[attn_index],
385 | key_padding_mask=key_padding_mask,
386 | spatial_shapes=spatial_shapes,
387 | level_start_index=level_start_index,
388 | **kwargs)
389 | attn_index += 1
390 | identity = query
391 |
392 | elif layer == 'ffn':
393 | query = self.ffns[ffn_index](
394 | query, identity if self.pre_norm else None)
395 | ffn_index += 1
396 |
397 | return query
398 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/multi_scale_deformable_attn_function.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 |
7 | import torch
8 | from torch.cuda.amp import custom_bwd, custom_fwd
9 | from torch.autograd.function import Function, once_differentiable
10 | from mmcv.utils import ext_loader
11 | ext_module = ext_loader.load_ext(
12 | '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
13 |
14 |
15 | class MultiScaleDeformableAttnFunction_fp16(Function):
16 |
17 | @staticmethod
18 | @custom_fwd(cast_inputs=torch.float16)
19 | def forward(ctx, value, value_spatial_shapes, value_level_start_index,
20 | sampling_locations, attention_weights, im2col_step):
21 | """GPU version of multi-scale deformable attention.
22 |
23 | Args:
24 | value (Tensor): The value has shape
25 | (bs, num_keys, mum_heads, embed_dims//num_heads)
26 | value_spatial_shapes (Tensor): Spatial shape of
27 | each feature map, has shape (num_levels, 2),
28 | last dimension 2 represent (h, w)
29 | sampling_locations (Tensor): The location of sampling points,
30 | has shape
31 | (bs ,num_queries, num_heads, num_levels, num_points, 2),
32 | the last dimension 2 represent (x, y).
33 | attention_weights (Tensor): The weight of sampling points used
34 | when calculate the attention, has shape
35 | (bs ,num_queries, num_heads, num_levels, num_points),
36 | im2col_step (Tensor): The step used in image to column.
37 |
38 | Returns:
39 | Tensor: has shape (bs, num_queries, embed_dims)
40 | """
41 | ctx.im2col_step = im2col_step
42 | output = ext_module.ms_deform_attn_forward(
43 | value,
44 | value_spatial_shapes,
45 | value_level_start_index,
46 | sampling_locations,
47 | attention_weights,
48 | im2col_step=ctx.im2col_step)
49 | ctx.save_for_backward(value, value_spatial_shapes,
50 | value_level_start_index, sampling_locations,
51 | attention_weights)
52 | return output
53 |
54 | @staticmethod
55 | @once_differentiable
56 | @custom_bwd
57 | def backward(ctx, grad_output):
58 | """GPU version of backward function.
59 |
60 | Args:
61 | grad_output (Tensor): Gradient
62 | of output tensor of forward.
63 |
64 | Returns:
65 | Tuple[Tensor]: Gradient
66 | of input tensors in forward.
67 | """
68 | value, value_spatial_shapes, value_level_start_index, \
69 | sampling_locations, attention_weights = ctx.saved_tensors
70 | grad_value = torch.zeros_like(value)
71 | grad_sampling_loc = torch.zeros_like(sampling_locations)
72 | grad_attn_weight = torch.zeros_like(attention_weights)
73 |
74 | ext_module.ms_deform_attn_backward(
75 | value,
76 | value_spatial_shapes,
77 | value_level_start_index,
78 | sampling_locations,
79 | attention_weights,
80 | grad_output.contiguous(),
81 | grad_value,
82 | grad_sampling_loc,
83 | grad_attn_weight,
84 | im2col_step=ctx.im2col_step)
85 |
86 | return grad_value, None, None, \
87 | grad_sampling_loc, grad_attn_weight, None
88 |
89 |
90 | class MultiScaleDeformableAttnFunction_fp32(Function):
91 |
92 | @staticmethod
93 | @custom_fwd(cast_inputs=torch.float32)
94 | def forward(ctx, value, value_spatial_shapes, value_level_start_index,
95 | sampling_locations, attention_weights, im2col_step):
96 | """GPU version of multi-scale deformable attention.
97 |
98 | Args:
99 | value (Tensor): The value has shape
100 | (bs, num_keys, mum_heads, embed_dims//num_heads)
101 | value_spatial_shapes (Tensor): Spatial shape of
102 | each feature map, has shape (num_levels, 2),
103 | last dimension 2 represent (h, w)
104 | sampling_locations (Tensor): The location of sampling points,
105 | has shape
106 | (bs ,num_queries, num_heads, num_levels, num_points, 2),
107 | the last dimension 2 represent (x, y).
108 | attention_weights (Tensor): The weight of sampling points used
109 | when calculate the attention, has shape
110 | (bs ,num_queries, num_heads, num_levels, num_points),
111 | im2col_step (Tensor): The step used in image to column.
112 |
113 | Returns:
114 | Tensor: has shape (bs, num_queries, embed_dims)
115 | """
116 |
117 | ctx.im2col_step = im2col_step
118 | output = ext_module.ms_deform_attn_forward(
119 | value,
120 | value_spatial_shapes,
121 | value_level_start_index,
122 | sampling_locations,
123 | attention_weights,
124 | im2col_step=ctx.im2col_step)
125 | ctx.save_for_backward(value, value_spatial_shapes,
126 | value_level_start_index, sampling_locations,
127 | attention_weights)
128 | return output
129 |
130 | @staticmethod
131 | @once_differentiable
132 | @custom_bwd
133 | def backward(ctx, grad_output):
134 | """GPU version of backward function.
135 |
136 | Args:
137 | grad_output (Tensor): Gradient
138 | of output tensor of forward.
139 |
140 | Returns:
141 | Tuple[Tensor]: Gradient
142 | of input tensors in forward.
143 | """
144 | value, value_spatial_shapes, value_level_start_index, \
145 | sampling_locations, attention_weights = ctx.saved_tensors
146 | grad_value = torch.zeros_like(value)
147 | grad_sampling_loc = torch.zeros_like(sampling_locations)
148 | grad_attn_weight = torch.zeros_like(attention_weights)
149 |
150 | ext_module.ms_deform_attn_backward(
151 | value,
152 | value_spatial_shapes,
153 | value_level_start_index,
154 | sampling_locations,
155 | attention_weights,
156 | grad_output.contiguous(),
157 | grad_value,
158 | grad_sampling_loc,
159 | grad_attn_weight,
160 | im2col_step=ctx.im2col_step)
161 |
162 | return grad_value, None, None, \
163 | grad_sampling_loc, grad_attn_weight, None
164 |
--------------------------------------------------------------------------------
/projects/bevformer/modules/temporal_self_attention.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 |
7 | from .multi_scale_deformable_attn_function import MultiScaleDeformableAttnFunction_fp32
8 | from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
9 | import warnings
10 | import torch
11 | import torch.nn as nn
12 | from mmcv.cnn import xavier_init, constant_init
13 | from mmcv.cnn.bricks.registry import ATTENTION
14 | import math
15 | from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
16 | from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
17 | to_2tuple)
18 |
19 | from mmcv.utils import ext_loader
20 | ext_module = ext_loader.load_ext(
21 | '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
22 |
23 |
24 | @ATTENTION.register_module()
25 | class TemporalSelfAttention(BaseModule):
26 | """An attention module used in BEVFormer based on Deformable-Detr.
27 |
28 | `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
29 | `_.
30 |
31 | Args:
32 | embed_dims (int): The embedding dimension of Attention.
33 | Default: 256.
34 | num_heads (int): Parallel attention heads. Default: 64.
35 | num_levels (int): The number of feature map used in
36 | Attention. Default: 4.
37 | num_points (int): The number of sampling points for
38 | each query in each head. Default: 4.
39 | im2col_step (int): The step used in image_to_column.
40 | Default: 64.
41 | dropout (float): A Dropout layer on `inp_identity`.
42 | Default: 0.1.
43 | batch_first (bool): Key, Query and Value are shape of
44 | (batch, n, embed_dim)
45 | or (n, batch, embed_dim). Default to True.
46 | norm_cfg (dict): Config dict for normalization layer.
47 | Default: None.
48 | init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
49 | Default: None.
50 | num_bev_queue (int): In this version, we only use one history BEV and one currenct BEV.
51 | the length of BEV queue is 2.
52 | """
53 |
54 | def __init__(self,
55 | embed_dims=256,
56 | num_heads=8,
57 | num_levels=4,
58 | num_points=4,
59 | num_bev_queue=2,
60 | im2col_step=64,
61 | dropout=0.1,
62 | batch_first=True,
63 | norm_cfg=None,
64 | init_cfg=None):
65 |
66 | super().__init__(init_cfg)
67 | if embed_dims % num_heads != 0:
68 | raise ValueError(f'embed_dims must be divisible by num_heads, '
69 | f'but got {embed_dims} and {num_heads}')
70 | dim_per_head = embed_dims // num_heads
71 | self.norm_cfg = norm_cfg
72 | self.dropout = nn.Dropout(dropout)
73 | self.batch_first = batch_first
74 | self.fp16_enabled = False
75 |
76 | # you'd better set dim_per_head to a power of 2
77 | # which is more efficient in the CUDA implementation
78 | def _is_power_of_2(n):
79 | if (not isinstance(n, int)) or (n < 0):
80 | raise ValueError(
81 | 'invalid input for _is_power_of_2: {} (type: {})'.format(
82 | n, type(n)))
83 | return (n & (n - 1) == 0) and n != 0
84 |
85 | if not _is_power_of_2(dim_per_head):
86 | warnings.warn(
87 | "You'd better set embed_dims in "
88 | 'MultiScaleDeformAttention to make '
89 | 'the dimension of each attention head a power of 2 '
90 | 'which is more efficient in our CUDA implementation.')
91 |
92 | self.im2col_step = im2col_step
93 | self.embed_dims = embed_dims
94 | self.num_levels = num_levels
95 | self.num_heads = num_heads
96 | self.num_points = num_points
97 | self.num_bev_queue = num_bev_queue
98 | self.sampling_offsets = nn.Linear(
99 | embed_dims*self.num_bev_queue, num_bev_queue*num_heads * num_levels * num_points * 2)
100 | self.attention_weights = nn.Linear(embed_dims*self.num_bev_queue,
101 | num_bev_queue*num_heads * num_levels * num_points)
102 | self.value_proj = nn.Linear(embed_dims, embed_dims)
103 | self.output_proj = nn.Linear(embed_dims, embed_dims)
104 | self.init_weights()
105 |
106 | def init_weights(self):
107 | """Default initialization for Parameters of Module."""
108 | constant_init(self.sampling_offsets, 0.)
109 | thetas = torch.arange(
110 | self.num_heads,
111 | dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
112 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
113 | grid_init = (grid_init /
114 | grid_init.abs().max(-1, keepdim=True)[0]).view(
115 | self.num_heads, 1, 1,
116 | 2).repeat(1, self.num_levels*self.num_bev_queue, self.num_points, 1)
117 |
118 | for i in range(self.num_points):
119 | grid_init[:, :, i, :] *= i + 1
120 |
121 | self.sampling_offsets.bias.data = grid_init.view(-1)
122 | constant_init(self.attention_weights, val=0., bias=0.)
123 | xavier_init(self.value_proj, distribution='uniform', bias=0.)
124 | xavier_init(self.output_proj, distribution='uniform', bias=0.)
125 | self._is_init = True
126 |
127 | def forward(self,
128 | query,
129 | key=None,
130 | value=None,
131 | identity=None,
132 | query_pos=None,
133 | key_padding_mask=None,
134 | reference_points=None,
135 | spatial_shapes=None,
136 | level_start_index=None,
137 | flag='decoder',
138 |
139 | **kwargs):
140 | """Forward Function of MultiScaleDeformAttention.
141 |
142 | Args:
143 | query (Tensor): Query of Transformer with shape
144 | (num_query, bs, embed_dims).
145 | key (Tensor): The key tensor with shape
146 | `(num_key, bs, embed_dims)`.
147 | value (Tensor): The value tensor with shape
148 | `(num_key, bs, embed_dims)`.
149 | identity (Tensor): The tensor used for addition, with the
150 | same shape as `query`. Default None. If None,
151 | `query` will be used.
152 | query_pos (Tensor): The positional encoding for `query`.
153 | Default: None.
154 | key_pos (Tensor): The positional encoding for `key`. Default
155 | None.
156 | reference_points (Tensor): The normalized reference
157 | points with shape (bs, num_query, num_levels, 2),
158 | all elements is range in [0, 1], top-left (0,0),
159 | bottom-right (1, 1), including padding area.
160 | or (N, Length_{query}, num_levels, 4), add
161 | additional two dimensions is (w, h) to
162 | form reference boxes.
163 | key_padding_mask (Tensor): ByteTensor for `query`, with
164 | shape [bs, num_key].
165 | spatial_shapes (Tensor): Spatial shape of features in
166 | different levels. With shape (num_levels, 2),
167 | last dimension represents (h, w).
168 | level_start_index (Tensor): The start index of each level.
169 | A tensor has shape ``(num_levels, )`` and can be represented
170 | as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
171 |
172 | Returns:
173 | Tensor: forwarded results with shape [num_query, bs, embed_dims].
174 | """
175 |
176 | if value is None:
177 | assert self.batch_first
178 | bs, len_bev, c = query.shape
179 | value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
180 |
181 | # value = torch.cat([query, query], 0)
182 |
183 | if identity is None:
184 | identity = query
185 | if query_pos is not None:
186 | query = query + query_pos
187 | if not self.batch_first:
188 | # change to (bs, num_query ,embed_dims)
189 | query = query.permute(1, 0, 2)
190 | value = value.permute(1, 0, 2)
191 | bs, num_query, embed_dims = query.shape
192 | _, num_value, _ = value.shape
193 | assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
194 | assert self.num_bev_queue == 2
195 |
196 | query = torch.cat([value[:bs], query], -1)
197 | value = self.value_proj(value)
198 |
199 | if key_padding_mask is not None:
200 | value = value.masked_fill(key_padding_mask[..., None], 0.0)
201 |
202 | value = value.reshape(bs*self.num_bev_queue,
203 | num_value, self.num_heads, -1)
204 |
205 | sampling_offsets = self.sampling_offsets(query)
206 | sampling_offsets = sampling_offsets.view(
207 | bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels, self.num_points, 2)
208 | attention_weights = self.attention_weights(query).view(
209 | bs, num_query, self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
210 | attention_weights = attention_weights.softmax(-1)
211 |
212 | attention_weights = attention_weights.view(bs, num_query,
213 | self.num_heads,
214 | self.num_bev_queue,
215 | self.num_levels,
216 | self.num_points)
217 |
218 | attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\
219 | .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
220 | sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\
221 | .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)
222 |
223 | if reference_points.shape[-1] == 2:
224 | offset_normalizer = torch.stack(
225 | [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
226 | sampling_locations = reference_points[:, :, None, :, None, :] \
227 | + sampling_offsets \
228 | / offset_normalizer[None, None, None, :, None, :]
229 |
230 | elif reference_points.shape[-1] == 4:
231 | sampling_locations = reference_points[:, :, None, :, None, :2] \
232 | + sampling_offsets / self.num_points \
233 | * reference_points[:, :, None, :, None, 2:] \
234 | * 0.5
235 | else:
236 | raise ValueError(
237 | f'Last dim of reference_points must be'
238 | f' 2 or 4, but get {reference_points.shape[-1]} instead.')
239 | if torch.cuda.is_available() and value.is_cuda:
240 |
241 | # using fp16 deformable attention is unstable because it performs many sum operations
242 | if value.dtype == torch.float16:
243 | MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
244 | else:
245 | MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
246 | output = MultiScaleDeformableAttnFunction.apply(
247 | value, spatial_shapes, level_start_index, sampling_locations,
248 | attention_weights, self.im2col_step)
249 | else:
250 |
251 | output = multi_scale_deformable_attn_pytorch(
252 | value, spatial_shapes, sampling_locations, attention_weights)
253 |
254 | # output shape (bs*num_bev_queue, num_query, embed_dims)
255 | # (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
256 | output = output.permute(1, 2, 0)
257 |
258 | # fuse history value and current value
259 | # (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
260 | output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
261 | output = output.mean(-1)
262 |
263 | # (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
264 | output = output.permute(2, 0, 1)
265 |
266 | output = self.output_proj(output)
267 |
268 | if not self.batch_first:
269 | output = output.permute(1, 0, 2)
270 |
271 | return self.dropout(output) + identity
272 |
--------------------------------------------------------------------------------
/projects/configs/toponet_r50_8x1_24e_olv2_subset_A.py:
--------------------------------------------------------------------------------
1 | _base_ = []
2 | custom_imports = dict(imports=['projects.bevformer', 'projects.toponet'])
3 |
4 | # If point cloud range is changed, the models should also change their point
5 | # cloud range accordingly
6 | point_cloud_range = [-51.2, -25.6, -2.3, 51.2, 25.6, 1.7]
7 |
8 | img_norm_cfg = dict(
9 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
10 |
11 | class_names = ['centerline']
12 |
13 | input_modality = dict(
14 | use_lidar=False,
15 | use_camera=True,
16 | use_radar=False,
17 | use_map=False,
18 | use_external=False)
19 | num_cams = 7
20 | pts_dim = 3
21 |
22 | dataset_type = 'OpenLaneV2_subset_A_Dataset'
23 | data_root = 'data/OpenLane-V2/'
24 |
25 | para_method = 'fix_pts_interp'
26 | method_para = dict(n_points=11)
27 | code_size = pts_dim * method_para['n_points']
28 |
29 | _dim_ = 256
30 | _pos_dim_ = _dim_//2
31 | _ffn_dim_ = _dim_*2
32 | _ffn_cfg_ = dict(
33 | type='FFN',
34 | embed_dims=_dim_,
35 | feedforward_channels=_ffn_dim_,
36 | num_fcs=2,
37 | ffn_drop=0.1,
38 | act_cfg=dict(type='ReLU', inplace=True),
39 | ),
40 |
41 | _num_levels_ = 4
42 | bev_h_ = 100
43 | bev_w_ = 200
44 |
45 | model = dict(
46 | type='TopoNet',
47 | img_backbone=dict(
48 | type='ResNet',
49 | depth=50,
50 | num_stages=4,
51 | out_indices=(1, 2, 3),
52 | frozen_stages=1,
53 | norm_cfg=dict(type='BN', requires_grad=False),
54 | norm_eval=True,
55 | style='pytorch',
56 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
57 | img_neck=dict(
58 | type='FPN',
59 | in_channels=[512, 1024, 2048],
60 | out_channels=_dim_,
61 | start_level=0,
62 | add_extra_convs='on_output',
63 | num_outs=_num_levels_,
64 | relu_before_extra_convs=True),
65 | bev_constructor=dict(
66 | type='BEVFormerConstructer',
67 | num_feature_levels=_num_levels_,
68 | num_cams=num_cams,
69 | embed_dims=_dim_,
70 | rotate_prev_bev=True,
71 | use_shift=True,
72 | use_can_bus=True,
73 | pc_range=point_cloud_range,
74 | bev_h=bev_h_,
75 | bev_w=bev_w_,
76 | rotate_center=[bev_h_//2, bev_w_//2],
77 | encoder=dict(
78 | type='BEVFormerEncoder',
79 | num_layers=3,
80 | pc_range=point_cloud_range,
81 | num_points_in_pillar=4,
82 | return_intermediate=False,
83 | transformerlayers=dict(
84 | type='BEVFormerLayer',
85 | attn_cfgs=[
86 | dict(
87 | type='TemporalSelfAttention',
88 | embed_dims=_dim_,
89 | num_levels=1),
90 | dict(
91 | type='SpatialCrossAttention',
92 | embed_dims=_dim_,
93 | num_cams=num_cams,
94 | pc_range=point_cloud_range,
95 | deformable_attention=dict(
96 | type='MSDeformableAttention3D',
97 | embed_dims=_dim_,
98 | num_points=8,
99 | num_levels=_num_levels_)
100 | )
101 | ],
102 | ffn_cfgs=_ffn_cfg_,
103 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
104 | 'ffn', 'norm'))),
105 | positional_encoding=dict(
106 | type='LearnedPositionalEncoding',
107 | num_feats=_pos_dim_,
108 | row_num_embed=bev_h_,
109 | col_num_embed=bev_w_),
110 | ),
111 | bbox_head=dict(
112 | type='CustomDeformableDETRHead',
113 | num_query=100,
114 | num_classes=13,
115 | in_channels=_dim_,
116 | sync_cls_avg_factor=True,
117 | with_box_refine=True,
118 | as_two_stage=False,
119 | transformer=dict(
120 | type='DeformableDetrTransformer',
121 | encoder=dict(
122 | type='DetrTransformerEncoder',
123 | num_layers=6,
124 | transformerlayers=dict(
125 | type='BaseTransformerLayer',
126 | attn_cfgs=dict(
127 | type='MultiScaleDeformableAttention', embed_dims=_dim_),
128 | ffn_cfgs=_ffn_cfg_,
129 | operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
130 | decoder=dict(
131 | type='DeformableDetrTransformerDecoder',
132 | num_layers=6,
133 | return_intermediate=True,
134 | transformerlayers=dict(
135 | type='DetrTransformerDecoderLayer',
136 | attn_cfgs=[
137 | dict(
138 | type='MultiheadAttention',
139 | embed_dims=_dim_,
140 | num_heads=8,
141 | dropout=0.1),
142 | dict(
143 | type='MultiScaleDeformableAttention',
144 | embed_dims=_dim_)
145 | ],
146 | feedforward_channels=_ffn_dim_,
147 | ffn_dropout=0.1,
148 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
149 | 'ffn', 'norm')))),
150 | positional_encoding=dict(
151 | type='SinePositionalEncoding',
152 | num_feats=_pos_dim_,
153 | normalize=True,
154 | offset=-0.5),
155 | loss_cls=dict(
156 | type='FocalLoss',
157 | use_sigmoid=True,
158 | gamma=2.0,
159 | alpha=0.25,
160 | loss_weight=1.0),
161 | loss_bbox=dict(type='L1Loss', loss_weight=2.5),
162 | loss_iou=dict(type='GIoULoss', loss_weight=1.0),
163 | test_cfg=dict(max_per_img=100)),
164 | lane_head=dict(
165 | type='TopoNetHead',
166 | num_classes=1,
167 | in_channels=_dim_,
168 | num_query=200,
169 | bev_h=bev_h_,
170 | bev_w=bev_w_,
171 | pc_range=point_cloud_range,
172 | pts_dim=pts_dim,
173 | sync_cls_avg_factor=False,
174 | code_size=code_size,
175 | code_weights= [1.0 for i in range(code_size)],
176 | transformer=dict(
177 | type='TopoNetTransformerDecoderOnly',
178 | embed_dims=_dim_,
179 | pts_dim=pts_dim,
180 | decoder=dict(
181 | type='TopoNetSGNNDecoder',
182 | num_layers=6,
183 | return_intermediate=True,
184 | transformerlayers=dict(
185 | type='SGNNDecoderLayer',
186 | attn_cfgs=[
187 | dict(
188 | type='MultiheadAttention',
189 | embed_dims=_dim_,
190 | num_heads=8,
191 | dropout=0.1),
192 | dict(
193 | type='CustomMSDeformableAttention',
194 | embed_dims=_dim_,
195 | num_levels=1),
196 | ],
197 | ffn_cfgs=dict(
198 | type='FFN_SGNN',
199 | embed_dims=_dim_,
200 | feedforward_channels=_ffn_dim_,
201 | num_te_classes=13,
202 | edge_weight=0.6),
203 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
204 | 'ffn', 'norm')))),
205 | lclc_head=dict(
206 | type='SingleLayerRelationshipHead',
207 | in_channels_o1=_dim_,
208 | in_channels_o2=_dim_,
209 | shared_param=False,
210 | loss_rel=dict(
211 | type='FocalLoss',
212 | use_sigmoid=True,
213 | gamma=2.0,
214 | alpha=0.25,
215 | loss_weight=5)),
216 | lcte_head=dict(
217 | type='SingleLayerRelationshipHead',
218 | in_channels_o1=_dim_,
219 | in_channels_o2=_dim_,
220 | shared_param=False,
221 | loss_rel=dict(
222 | type='FocalLoss',
223 | use_sigmoid=True,
224 | gamma=2.0,
225 | alpha=0.25,
226 | loss_weight=5)),
227 | bbox_coder=dict(type='LanePseudoCoder'),
228 | loss_cls=dict(
229 | type='FocalLoss',
230 | use_sigmoid=True,
231 | gamma=2.0,
232 | alpha=0.25,
233 | loss_weight=1.5),
234 | loss_bbox=dict(type='L1Loss', loss_weight=0.025)),
235 | # model training and testing settings
236 | train_cfg=dict(
237 | bbox=dict(
238 | assigner=dict(
239 | type='HungarianAssigner',
240 | cls_cost=dict(type='FocalLossCost', weight=1.0),
241 | reg_cost=dict(type='BBoxL1Cost', weight=2.5, box_format='xywh'),
242 | iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0))),
243 | lane=dict(
244 | assigner=dict(
245 | type='LaneHungarianAssigner3D',
246 | cls_cost=dict(type='FocalLossCost', weight=1.5),
247 | reg_cost=dict(type='LaneL1Cost', weight=0.025),
248 | pc_range=point_cloud_range))))
249 |
250 | train_pipeline = [
251 | dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
252 | dict(type='LoadAnnotations3DLane',
253 | with_lane_3d=True, with_lane_label_3d=True, with_lane_adj=True,
254 | with_bbox=True, with_label=True, with_lane_lcte_adj=True),
255 | dict(type='PhotoMetricDistortionMultiViewImage'),
256 | dict(type='CropFrontViewImageForAv2'),
257 | dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
258 | dict(type='NormalizeMultiviewImage', **img_norm_cfg),
259 | dict(type='PadMultiViewImageSame2Max', size_divisor=32),
260 | dict(type='GridMaskMultiViewImage'),
261 | dict(type='LaneParameterize3D', method=para_method, method_para=method_para),
262 | dict(type='CustomFormatBundle3DLane', class_names=class_names),
263 | dict(type='CustomCollect3D', keys=[
264 | 'img', 'gt_lanes_3d', 'gt_lane_labels_3d', 'gt_lane_adj',
265 | 'gt_bboxes', 'gt_labels', 'gt_lane_lcte_adj'])
266 | ]
267 |
268 | test_pipeline = [
269 | dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
270 | dict(type='CropFrontViewImageForAv2'),
271 | dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
272 | dict(type='NormalizeMultiviewImage', **img_norm_cfg),
273 | dict(type='PadMultiViewImageSame2Max', size_divisor=32),
274 | dict(type='CustomFormatBundle3DLane', class_names=class_names),
275 | dict(type='CustomCollect3D', keys=['img'])
276 | ]
277 |
278 | data = dict(
279 | samples_per_gpu=1,
280 | workers_per_gpu=8,
281 | train=dict(
282 | type=dataset_type,
283 | data_root=data_root,
284 | ann_file=data_root + 'data_dict_subset_A_train.pkl',
285 | pipeline=train_pipeline,
286 | classes=class_names,
287 | modality=input_modality,
288 | split='train',
289 | filter_map_change=True,
290 | test_mode=False),
291 | val=dict(
292 | type=dataset_type,
293 | data_root=data_root,
294 | ann_file=data_root + 'data_dict_subset_A_val.pkl',
295 | pipeline=test_pipeline,
296 | classes=class_names,
297 | modality=input_modality,
298 | split='val',
299 | test_mode=True),
300 | test=dict(
301 | type=dataset_type,
302 | data_root=data_root,
303 | ann_file=data_root + 'data_dict_subset_A_val.pkl',
304 | pipeline=test_pipeline,
305 | classes=class_names,
306 | modality=input_modality,
307 | split='val',
308 | test_mode=True)
309 | )
310 |
311 | optimizer = dict(
312 | type='AdamW',
313 | lr=2e-4,
314 | paramwise_cfg=dict(
315 | custom_keys={
316 | 'img_backbone': dict(lr_mult=0.1),
317 | }),
318 | weight_decay=0.01)
319 |
320 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
321 | # learning policy
322 | lr_config = dict(
323 | policy='CosineAnnealing',
324 | warmup='linear',
325 | warmup_iters=500,
326 | warmup_ratio=1.0 / 3,
327 | min_lr_ratio=1e-3)
328 | total_epochs = 24
329 | evaluation = dict(interval=24, pipeline=test_pipeline)
330 |
331 | runner = dict(type='EpochBasedRunner', max_epochs=total_epochs)
332 | log_config = dict(
333 | interval=50,
334 | hooks=[
335 | dict(type='TextLoggerHook'),
336 | dict(type='TensorboardLoggerHook')
337 | ])
338 |
339 | checkpoint_config = dict(interval=1, max_keep_ckpts=1)
340 |
341 | dist_params = dict(backend='nccl')
342 | log_level = 'INFO'
343 | work_dir = None
344 | load_from = None
345 | resume_from = None
346 | workflow = [('train', 1)]
347 |
348 | # NOTE: `auto_scale_lr` is for automatically scaling LR,
349 | # base_batch_size = (8 GPUs) x (1 samples per GPU)
350 | auto_scale_lr = dict(base_batch_size=8)
351 |
--------------------------------------------------------------------------------
/projects/configs/toponet_r50_8x1_24e_olv2_subset_B.py:
--------------------------------------------------------------------------------
1 | _base_ = []
2 | custom_imports = dict(imports=['projects.bevformer', 'projects.toponet'])
3 |
4 | # If point cloud range is changed, the models should also change their point
5 | # cloud range accordingly
6 | point_cloud_range = [-51.2, -25.6, -2.0, 51.2, 25.6, 2.0]
7 |
8 | img_norm_cfg = dict(
9 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
10 |
11 | class_names = ['centerline']
12 |
13 | input_modality = dict(
14 | use_lidar=False,
15 | use_camera=True,
16 | use_radar=False,
17 | use_map=False,
18 | use_external=False)
19 | num_cams = 6
20 | pts_dim = 2
21 |
22 | dataset_type = 'OpenLaneV2_subset_B_Dataset'
23 | data_root = 'data/OpenLane-V2/'
24 |
25 | para_method = 'fix_pts_interp'
26 | method_para = dict(n_points=11)
27 | code_size = pts_dim * method_para['n_points']
28 |
29 | _dim_ = 256
30 | _pos_dim_ = _dim_//2
31 | _ffn_dim_ = _dim_*2
32 | _ffn_cfg_ = dict(
33 | type='FFN',
34 | embed_dims=_dim_,
35 | feedforward_channels=_ffn_dim_,
36 | num_fcs=2,
37 | ffn_drop=0.1,
38 | act_cfg=dict(type='ReLU', inplace=True),
39 | ),
40 |
41 | _num_levels_ = 4
42 | bev_h_ = 100
43 | bev_w_ = 200
44 |
45 | model = dict(
46 | type='TopoNet',
47 | img_backbone=dict(
48 | type='ResNet',
49 | depth=50,
50 | num_stages=4,
51 | out_indices=(1, 2, 3),
52 | frozen_stages=1,
53 | norm_cfg=dict(type='BN', requires_grad=False),
54 | norm_eval=True,
55 | style='pytorch',
56 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
57 | img_neck=dict(
58 | type='FPN',
59 | in_channels=[512, 1024, 2048],
60 | out_channels=_dim_,
61 | start_level=0,
62 | add_extra_convs='on_output',
63 | num_outs=_num_levels_,
64 | relu_before_extra_convs=True),
65 | bev_constructor=dict(
66 | type='BEVFormerConstructer',
67 | num_feature_levels=_num_levels_,
68 | num_cams=num_cams,
69 | embed_dims=_dim_,
70 | rotate_prev_bev=True,
71 | use_shift=True,
72 | use_can_bus=True,
73 | pc_range=point_cloud_range,
74 | bev_h=bev_h_,
75 | bev_w=bev_w_,
76 | rotate_center=[bev_h_//2, bev_w_//2],
77 | encoder=dict(
78 | type='BEVFormerEncoder',
79 | num_layers=3,
80 | pc_range=point_cloud_range,
81 | num_points_in_pillar=4,
82 | return_intermediate=False,
83 | transformerlayers=dict(
84 | type='BEVFormerLayer',
85 | attn_cfgs=[
86 | dict(
87 | type='TemporalSelfAttention',
88 | embed_dims=_dim_,
89 | num_levels=1),
90 | dict(
91 | type='SpatialCrossAttention',
92 | embed_dims=_dim_,
93 | num_cams=num_cams,
94 | pc_range=point_cloud_range,
95 | deformable_attention=dict(
96 | type='MSDeformableAttention3D',
97 | embed_dims=_dim_,
98 | num_points=8,
99 | num_levels=_num_levels_)
100 | )
101 | ],
102 | ffn_cfgs=_ffn_cfg_,
103 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
104 | 'ffn', 'norm'))),
105 | positional_encoding=dict(
106 | type='LearnedPositionalEncoding',
107 | num_feats=_pos_dim_,
108 | row_num_embed=bev_h_,
109 | col_num_embed=bev_w_),
110 | ),
111 | bbox_head=dict(
112 | type='CustomDeformableDETRHead',
113 | num_query=100,
114 | num_classes=13,
115 | in_channels=_dim_,
116 | sync_cls_avg_factor=True,
117 | with_box_refine=True,
118 | as_two_stage=False,
119 | transformer=dict(
120 | type='DeformableDetrTransformer',
121 | encoder=dict(
122 | type='DetrTransformerEncoder',
123 | num_layers=6,
124 | transformerlayers=dict(
125 | type='BaseTransformerLayer',
126 | attn_cfgs=dict(
127 | type='MultiScaleDeformableAttention', embed_dims=_dim_),
128 | ffn_cfgs=_ffn_cfg_,
129 | operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
130 | decoder=dict(
131 | type='DeformableDetrTransformerDecoder',
132 | num_layers=6,
133 | return_intermediate=True,
134 | transformerlayers=dict(
135 | type='DetrTransformerDecoderLayer',
136 | attn_cfgs=[
137 | dict(
138 | type='MultiheadAttention',
139 | embed_dims=_dim_,
140 | num_heads=8,
141 | dropout=0.1),
142 | dict(
143 | type='MultiScaleDeformableAttention',
144 | embed_dims=_dim_)
145 | ],
146 | feedforward_channels=_ffn_dim_,
147 | ffn_dropout=0.1,
148 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
149 | 'ffn', 'norm')))),
150 | positional_encoding=dict(
151 | type='SinePositionalEncoding',
152 | num_feats=_pos_dim_,
153 | normalize=True,
154 | offset=-0.5),
155 | loss_cls=dict(
156 | type='FocalLoss',
157 | use_sigmoid=True,
158 | gamma=2.0,
159 | alpha=0.25,
160 | loss_weight=1.0),
161 | loss_bbox=dict(type='L1Loss', loss_weight=2.5),
162 | loss_iou=dict(type='GIoULoss', loss_weight=1.0),
163 | test_cfg=dict(max_per_img=100)),
164 | lane_head=dict(
165 | type='TopoNetHead',
166 | num_classes=1,
167 | in_channels=_dim_,
168 | num_query=200,
169 | bev_h=bev_h_,
170 | bev_w=bev_w_,
171 | pc_range=point_cloud_range,
172 | pts_dim=pts_dim,
173 | sync_cls_avg_factor=False,
174 | code_size=code_size,
175 | code_weights= [1.0 for i in range(code_size)],
176 | transformer=dict(
177 | type='TopoNetTransformerDecoderOnly',
178 | embed_dims=_dim_,
179 | pts_dim=pts_dim,
180 | decoder=dict(
181 | type='TopoNetSGNNDecoder',
182 | num_layers=6,
183 | return_intermediate=True,
184 | transformerlayers=dict(
185 | type='SGNNDecoderLayer',
186 | attn_cfgs=[
187 | dict(
188 | type='MultiheadAttention',
189 | embed_dims=_dim_,
190 | num_heads=8,
191 | dropout=0.1),
192 | dict(
193 | type='CustomMSDeformableAttention',
194 | embed_dims=_dim_,
195 | num_levels=1),
196 | ],
197 | ffn_cfgs=dict(
198 | type='FFN_SGNN',
199 | embed_dims=_dim_,
200 | feedforward_channels=_ffn_dim_,
201 | num_te_classes=13,
202 | edge_weight=0.6),
203 | operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
204 | 'ffn', 'norm')))),
205 | lclc_head=dict(
206 | type='SingleLayerRelationshipHead',
207 | in_channels_o1=_dim_,
208 | in_channels_o2=_dim_,
209 | shared_param=False,
210 | loss_rel=dict(
211 | type='FocalLoss',
212 | use_sigmoid=True,
213 | gamma=2.0,
214 | alpha=0.25,
215 | loss_weight=5)),
216 | lcte_head=dict(
217 | type='SingleLayerRelationshipHead',
218 | in_channels_o1=_dim_,
219 | in_channels_o2=_dim_,
220 | shared_param=False,
221 | loss_rel=dict(
222 | type='FocalLoss',
223 | use_sigmoid=True,
224 | gamma=2.0,
225 | alpha=0.25,
226 | loss_weight=5)),
227 | bbox_coder=dict(type='LanePseudoCoder'),
228 | loss_cls=dict(
229 | type='FocalLoss',
230 | use_sigmoid=True,
231 | gamma=2.0,
232 | alpha=0.25,
233 | loss_weight=1.5),
234 | loss_bbox=dict(type='L1Loss', loss_weight=0.025)),
235 | # model training and testing settings
236 | train_cfg=dict(
237 | bbox=dict(
238 | assigner=dict(
239 | type='HungarianAssigner',
240 | cls_cost=dict(type='FocalLossCost', weight=1.0),
241 | reg_cost=dict(type='BBoxL1Cost', weight=2.5, box_format='xywh'),
242 | iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0))),
243 | lane=dict(
244 | assigner=dict(
245 | type='LaneHungarianAssigner3D',
246 | cls_cost=dict(type='FocalLossCost', weight=1.5),
247 | reg_cost=dict(type='LaneL1Cost', weight=0.025),
248 | pc_range=point_cloud_range))))
249 |
250 | train_pipeline = [
251 | dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
252 | dict(type='LoadAnnotations3DLane',
253 | with_lane_3d=True, with_lane_label_3d=True, with_lane_adj=True,
254 | with_bbox=True, with_label=True, with_lane_lcte_adj=True),
255 | dict(type='PhotoMetricDistortionMultiViewImage'),
256 | dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
257 | dict(type='NormalizeMultiviewImage', **img_norm_cfg),
258 | dict(type='PadMultiViewImageSame2Max', size_divisor=32),
259 | dict(type='GridMaskMultiViewImage'),
260 | dict(type='LaneParameterize3D', method=para_method, method_para=method_para),
261 | dict(type='CustomFormatBundle3DLane', class_names=class_names),
262 | dict(type='CustomCollect3D', keys=[
263 | 'img', 'gt_lanes_3d', 'gt_lane_labels_3d', 'gt_lane_adj',
264 | 'gt_bboxes', 'gt_labels', 'gt_lane_lcte_adj'])
265 | ]
266 |
267 | test_pipeline = [
268 | dict(type='CustomLoadMultiViewImageFromFiles', to_float32=True),
269 | dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
270 | dict(type='NormalizeMultiviewImage', **img_norm_cfg),
271 | dict(type='PadMultiViewImageSame2Max', size_divisor=32),
272 | dict(type='CustomFormatBundle3DLane', class_names=class_names),
273 | dict(type='CustomCollect3D', keys=['img'])
274 | ]
275 |
276 | data = dict(
277 | samples_per_gpu=1,
278 | workers_per_gpu=8,
279 | train=dict(
280 | type=dataset_type,
281 | data_root=data_root,
282 | ann_file=data_root + 'data_dict_subset_B_train.pkl',
283 | pipeline=train_pipeline,
284 | classes=class_names,
285 | modality=input_modality,
286 | split='train',
287 | filter_map_change=True,
288 | test_mode=False),
289 | val=dict(
290 | type=dataset_type,
291 | data_root=data_root,
292 | ann_file=data_root + 'data_dict_subset_B_val.pkl',
293 | pipeline=test_pipeline,
294 | classes=class_names,
295 | modality=input_modality,
296 | split='val',
297 | test_mode=True),
298 | test=dict(
299 | type=dataset_type,
300 | data_root=data_root,
301 | ann_file=data_root + 'data_dict_subset_B_val.pkl',
302 | pipeline=test_pipeline,
303 | classes=class_names,
304 | modality=input_modality,
305 | split='val',
306 | test_mode=True)
307 | )
308 |
309 | optimizer = dict(
310 | type='AdamW',
311 | lr=2e-4,
312 | paramwise_cfg=dict(
313 | custom_keys={
314 | 'img_backbone': dict(lr_mult=0.1),
315 | }),
316 | weight_decay=0.01)
317 |
318 | optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
319 | # learning policy
320 | lr_config = dict(
321 | policy='CosineAnnealing',
322 | warmup='linear',
323 | warmup_iters=500,
324 | warmup_ratio=1.0 / 3,
325 | min_lr_ratio=1e-3)
326 | total_epochs = 24
327 | evaluation = dict(interval=24, pipeline=test_pipeline)
328 |
329 | runner = dict(type='EpochBasedRunner', max_epochs=total_epochs)
330 | log_config = dict(
331 | interval=50,
332 | hooks=[
333 | dict(type='TextLoggerHook'),
334 | dict(type='TensorboardLoggerHook')
335 | ])
336 |
337 | checkpoint_config = dict(interval=1, max_keep_ckpts=1)
338 |
339 | dist_params = dict(backend='nccl')
340 | log_level = 'INFO'
341 | work_dir = None
342 | load_from = None
343 | resume_from = None
344 | workflow = [('train', 1)]
345 |
346 | # NOTE: `auto_scale_lr` is for automatically scaling LR,
347 | # base_batch_size = (8 GPUs) x (1 samples per GPU)
348 | auto_scale_lr = dict(base_batch_size=8)
349 |
--------------------------------------------------------------------------------
/projects/toponet/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .models import *
3 | from .core import *
4 | from .utils import *
5 |
--------------------------------------------------------------------------------
/projects/toponet/core/__init__.py:
--------------------------------------------------------------------------------
1 | from .lane import *
2 | from .visualizer import *
3 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/__init__.py:
--------------------------------------------------------------------------------
1 | from .assigners import *
2 | from .coders import *
3 | from .match_costs import *
4 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/assigners/__init__.py:
--------------------------------------------------------------------------------
1 | from .lane_hungarian_assigner import LaneHungarianAssigner3D
--------------------------------------------------------------------------------
/projects/toponet/core/lane/assigners/lane_hungarian_assigner.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mmdet.core.bbox.builder import BBOX_ASSIGNERS
4 | from mmdet.core.bbox.assigners import AssignResult
5 | from mmdet.core.bbox.assigners import BaseAssigner
6 | from mmdet.core.bbox.match_costs import build_match_cost
7 | from mmdet.models.utils.transformer import inverse_sigmoid
8 | from ..util import normalize_3dlane
9 |
10 | try:
11 | from scipy.optimize import linear_sum_assignment
12 | except ImportError:
13 | linear_sum_assignment = None
14 |
15 |
16 | @BBOX_ASSIGNERS.register_module()
17 | class LaneHungarianAssigner3D(BaseAssigner):
18 |
19 | def __init__(self,
20 | cls_cost=dict(type='ClassificationCost', weight=1.),
21 | reg_cost=dict(type='BBoxL1Cost', weight=1.0),
22 | normalize_gt=False,
23 | pc_range=None):
24 | self.cls_cost = build_match_cost(cls_cost)
25 | self.reg_cost = build_match_cost(reg_cost)
26 | self.normalize_gt = normalize_gt
27 | self.pc_range = pc_range
28 |
29 | def assign(self,
30 | lanes_pred,
31 | cls_pred,
32 | gt_lanes,
33 | gt_labels,
34 | gt_bboxes_ignore=None,
35 | eps=1e-7):
36 | """Computes one-to-one matching based on the weighted costs.
37 | This method assign each query prediction to a ground truth or
38 | background. The `assigned_gt_inds` with -1 means don't care,
39 | 0 means negative sample, and positive number is the index (1-based)
40 | of assigned gt.
41 | The assignment is done in the following steps, the order matters.
42 | 1. assign every prediction to -1
43 | 2. compute the weighted costs
44 | 3. do Hungarian matching on CPU based on the costs
45 | 4. assign all to 0 (background) first, then for each matched pair
46 | between predictions and gts, treat this prediction as foreground
47 | and assign the corresponding gt index (plus 1) to it.
48 | Args:
49 | bbox_pred (Tensor): Predicted boxes with normalized coordinates
50 | (cx, cy, w, h), which are all in range [0, 1]. Shape
51 | [num_query, 4].
52 | cls_pred (Tensor): Predicted classification logits, shape
53 | [num_query, num_class].
54 | gt_bboxes (Tensor): Ground truth boxes with unnormalized
55 | coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
56 | gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
57 | gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
58 | labelled as `ignored`. Default None.
59 | eps (int | float, optional): A value added to the denominator for
60 | numerical stability. Default 1e-7.
61 | Returns:
62 | :obj:`AssignResult`: The assigned result.
63 | """
64 | assert gt_bboxes_ignore is None, \
65 | 'Only case when gt_bboxes_ignore is None is supported.'
66 | num_gts, num_bboxes = gt_lanes.size(0), lanes_pred.size(0)
67 | # 1. assign -1 by default
68 | assigned_gt_inds = lanes_pred.new_full((num_bboxes, ),
69 | -1,
70 | dtype=torch.long)
71 | assigned_labels = lanes_pred.new_full((num_bboxes, ),
72 | -1,
73 | dtype=torch.long)
74 | if num_gts == 0 or num_bboxes == 0:
75 | # No ground truth or boxes, return empty assignment
76 | if num_gts == 0:
77 | # No ground truth, assign all to background
78 | assigned_gt_inds[:] = 0
79 | return AssignResult(
80 | num_gts, assigned_gt_inds, None, labels=assigned_labels)
81 |
82 | # 2. compute the weighted costs
83 | # classification and bboxcost.
84 | cls_cost = self.cls_cost(cls_pred, gt_labels)
85 |
86 | if self.normalize_gt:
87 | normalized_gt_lanes = normalize_3dlane(gt_lanes, self.pc_range)
88 | else:
89 | normalized_gt_lanes = gt_lanes
90 |
91 | # regression L1 cost
92 | reg_cost = self.reg_cost(lanes_pred, normalized_gt_lanes)
93 |
94 | # weighted sum of above two costs
95 | cost = cls_cost + reg_cost
96 |
97 | # 3. do Hungarian matching on CPU using linear_sum_assignment
98 | cost = cost.detach().cpu()
99 | if linear_sum_assignment is None:
100 | raise ImportError('Please run "pip install scipy" '
101 | 'to install scipy first.')
102 | matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
103 | matched_row_inds = torch.from_numpy(matched_row_inds).to(
104 | lanes_pred.device)
105 | matched_col_inds = torch.from_numpy(matched_col_inds).to(
106 | lanes_pred.device)
107 |
108 | # 4. assign backgrounds and foregrounds
109 | # assign all indices to backgrounds first
110 | assigned_gt_inds[:] = 0
111 |
112 | # assign foregrounds based on matching results
113 | assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
114 | assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
115 |
116 | return AssignResult(
117 | num_gts, assigned_gt_inds, None, labels=assigned_labels)
118 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/coders/__init__.py:
--------------------------------------------------------------------------------
1 | from .lane_coder import LanePseudoCoder
2 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/coders/lane_coder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mmdet.core.bbox import BaseBBoxCoder
4 | from mmdet.core.bbox.builder import BBOX_CODERS
5 | from ..util import denormalize_3dlane
6 | import numpy as np
7 |
8 |
9 | @BBOX_CODERS.register_module()
10 | class LanePseudoCoder(BaseBBoxCoder):
11 |
12 | def __init__(self, denormalize=False):
13 | self.denormalize = denormalize
14 |
15 | def encode(self):
16 | pass
17 |
18 | def decode_single(self, cls_scores, lane_preds):
19 | """Decode bboxes.
20 | Args:
21 | cls_scores (Tensor): Outputs from the classification head, \
22 | shape [num_query, cls_out_channels]. Note \
23 | cls_out_channels should includes background.
24 | lane_preds (Tensor): Outputs from the regression \
25 | head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
26 | Shape [num_query, 9].
27 | Returns:
28 | list[dict]: Decoded boxes.
29 | """
30 |
31 | cls_scores = cls_scores.sigmoid()
32 | scores, labels = cls_scores.max(-1)
33 | if self.denormalize:
34 | final_lane_preds = denormalize_3dlane(lane_preds, self.pc_range)
35 | else:
36 | final_lane_preds = lane_preds
37 |
38 | predictions_dict = {
39 | 'lane3d': final_lane_preds.detach().cpu().numpy(),
40 | 'scores': scores.detach().cpu().numpy(),
41 | 'labels': labels.detach().cpu().numpy()
42 | }
43 |
44 | return predictions_dict
45 |
46 | def decode(self, preds_dicts):
47 | """Decode bboxes.
48 | Args:
49 | all_cls_scores (Tensor): Outputs from the classification head, \
50 | shape [nb_dec, bs, num_query, cls_out_channels]. Note \
51 | cls_out_channels should includes background.
52 | all_bbox_preds (Tensor): Sigmoid outputs from the regression \
53 | head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
54 | Shape [nb_dec, bs, num_query, 9].
55 | Returns:
56 | list[dict]: Decoded boxes.
57 | """
58 | all_cls_scores = preds_dicts['all_cls_scores'][-1]
59 | all_lanes_preds = preds_dicts['all_lanes_preds'][-1]
60 |
61 | batch_size = all_cls_scores.size()[0]
62 | predictions_list = []
63 | for i in range(batch_size):
64 | predictions_list.append(self.decode_single(all_cls_scores[i], all_lanes_preds[i]))
65 | return predictions_list
66 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/match_costs/__init__.py:
--------------------------------------------------------------------------------
1 | from .match_cost import LaneL1Cost
2 |
3 | __all__ = ['LaneL1Cost']
--------------------------------------------------------------------------------
/projects/toponet/core/lane/match_costs/match_cost.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mmdet.core.bbox.match_costs.builder import MATCH_COST
3 |
4 |
5 | @MATCH_COST.register_module()
6 | class LaneL1Cost(object):
7 |
8 | def __init__(self, weight=1.):
9 | self.weight = weight
10 |
11 | def __call__(self, lane_pred, gt_lanes):
12 | lane_cost = torch.cdist(lane_pred, gt_lanes, p=1)
13 | return lane_cost * self.weight
14 |
--------------------------------------------------------------------------------
/projects/toponet/core/lane/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from shapely.geometry import LineString
4 |
5 | def normalize_3dlane(lanes, pc_range):
6 | normalized_lanes = lanes.clone()
7 | normalized_lanes[..., 0::3] = (lanes[..., 0::3] - pc_range[0]) / (pc_range[3] - pc_range[0])
8 | normalized_lanes[..., 1::3] = (lanes[..., 1::3] - pc_range[1]) / (pc_range[4] - pc_range[1])
9 | normalized_lanes[..., 2::3] = (lanes[..., 2::3] - pc_range[2]) / (pc_range[5] - pc_range[2])
10 | normalized_lanes = torch.clamp(normalized_lanes, 0, 1)
11 |
12 | return normalized_lanes
13 |
14 | def denormalize_3dlane(normalized_lanes, pc_range):
15 | lanes = normalized_lanes.clone()
16 | lanes[..., 0::3] = (normalized_lanes[..., 0::3] * (pc_range[3] - pc_range[0]) + pc_range[0])
17 | lanes[..., 1::3] = (normalized_lanes[..., 1::3] * (pc_range[4] - pc_range[1]) + pc_range[1])
18 | lanes[..., 2::3] = (normalized_lanes[..., 2::3] * (pc_range[5] - pc_range[2]) + pc_range[2])
19 | return lanes
20 |
21 | def fix_pts_interpolate(lane, n_points):
22 | ls = LineString(lane)
23 | distances = np.linspace(0, ls.length, n_points)
24 | lane = np.array([ls.interpolate(distance).coords[0] for distance in distances])
25 | return lane
26 |
--------------------------------------------------------------------------------
/projects/toponet/core/visualizer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/TopoNet/284442bd36e8c0af5dbbf4285ebb79897c1ccd82/projects/toponet/core/visualizer/__init__.py
--------------------------------------------------------------------------------
/projects/toponet/core/visualizer/lane.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mmcv
3 | import cv2
4 | from openlanev2.visualization.utils import COLOR_DICT
5 |
6 | GT_COLOR = (0, 255, 0)
7 | PRED_COLOR = (44, 63, 255)
8 |
9 |
10 | def show_results(image_list, lidar2imgs, gt_lane, pred_lane, gt_te=None, pred_te=None):
11 | res_image_list = []
12 | for idx, (raw_img, lidar2img) in enumerate(zip(image_list, lidar2imgs)):
13 | image = raw_img.copy()
14 | for lane in gt_lane:
15 | xyz1 = np.concatenate([lane, np.ones((lane.shape[0], 1))], axis=1)
16 | xyz1 = xyz1 @ lidar2img.T
17 | xyz1 = xyz1[xyz1[:, 2] > 1e-5]
18 | if xyz1.shape[0] == 0:
19 | continue
20 | points_2d = xyz1[:, :2] / xyz1[:, 2:3]
21 | points_2d = points_2d.astype(int)
22 | image = cv2.polylines(image, points_2d[None], False, GT_COLOR, 2)
23 |
24 | for lane in pred_lane:
25 | xyz1 = np.concatenate([lane, np.ones((lane.shape[0], 1))], axis=1)
26 | xyz1 = xyz1 @ lidar2img.T
27 | xyz1 = xyz1[xyz1[:, 2] > 1e-5]
28 | if xyz1.shape[0] == 0:
29 | continue
30 | points_2d = xyz1[:, :2] / xyz1[:, 2:3]
31 | points_2d = points_2d.astype(int)
32 | image = cv2.polylines(image, points_2d[None], False, PRED_COLOR, 2)
33 |
34 | if idx == 0:
35 | if gt_te is not None:
36 | for bbox, attr in gt_te:
37 | b = bbox.astype(int)
38 | color = COLOR_DICT[attr]
39 | image = draw_corner_rectangle(image, (b[0], b[1]), (b[2], b[3]), color, 3, 1)
40 | if pred_te is not None:
41 | for bbox, attr in pred_te:
42 | b = bbox.astype(int)
43 | color = COLOR_DICT[attr]
44 | image = cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), color, 3)
45 |
46 | res_image_list.append(image)
47 |
48 | return res_image_list
49 |
50 | def show_bev_results(gt_lane, pred_lane, gt_lclc=None, pred_lclc=None, only=None, map_size=[-55, 55, -30, 30], scale=10):
51 | image = np.zeros((int(scale*(map_size[1]-map_size[0])), int(scale*(map_size[3] - map_size[2])), 3), dtype=np.uint8)
52 | if only is None or only == 'gt':
53 | for lane in gt_lane:
54 | draw_coor = (scale * (-lane[:, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
55 | image = cv2.polylines(image, [draw_coor[:, [1,0]]], False, GT_COLOR, max(round(scale * 0.2), 1))
56 | image = cv2.circle(image, (draw_coor[0, 1], draw_coor[0, 0]), max(2, round(scale * 0.5)), GT_COLOR, -1)
57 | image = cv2.circle(image, (draw_coor[-1, 1], draw_coor[-1, 0]), max(2, round(scale * 0.5)) , GT_COLOR, -1)
58 |
59 | if gt_lclc is not None:
60 | for l1_idx, lclc in enumerate(gt_lclc):
61 | for l2_idx, connected in enumerate(lclc):
62 | if connected:
63 | l1 = gt_lane[l1_idx]
64 | l2 = gt_lane[l2_idx]
65 | l1_mid = len(l1) // 2
66 | l2_mid = len(l2) // 2
67 | p1 = (scale * (-l1[l1_mid, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
68 | p2 = (scale * (-l2[l2_mid, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
69 | image = cv2.arrowedLine(image, (p1[1], p1[0]), (p2[1], p2[0]), GT_COLOR, max(round(scale * 0.1), 1), tipLength=0.1)
70 |
71 | if only is None or only == 'pred':
72 | for lane in pred_lane:
73 | draw_coor = (scale * (-lane[:, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
74 | image = cv2.polylines(image, [draw_coor[:, [1,0]]], False, PRED_COLOR, max(round(scale * 0.2), 1))
75 | image = cv2.circle(image, (draw_coor[0, 1], draw_coor[0, 0]), max(2, round(scale * 0.5)), PRED_COLOR, -1)
76 | image = cv2.circle(image, (draw_coor[-1, 1], draw_coor[-1, 0]), max(2, round(scale * 0.5)) , PRED_COLOR, -1)
77 |
78 | if pred_lclc is not None:
79 | for l1_idx, lclc in enumerate(pred_lclc):
80 | for l2_idx, connected in enumerate(lclc):
81 | if connected:
82 | l1 = pred_lane[l1_idx]
83 | l2 = pred_lane[l2_idx]
84 | l1_mid = len(l1) // 2
85 | l2_mid = len(l2) // 2
86 | p1 = (scale * (-l1[l1_mid, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
87 | p2 = (scale * (-l2[l2_mid, :2] + np.array([map_size[1], map_size[3]]))).astype(int)
88 | image = cv2.arrowedLine(image, (p1[1], p1[0]), (p2[1], p2[0]), PRED_COLOR, max(round(scale * 0.1), 1), tipLength=0.1)
89 |
90 | return image
91 |
92 | def draw_corner_rectangle(img: np.ndarray, pt1: tuple, pt2: tuple, color: tuple,
93 | corner_thickness: int = 3, edge_thickness: int = 2,
94 | centre_cross: bool = False, lineType: int = cv2.LINE_8):
95 |
96 | corner_length = min(abs(pt1[0] - pt2[0]), abs(pt1[1] - pt2[1])) // 4
97 | e_args = [color, edge_thickness, lineType]
98 | c_args = [color, corner_thickness, lineType]
99 |
100 | # edges
101 | img = cv2.line(img, (pt1[0] + corner_length, pt1[1]), (pt2[0] - corner_length, pt1[1]), *e_args)
102 | img = cv2.line(img, (pt2[0], pt1[1] + corner_length), (pt2[0], pt2[1] - corner_length), *e_args)
103 | img = cv2.line(img, (pt1[0], pt1[1] + corner_length), (pt1[0], pt2[1] - corner_length), *e_args)
104 | img = cv2.line(img, (pt1[0] + corner_length, pt2[1]), (pt2[0] - corner_length, pt2[1]), *e_args)
105 | # corners
106 | img = cv2.line(img, pt1, (pt1[0] + corner_length, pt1[1]), *c_args)
107 | img = cv2.line(img, pt1, (pt1[0], pt1[1] + corner_length), *c_args)
108 | img = cv2.line(img, (pt2[0], pt1[1]), (pt2[0] - corner_length, pt1[1]), *c_args)
109 | img = cv2.line(img, (pt2[0], pt1[1]), (pt2[0], pt1[1] + corner_length), *c_args)
110 | img = cv2.line(img, (pt1[0], pt2[1]), (pt1[0] + corner_length, pt2[1]), *c_args)
111 | img = cv2.line(img, (pt1[0], pt2[1]), (pt1[0], pt2[1] - corner_length), *c_args)
112 | img = cv2.line(img, pt2, (pt2[0] - corner_length, pt2[1]), *c_args)
113 | img = cv2.line(img, pt2, (pt2[0], pt2[1] - corner_length), *c_args)
114 |
115 | if centre_cross:
116 | cx, cy = int((pt1[0] + pt2[0]) / 2), int((pt1[1] + pt2[1]) / 2)
117 | img = cv2.line(img, (cx - corner_length, cy), (cx + corner_length, cy), *e_args)
118 | img = cv2.line(img, (cx, cy - corner_length), (cx, cy + corner_length), *e_args)
119 |
120 | return img
121 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipelines import *
2 | from .openlanev2_subset_A_dataset import OpenLaneV2_subset_A_Dataset
3 | from .openlanev2_subset_B_dataset import OpenLaneV2_subset_B_Dataset
4 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/openlanev2_subset_B_dataset.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import os
8 | import random
9 | import copy
10 |
11 | import numpy as np
12 | import torch
13 | import mmcv
14 | import cv2
15 |
16 | from pyquaternion import Quaternion
17 | from mmcv.parallel import DataContainer as DC
18 | from mmdet.datasets import DATASETS
19 | from mmdet3d.datasets import Custom3DDataset
20 | from openlanev2.evaluation import evaluate as openlanev2_evaluate
21 | from openlanev2.utils import format_metric
22 | from openlanev2.visualization import draw_annotation_pv, assign_attribute, assign_topology
23 |
24 | from ..core.lane.util import fix_pts_interpolate
25 | from ..core.visualizer.lane import show_bev_results
26 |
27 | from .openlanev2_subset_A_dataset import OpenLaneV2_subset_A_Dataset
28 |
29 | @DATASETS.register_module()
30 | class OpenLaneV2_subset_B_Dataset(OpenLaneV2_subset_A_Dataset):
31 | CAMS = ('CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
32 | 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT')
33 | MAP_CHANGE_LOGS = [
34 | 'a6daedc3063b421cb3a05019e545f925',
35 | '02f1e5e2fc544798aad223f5ae5e8440',
36 | '55638ae3a8b34572aef756ee7fbce0df',
37 | '20ec831deb0f44e397497198cbe5a97c',
38 | ]
39 |
40 | def get_ann_info(self, index):
41 | """Get annotation info according to the given index.
42 |
43 | Args:
44 | index (int): Index of the annotation data to get.
45 |
46 | Returns:
47 | dict: annotation information
48 | """
49 | info = self.data_infos[index]
50 | ann_info = info['annotation']
51 |
52 | gt_lanes = [np.array(lane['points'][:, :2], dtype=np.float32) for lane in ann_info['lane_centerline']]
53 | gt_lane_labels_3d = np.zeros(len(gt_lanes), dtype=np.int64)
54 | lane_adj = np.array(ann_info['topology_lclc'], dtype=np.float32)
55 |
56 | # only use traffic light attribute
57 | te_bboxes = np.array([np.array(sign['points'], dtype=np.float32).flatten() for sign in ann_info['traffic_element']])
58 | te_labels = np.array([sign['attribute'] for sign in ann_info['traffic_element']], dtype=np.int64)
59 | if len(te_bboxes) == 0:
60 | te_bboxes = np.zeros((0, 4), dtype=np.float32)
61 | te_labels = np.zeros((0, ), dtype=np.int64)
62 |
63 | lane_lcte_adj = np.array(ann_info['topology_lcte'], dtype=np.float32)
64 |
65 | assert len(gt_lanes) == lane_adj.shape[0]
66 | assert len(gt_lanes) == lane_adj.shape[1]
67 | assert len(gt_lanes) == lane_lcte_adj.shape[0]
68 | assert len(te_bboxes) == lane_lcte_adj.shape[1]
69 |
70 | annos = dict(
71 | gt_lanes_3d = gt_lanes,
72 | gt_lane_labels_3d = gt_lane_labels_3d,
73 | gt_lane_adj = lane_adj,
74 | bboxes = te_bboxes,
75 | labels = te_labels,
76 | gt_lane_lcte_adj = lane_lcte_adj
77 | )
78 | return annos
79 |
80 | def format_results(self, results, jsonfile_prefix=None):
81 | pred_dict = {}
82 | pred_dict['method'] = 'TopoNet'
83 | pred_dict['authors'] = []
84 | pred_dict['e-mail'] = 'dummy'
85 | pred_dict['institution / company'] = 'OpenDriveLab'
86 | pred_dict['country / region'] = 'CN'
87 | pred_dict['results'] = {}
88 | for idx, result in enumerate(results):
89 | info = self.data_infos[idx]
90 | key = (self.split, info['segment_id'], str(info['timestamp']))
91 |
92 | pred_info = dict(
93 | lane_centerline=[],
94 | traffic_element=[],
95 | topology_lclc=None,
96 | topology_lcte=None
97 | )
98 |
99 | if result['lane_results'] is not None:
100 | lane_results = result['lane_results']
101 | scores = lane_results[1]
102 | valid_indices = np.argsort(-scores)
103 | lanes = lane_results[0][valid_indices]
104 | lanes = lanes.reshape(-1, lanes.shape[-1] // 2, 2)
105 | lanes = np.concatenate([lanes, np.zeros_like(lanes[..., 0:1])], axis=-1)
106 |
107 | scores = scores[valid_indices]
108 | for pred_idx, (lane, score) in enumerate(zip(lanes, scores)):
109 | points = fix_pts_interpolate(lane, 11)
110 | lc_info = dict(
111 | id = 10000 + pred_idx,
112 | points = points.astype(np.float32),
113 | confidence = score.item()
114 | )
115 | pred_info['lane_centerline'].append(lc_info)
116 |
117 | if result['bbox_results'] is not None:
118 | te_results = result['bbox_results']
119 | scores = te_results[1]
120 | te_valid_indices = np.argsort(-scores)
121 | tes = te_results[0][te_valid_indices]
122 | scores = scores[te_valid_indices]
123 | class_idxs = te_results[2][te_valid_indices]
124 | for pred_idx, (te, score, class_idx) in enumerate(zip(tes, scores, class_idxs)):
125 | te_info = dict(
126 | id = 20000 + pred_idx,
127 | category = 1 if class_idx < 4 else 2,
128 | attribute = class_idx,
129 | points = te.reshape(2, 2).astype(np.float32),
130 | confidence = score
131 | )
132 | pred_info['traffic_element'].append(te_info)
133 |
134 | if result['lclc_results'] is not None:
135 | pred_info['topology_lclc'] = result['lclc_results'].astype(np.float32)[valid_indices][:, valid_indices]
136 | else:
137 | pred_info['topology_lclc'] = np.zeros((len(pred_info['lane_centerline']), len(pred_info['lane_centerline'])), dtype=np.float32)
138 |
139 | if result['lcte_results'] is not None:
140 | pred_info['topology_lcte'] = result['lcte_results'].astype(np.float32)[valid_indices][:, te_valid_indices]
141 | else:
142 | pred_info['topology_lcte'] = np.zeros((len(pred_info['lane_centerline']), len(pred_info['traffic_element'])), dtype=np.float32)
143 |
144 | pred_dict['results'][key] = dict(predictions=pred_info)
145 |
146 | return pred_dict
147 |
148 | @staticmethod
149 | def _render_surround_img(images):
150 | all_image = []
151 | img_height = images[1].shape[0]
152 |
153 | for idx in [2, 0, 1, 5, 3, 4]:
154 | if idx == 4 or idx == 1:
155 | all_image.append(images[idx])
156 | else:
157 | all_image.append(images[idx])
158 | all_image.append(np.full((img_height, 10, 3), (255, 255, 255), dtype=np.uint8))
159 |
160 | surround_img_upper = None
161 | surround_img_upper = np.concatenate(all_image[:5], 1)
162 |
163 | surround_img_down = None
164 | surround_img_down = np.concatenate(all_image[5:], 1)
165 |
166 | surround_img = np.concatenate((surround_img_upper, np.full((10, surround_img_down.shape[1], 3), (255, 255, 255), dtype=np.uint8), surround_img_down), 0)
167 | surround_img = cv2.resize(surround_img, None, fx=0.5, fy=0.5)
168 |
169 | return surround_img
170 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .transform_3d import (
2 | PadMultiViewImage, NormalizeMultiviewImage,
3 | PhotoMetricDistortionMultiViewImage, CustomCollect3D, RandomScaleImageMultiViewImage,
4 | GridMaskMultiViewImage, CropFrontViewImageForAv2)
5 | from .transform_3d_lane import LaneParameterize3D, LaneLengthFilter
6 | from .formating import CustomFormatBundle3DLane
7 | from .loading import CustomLoadMultiViewImageFromFiles, LoadAnnotations3DLane
8 |
9 | __all__ = [
10 | 'PadMultiViewImage', 'NormalizeMultiviewImage',
11 | 'PhotoMetricDistortionMultiViewImage', 'CustomCollect3D', 'RandomScaleImageMultiViewImage',
12 | 'GridMaskMultiViewImage', 'CropFrontViewImageForAv2',
13 | 'LaneParameterize3D', 'LaneLengthFilter',
14 | 'CustomFormatBundle3DLane',
15 | 'CustomLoadMultiViewImageFromFiles', 'LoadAnnotations3DLane'
16 | ]
17 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/pipelines/formating.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | from mmcv.parallel import DataContainer as DC
9 |
10 | from mmdet.datasets.builder import PIPELINES
11 | from mmdet.datasets.pipelines import to_tensor
12 | from mmdet3d.datasets.pipelines import DefaultFormatBundle3D
13 |
14 |
15 | @PIPELINES.register_module()
16 | class CustomFormatBundle3DLane(DefaultFormatBundle3D):
17 | """Custom formatting bundle for 3D Lane.
18 | """
19 |
20 | def __init__(self, class_names, **kwargs):
21 | super(CustomFormatBundle3DLane, self).__init__(class_names, **kwargs)
22 |
23 | def __call__(self, results):
24 | """Call function to transform and format common fields in results.
25 |
26 | Args:
27 | results (dict): Result dict contains the data to convert.
28 |
29 | Returns:
30 | dict: The result dict contains the data that is formatted with
31 | default bundle.
32 | """
33 | if 'gt_lanes_3d' in results:
34 | results['gt_lanes_3d'] = DC(
35 | to_tensor(results['gt_lanes_3d']))
36 | if 'gt_lane_labels_3d' in results:
37 | results['gt_lane_labels_3d'] = DC(
38 | to_tensor(results['gt_lane_labels_3d']))
39 | if 'gt_lane_adj' in results:
40 | results['gt_lane_adj'] = DC(
41 | to_tensor(results['gt_lane_adj']))
42 | if 'gt_lane_lcte_adj' in results:
43 | results['gt_lane_lcte_adj'] = DC(
44 | to_tensor(results['gt_lane_lcte_adj']))
45 |
46 | results = super(CustomFormatBundle3DLane, self).__call__(results)
47 | return results
48 |
49 | def __repr__(self):
50 | """str: Return a string that describes the module."""
51 | repr_str = self.__class__.__name__
52 | repr_str += f'(class_names={self.class_names}, '
53 | repr_str += f'with_gt={self.with_gt}, with_label={self.with_label})'
54 | return repr_str
55 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/pipelines/loading.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | import mmcv
9 | from mmdet.datasets.builder import PIPELINES
10 | from mmdet3d.datasets.pipelines import LoadAnnotations3D
11 |
12 |
13 | @PIPELINES.register_module()
14 | class CustomLoadMultiViewImageFromFiles(object):
15 |
16 | def __init__(self, to_float32=False, color_type='unchanged'):
17 | self.to_float32 = to_float32
18 | self.color_type = color_type
19 |
20 | def __call__(self, results):
21 | filename = results['img_filename']
22 | # img is of shape (h, w, c, num_views)
23 | img = [mmcv.imread(name, self.color_type) for name in filename]
24 | if self.to_float32:
25 | img = [_.astype(np.float32) for _ in img]
26 | results['filename'] = filename
27 | results['img'] = img
28 | results['img_shape'] = [img_.shape for img_ in img]
29 | results['ori_shape'] = [img_.shape for img_ in img]
30 | # Set initial values for default meta_keys
31 | results['pad_shape'] = [img_.shape for img_ in img]
32 | results['crop_shape'] = [np.zeros(2) for img_ in img]
33 | results['scale_factor'] = [1.0 for img_ in img]
34 | num_channels = 1 if len(img[0].shape) < 3 else img[0].shape[2]
35 | results['img_norm_cfg'] = dict(
36 | mean=np.zeros(num_channels, dtype=np.float32),
37 | std=np.ones(num_channels, dtype=np.float32),
38 | to_rgb=False)
39 | return results
40 |
41 | def __repr__(self):
42 | """str: Return a string that describes the module."""
43 | repr_str = self.__class__.__name__
44 | repr_str += f'(to_float32={self.to_float32}, '
45 | repr_str += f"color_type='{self.color_type}')"
46 | return repr_str
47 |
48 |
49 | @PIPELINES.register_module()
50 | class LoadAnnotations3DLane(LoadAnnotations3D):
51 | """Load Annotations3D Lane.
52 |
53 | Args:
54 | with_lane_3d (bool, optional): Whether to load 3D Lanes.
55 | Defaults to True.
56 | with_lane_label_3d (bool, optional): Whether to load 3D Lanes Labels.
57 | Defaults to True.
58 | with_lane_adj (bool, optional): Whether to load Lane-Lane Adjacency.
59 | Defaults to True.
60 | with_lane_lcte_adj (bool, optional): Whether to load Lane-TE Adjacency.
61 | Defaults to False.
62 | """
63 |
64 | def __init__(self,
65 | with_lane_3d=True,
66 | with_lane_label_3d=True,
67 | with_lane_adj=True,
68 | with_lane_lcte_adj=False,
69 | with_bbox_3d=False,
70 | with_label_3d=False,
71 | **kwargs):
72 | super().__init__(with_bbox_3d, with_label_3d, **kwargs)
73 | self.with_lane_3d = with_lane_3d
74 | self.with_lane_label_3d = with_lane_label_3d
75 | self.with_lane_adj = with_lane_adj
76 | self.with_lane_lcte_adj = with_lane_lcte_adj
77 |
78 | def _load_lanes_3d(self, results):
79 | results['gt_lanes_3d'] = results['ann_info']['gt_lanes_3d']
80 | if self.with_lane_label_3d:
81 | results['gt_lane_labels_3d'] = results['ann_info']['gt_lane_labels_3d']
82 | if self.with_lane_adj:
83 | results['gt_lane_adj'] = results['ann_info']['gt_lane_adj']
84 | if self.with_lane_lcte_adj:
85 | results['gt_lane_lcte_adj'] = results['ann_info']['gt_lane_lcte_adj']
86 | return results
87 |
88 | def __call__(self, results):
89 | """Call function to load multiple types annotations.
90 |
91 | Args:
92 | results (dict): Result dict from :obj:`mmdet3d.CustomDataset`.
93 |
94 | Returns:
95 | dict: The dict containing loaded 3D bounding box, label, mask and
96 | semantic segmentation annotations.
97 | """
98 | results = super().__call__(results)
99 | if self.with_lane_3d:
100 | results = self._load_lanes_3d(results)
101 | return results
102 |
103 | def __repr__(self):
104 | """str: Return a string that describes the module."""
105 | indent_str = ' '
106 | repr_str = super().__repr__()
107 | repr_str += f'{indent_str}with_lane_3d={self.with_lane_3d}, '
108 | repr_str += f'{indent_str}with_lane_lable_3d={self.with_lane_lable_3d}, '
109 | repr_str += f'{indent_str}with_lane_adj={self.with_lane_adj}, '
110 | return repr_str
111 |
--------------------------------------------------------------------------------
/projects/toponet/datasets/pipelines/transform_3d_lane.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | from mmdet.datasets.builder import PIPELINES
9 | from shapely.geometry import LineString
10 | from ...core.lane.util import fix_pts_interpolate
11 |
12 |
13 | @PIPELINES.register_module()
14 | class LaneParameterize3D(object):
15 |
16 | def __init__(self, method, method_para):
17 | method_list = ['fix_pts_interp']
18 | self.method = method
19 | if not self.method in method_list:
20 | raise Exception("Not implemented!")
21 | self.method_para = method_para
22 |
23 | def __call__(self, results):
24 | """Call function to normalize images.
25 | Args:
26 | results (dict): Result dict from loading pipeline.
27 | Returns:
28 | dict: Normalized results, 'img_norm_cfg' key is added into
29 | result dict.
30 | """
31 | lanes = results['gt_lanes_3d']
32 | para_lanes = getattr(self, self.method)(lanes, **self.method_para)
33 | results['gt_lanes_3d'] = para_lanes
34 |
35 | return results
36 |
37 | def fix_pts_interp(self, input_data, n_points=11):
38 | '''Interpolate the 3D lanes to fix points. The input size is (n_pts, 3).
39 | '''
40 | lane_list = []
41 | for lane in input_data:
42 | if n_points == 11 and lane.shape[0] == 201:
43 | lane_list.append(lane[::20].flatten())
44 | else:
45 | lane = fix_pts_interpolate(lane, n_points).flatten()
46 | lane_list.append(lane)
47 | return np.array(lane_list, dtype=np.float32)
48 |
49 |
50 | @PIPELINES.register_module()
51 | class LaneLengthFilter(object):
52 | """Filter the 3D lanes by lane length (meters).
53 | """
54 |
55 | def __init__(self, min_length):
56 | self.min_length = min_length
57 |
58 | def __call__(self, results):
59 |
60 | if self.min_length <= 0:
61 | return results
62 |
63 | length_list = np.array(list(map(lambda x:LineString(x).length, results['gt_lanes_3d'])))
64 | masks = length_list > self.min_length
65 | results['gt_lanes_3d'] = [lane for idx, lane in enumerate(results['gt_lanes_3d']) if masks[idx]]
66 | results['gt_lane_labels_3d'] = results['gt_lane_labels_3d'][masks]
67 |
68 | if 'gt_lane_adj' in results.keys():
69 | results['gt_lane_adj'] = results['gt_lane_adj'][masks][:, masks]
70 | if 'gt_lane_lcte_adj' in results.keys():
71 | results['gt_lane_lcte_adj'] = results['gt_lane_lcte_adj'][masks]
72 |
73 | return results
74 |
--------------------------------------------------------------------------------
/projects/toponet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .detectors import *
2 | from .dense_heads import *
3 | from .modules import *
4 |
--------------------------------------------------------------------------------
/projects/toponet/models/dense_heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .toponet_head import TopoNetHead
2 | from .deformable_detr_head import CustomDeformableDETRHead
3 | from .relationship_head import SingleLayerRelationshipHead
4 |
--------------------------------------------------------------------------------
/projects/toponet/models/dense_heads/relationship_head.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | import cv2
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import mmcv
13 | from mmcv.cnn import Linear, bias_init_with_prob, build_activation_layer
14 | from mmcv.cnn.bricks.transformer import build_feedforward_network
15 | from mmcv.runner import auto_fp16, force_fp32
16 | from mmcv.utils import TORCH_VERSION, digit_version
17 | from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean
18 | from mmdet.models.builder import HEADS, build_loss
19 | from mmdet.models.dense_heads import AnchorFreeHead
20 | from mmdet.models.utils import build_transformer
21 | from mmdet.models.utils.transformer import inverse_sigmoid
22 | from mmdet3d.core.bbox.coders import build_bbox_coder
23 |
24 |
25 | class MLP(nn.Module):
26 |
27 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
28 | super().__init__()
29 | self.num_layers = num_layers
30 | h = [hidden_dim] * (num_layers - 1)
31 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
32 |
33 | def forward(self, x):
34 | for i, layer in enumerate(self.layers):
35 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
36 | return x
37 |
38 | @HEADS.register_module()
39 | class SingleLayerRelationshipHead(nn.Module):
40 | def __init__(self,
41 | in_channels_o1,
42 | in_channels_o2=None,
43 | shared_param=True,
44 | loss_rel=dict(
45 | type='FocalLoss',
46 | use_sigmoid=True,
47 | gamma=2.0,
48 | alpha=0.25)):
49 | super().__init__()
50 |
51 | self.MLP_o1 = MLP(in_channels_o1, in_channels_o1, 128, 3)
52 | self.shared_param = shared_param
53 | if shared_param:
54 | self.MLP_o2 = self.MLP_o1
55 | else:
56 | self.MLP_o2 = MLP(in_channels_o2, in_channels_o2, 128, 3)
57 | self.classifier = MLP(256, 256, 1, 3)
58 | self.loss_rel = build_loss(loss_rel)
59 |
60 | def forward(self, o1_feats, o2_feats):
61 | # feats: B, num_query, num_embedding
62 | o1_embeds = self.MLP_o1(o1_feats)
63 | o2_embeds = self.MLP_o2(o2_feats)
64 |
65 | num_query_o1 = o1_embeds.size(1)
66 | num_query_o2 = o2_embeds.size(1)
67 | o1_tensor = o1_embeds.unsqueeze(2).repeat(1, 1, num_query_o2, 1)
68 | o2_tensor = o2_embeds.unsqueeze(1).repeat(1, num_query_o1, 1, 1)
69 |
70 | relationship_tensor = torch.cat([o1_tensor, o2_tensor], dim=-1)
71 | relationship_pred = self.classifier(relationship_tensor)
72 |
73 | return relationship_pred
74 |
75 | def loss(self, rel_preds, gt_adjs, o1_assign_results, o2_assign_results):
76 | B, num_query_o1, num_query_o2, _ = rel_preds.size()
77 | o1_assign = o1_assign_results
78 | o1_pos_inds = o1_assign['pos_inds']
79 | o1_pos_assigned_gt_inds = o1_assign['pos_assigned_gt_inds']
80 |
81 | if self.shared_param:
82 | o2_assign = o1_assign
83 | o2_pos_inds = o1_pos_inds
84 | o2_pos_assigned_gt_inds = o1_pos_assigned_gt_inds
85 | else:
86 | o2_assign = o2_assign_results
87 | o2_pos_inds = o2_assign['pos_inds']
88 | o2_pos_assigned_gt_inds = o2_assign['pos_assigned_gt_inds']
89 |
90 | targets = []
91 | for i in range(B):
92 | gt_adj = gt_adjs[i]
93 | target = torch.zeros_like(rel_preds[i].squeeze(-1), dtype=gt_adj.dtype, device=rel_preds.device)
94 | xs = o1_pos_inds[i].unsqueeze(-1).repeat(1, o2_pos_inds[i].size(0))
95 | ys = o2_pos_inds[i].unsqueeze(0).repeat(o1_pos_inds[i].size(0), 1)
96 | target[xs, ys] = gt_adj[o1_pos_assigned_gt_inds[i]][:, o2_pos_assigned_gt_inds[i]]
97 | targets.append(target)
98 | targets = torch.stack(targets, dim=0)
99 |
100 | targets = 1 - targets.view(-1).long()
101 | rel_preds = rel_preds.view(-1, 1)
102 |
103 | loss_rel = self.loss_rel(rel_preds, targets)
104 |
105 | if digit_version(TORCH_VERSION) >= digit_version('1.8'):
106 | loss_rel = torch.nan_to_num(loss_rel)
107 |
108 | return dict(loss_rel=loss_rel)
109 |
110 |
--------------------------------------------------------------------------------
/projects/toponet/models/detectors/__init__.py:
--------------------------------------------------------------------------------
1 | from .toponet import TopoNet
2 |
--------------------------------------------------------------------------------
/projects/toponet/models/detectors/toponet.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import time
8 | import copy
9 | import numpy as np
10 | import torch
11 |
12 | from mmcv.runner import force_fp32, auto_fp16
13 | from mmdet.models import DETECTORS
14 | from mmdet.models.builder import build_head
15 | from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
16 |
17 | from ...utils.builder import build_bev_constructor
18 |
19 |
20 | @DETECTORS.register_module()
21 | class TopoNet(MVXTwoStageDetector):
22 |
23 | def __init__(self,
24 | bev_constructor=None,
25 | bbox_head=None,
26 | lane_head=None,
27 | video_test_mode=False,
28 | **kwargs):
29 |
30 | super(TopoNet, self).__init__(**kwargs)
31 |
32 | if bev_constructor is not None:
33 | self.bev_constructor = build_bev_constructor(bev_constructor)
34 |
35 | if bbox_head is not None:
36 | bbox_head.update(train_cfg=self.train_cfg.bbox)
37 | self.bbox_head = build_head(bbox_head)
38 | else:
39 | self.bbox_head = None
40 |
41 | if lane_head is not None:
42 | lane_head.update(train_cfg=self.train_cfg.lane)
43 | self.pts_bbox_head = build_head(lane_head)
44 | else:
45 | self.pts_bbox_head = None
46 |
47 | self.fp16_enabled = False
48 |
49 | # temporal
50 | self.video_test_mode = video_test_mode
51 | self.prev_frame_info = {
52 | 'prev_bev': None,
53 | 'scene_token': None,
54 | 'prev_pos': 0,
55 | 'prev_angle': 0,
56 | }
57 |
58 | def extract_img_feat(self, img, img_metas, len_queue=None):
59 | """Extract features of images."""
60 | B = img.size(0)
61 | if img is not None:
62 |
63 | if img.dim() == 5 and img.size(0) == 1:
64 | img.squeeze_()
65 | elif img.dim() == 5 and img.size(0) > 1:
66 | B, N, C, H, W = img.size()
67 | img = img.reshape(B * N, C, H, W)
68 | img_feats = self.img_backbone(img)
69 |
70 | if isinstance(img_feats, dict):
71 | img_feats = list(img_feats.values())
72 | else:
73 | return None
74 | if self.with_img_neck:
75 | img_feats = self.img_neck(img_feats)
76 |
77 | img_feats_reshaped = []
78 | for img_feat in img_feats:
79 | BN, C, H, W = img_feat.size()
80 | if len_queue is not None:
81 | img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
82 | else:
83 | img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
84 | return img_feats_reshaped
85 |
86 | @auto_fp16(apply_to=('img'))
87 | def extract_feat(self, img, img_metas=None, len_queue=None):
88 | img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
89 | return img_feats
90 |
91 | def forward_dummy(self, img):
92 | dummy_metas = None
93 | return self.forward_test(img=img, img_metas=[[dummy_metas]])
94 |
95 | def forward(self, return_loss=True, **kwargs):
96 | if return_loss:
97 | return self.forward_train(**kwargs)
98 | else:
99 | return self.forward_test(**kwargs)
100 |
101 | def obtain_history_bev(self, imgs_queue, img_metas_list):
102 | """Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
103 | """
104 | self.eval()
105 |
106 | with torch.no_grad():
107 | prev_bev = None
108 | bs, len_queue, num_cams, C, H, W = imgs_queue.shape
109 | imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
110 | img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue)
111 | for i in range(len_queue):
112 | img_metas = [each[i] for each in img_metas_list]
113 | img_feats = [each_scale[:, i] for each_scale in img_feats_list]
114 | prev_bev = self.bev_constructor(img_feats, img_metas, prev_bev)
115 | self.train()
116 | return prev_bev
117 |
118 | @auto_fp16(apply_to=('img'))
119 | def forward_train(self,
120 | img=None,
121 | img_metas=None,
122 | gt_labels=None,
123 | gt_bboxes=None,
124 | gt_lanes_3d=None,
125 | gt_lane_labels_3d=None,
126 | gt_lane_adj=None,
127 | gt_lane_lcte_adj=None,
128 | gt_bboxes_ignore=None,
129 | ):
130 |
131 | len_queue = img.size(1)
132 | prev_img = img[:, :-1, ...]
133 | img = img[:, -1, ...]
134 |
135 | if self.video_test_mode:
136 | prev_img_metas = copy.deepcopy(img_metas)
137 | prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
138 | else:
139 | prev_bev = None
140 |
141 | img_metas = [each[len_queue-1] for each in img_metas]
142 | img_feats = self.extract_feat(img=img, img_metas=img_metas)
143 |
144 | front_view_img_feats = [lvl[:, 0] for lvl in img_feats]
145 | batch_input_shape = tuple(img[0, 0].size()[-2:])
146 | bbox_img_metas = []
147 | for img_meta in img_metas:
148 | bbox_img_metas.append(
149 | dict(
150 | batch_input_shape=batch_input_shape,
151 | img_shape=img_meta['img_shape'][0],
152 | scale_factor=img_meta['scale_factor'][0],
153 | crop_shape=img_meta['crop_shape'][0]))
154 | img_meta['batch_input_shape'] = batch_input_shape
155 |
156 | te_losses = {}
157 | bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas)
158 | bbox_losses, te_assign_result = self.bbox_head.loss(bbox_outs, gt_bboxes, gt_labels, bbox_img_metas, gt_bboxes_ignore)
159 | te_feats = bbox_outs['history_states']
160 | te_cls_scores = bbox_outs['all_cls_scores']
161 |
162 | for loss in bbox_losses:
163 | te_losses['bbox_head.' + loss] = bbox_losses[loss]
164 |
165 | num_gt_bboxes = sum([len(gt) for gt in gt_labels])
166 | if num_gt_bboxes == 0:
167 | for loss in te_losses:
168 | te_losses[loss] *= 0
169 |
170 | losses = dict()
171 | bev_feats = self.bev_constructor(img_feats, img_metas, prev_bev)
172 | outs = self.pts_bbox_head(img_feats, bev_feats, img_metas, te_feats, te_cls_scores)
173 | loss_inputs = [outs, gt_lanes_3d, gt_lane_labels_3d, gt_lane_adj, gt_lane_lcte_adj, te_assign_result]
174 | lane_losses = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
175 | for loss in lane_losses:
176 | losses['lane_head.' + loss] = lane_losses[loss]
177 |
178 | losses.update(te_losses)
179 |
180 | return losses
181 |
182 | def forward_test(self, img_metas, img=None, **kwargs):
183 | for var, name in [(img_metas, 'img_metas')]:
184 | if not isinstance(var, list):
185 | raise TypeError('{} must be a list, but got {}'.format(
186 | name, type(var)))
187 | img = [img] if img is None else img
188 |
189 | if img_metas[0]['scene_token'] != self.prev_frame_info['scene_token']:
190 | # the first sample of each scene is truncated
191 | self.prev_frame_info['prev_bev'] = None
192 | # update idx
193 | self.prev_frame_info['scene_token'] = img_metas[0]['scene_token']
194 |
195 | # do not use temporal information
196 | if not self.video_test_mode:
197 | self.prev_frame_info['prev_bev'] = None
198 |
199 | # Get the delta of ego position and angle between two timestamps.
200 | tmp_pos = copy.deepcopy(img_metas[0]['can_bus'][:3])
201 | tmp_angle = copy.deepcopy(img_metas[0]['can_bus'][-1])
202 | if self.prev_frame_info['prev_bev'] is not None:
203 | img_metas[0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
204 | img_metas[0]['can_bus'][-1] -= self.prev_frame_info['prev_angle']
205 | else:
206 | img_metas[0]['can_bus'][-1] = 0
207 | img_metas[0]['can_bus'][:3] = 0
208 |
209 | new_prev_bev, results_list = self.simple_test(
210 | img_metas, img, prev_bev=self.prev_frame_info['prev_bev'], **kwargs)
211 | # During inference, we save the BEV features and ego motion of each timestamp.
212 | self.prev_frame_info['prev_pos'] = tmp_pos
213 | self.prev_frame_info['prev_angle'] = tmp_angle
214 | self.prev_frame_info['prev_bev'] = new_prev_bev
215 | return results_list
216 |
217 | def simple_test_pts(self, x, img_metas, img=None, prev_bev=None, rescale=False):
218 | """Test function"""
219 | batchsize = len(img_metas)
220 |
221 | front_view_img_feats = [lvl[:, 0] for lvl in x]
222 | batch_input_shape = tuple(img[0, 0].size()[-2:])
223 | bbox_img_metas = []
224 | for img_meta in img_metas:
225 | bbox_img_metas.append(
226 | dict(
227 | batch_input_shape=batch_input_shape,
228 | img_shape=img_meta['img_shape'][0],
229 | scale_factor=img_meta['scale_factor'][0],
230 | crop_shape=img_meta['crop_shape'][0]))
231 | img_meta['batch_input_shape'] = batch_input_shape
232 | bbox_outs = self.bbox_head(front_view_img_feats, bbox_img_metas)
233 | bbox_results = self.bbox_head.get_bboxes(bbox_outs, bbox_img_metas, rescale=rescale)
234 | te_feats = bbox_outs['history_states']
235 | te_cls_scores = bbox_outs['all_cls_scores']
236 | bev_feats = self.bev_constructor(x, img_metas, prev_bev)
237 |
238 | outs = self.pts_bbox_head(x, bev_feats, img_metas, te_feats, te_cls_scores)
239 | lane_results, lclc_results, lcte_results = self.pts_bbox_head.get_lanes(
240 | outs, img_metas, rescale=rescale)
241 |
242 | return bev_feats, bbox_results, lane_results, lclc_results, lcte_results
243 |
244 | def simple_test(self, img_metas, img=None, prev_bev=None, rescale=False):
245 | """Test function without augmentaiton."""
246 | img_feats = self.extract_feat(img=img, img_metas=img_metas)
247 |
248 | results_list = [dict() for i in range(len(img_metas))]
249 | new_prev_bev, bbox_results, lane_results, lclc_results, lcte_results = self.simple_test_pts(
250 | img_feats, img_metas, img, prev_bev, rescale=rescale)
251 | for result_dict, bbox, lane, lclc, lcte in zip(results_list, bbox_results, lane_results, lclc_results, lcte_results):
252 | result_dict['bbox_results'] = bbox
253 | result_dict['lane_results'] = lane
254 | result_dict['lclc_results'] = lclc
255 | result_dict['lcte_results'] = lcte
256 | return new_prev_bev, results_list
257 |
--------------------------------------------------------------------------------
/projects/toponet/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .bevformer_constructer import BEVFormerConstructer
2 | from .transformer_decoder_only import TopoNetTransformerDecoderOnly
3 | from .sgnn_decoder import TopoNetSGNNDecoder, SGNNDecoderLayer, FFN_SGNN
4 |
--------------------------------------------------------------------------------
/projects/toponet/models/modules/bevformer_constructer.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.init import normal_
11 | from torchvision.transforms.functional import rotate
12 |
13 | from mmcv.cnn import xavier_init
14 | from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence, build_positional_encoding
15 | from mmcv.runner.base_module import BaseModule
16 | from mmcv.runner import force_fp32, auto_fp16
17 |
18 | from ...utils.builder import BEV_CONSTRUCTOR
19 | from projects.bevformer.modules.temporal_self_attention import TemporalSelfAttention
20 | from projects.bevformer.modules.spatial_cross_attention import MSDeformableAttention3D
21 | from projects.bevformer.modules.decoder import CustomMSDeformableAttention
22 |
23 |
24 | @BEV_CONSTRUCTOR.register_module()
25 | class BEVFormerConstructer(BaseModule):
26 | """Implements the BEVFormer BEV Constructer.
27 | Args:
28 | as_two_stage (bool): Generate query from encoder features.
29 | Default: False.
30 | num_feature_levels (int): Number of feature maps from FPN:
31 | Default: 4.
32 | two_stage_num_proposals (int): Number of proposals when set
33 | `as_two_stage` as True. Default: 300.
34 | """
35 |
36 | def __init__(self,
37 | num_feature_levels=4,
38 | num_cams=6,
39 | embed_dims=256,
40 | rotate_prev_bev=True,
41 | use_shift=True,
42 | use_can_bus=True,
43 | can_bus_norm=True,
44 | use_cams_embeds=True,
45 | pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
46 | bev_h=200,
47 | bev_w=200,
48 | rotate_center=[100, 100],
49 | encoder=None,
50 | positional_encoding=None,
51 | **kwargs):
52 | super(BEVFormerConstructer, self).__init__(**kwargs)
53 | self.embed_dims = embed_dims
54 | self.num_feature_levels = num_feature_levels
55 | self.num_cams = num_cams
56 | self.fp16_enabled = False
57 |
58 | self.rotate_prev_bev = rotate_prev_bev
59 | self.use_shift = use_shift
60 | self.use_can_bus = use_can_bus
61 | self.can_bus_norm = can_bus_norm
62 | self.use_cams_embeds = use_cams_embeds
63 | self.encoder = build_transformer_layer_sequence(encoder)
64 | self.positional_encoding = build_positional_encoding(positional_encoding)
65 |
66 | self.pc_range = pc_range
67 | self.real_w = self.pc_range[3] - self.pc_range[0]
68 | self.real_h = self.pc_range[4] - self.pc_range[1]
69 | self.bev_h = bev_h
70 | self.bev_w = bev_w
71 | self.rotate_center = rotate_center
72 |
73 | self.init_layers()
74 |
75 | def init_layers(self):
76 | self.bev_embedding = nn.Embedding(
77 | self.bev_h * self.bev_w, self.embed_dims)
78 | self.level_embeds = nn.Parameter(torch.Tensor(
79 | self.num_feature_levels, self.embed_dims))
80 | self.cams_embeds = nn.Parameter(
81 | torch.Tensor(self.num_cams, self.embed_dims))
82 | self.can_bus_mlp = nn.Sequential(
83 | nn.Linear(18, self.embed_dims // 2),
84 | nn.ReLU(inplace=True),
85 | nn.Linear(self.embed_dims // 2, self.embed_dims),
86 | nn.ReLU(inplace=True),
87 | )
88 | if self.can_bus_norm:
89 | self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
90 |
91 | def init_weights(self):
92 | """Initialize the transformer weights."""
93 | for p in self.parameters():
94 | if p.dim() > 1:
95 | nn.init.xavier_uniform_(p)
96 | for m in self.modules():
97 | if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
98 | or isinstance(m, CustomMSDeformableAttention):
99 | try:
100 | m.init_weight()
101 | except AttributeError:
102 | m.init_weights()
103 | normal_(self.level_embeds)
104 | normal_(self.cams_embeds)
105 | xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)
106 |
107 | # @auto_fp16(apply_to=('mlvl_feats', 'prev_bev'))
108 | def forward(self, mlvl_feats, img_metas, prev_bev=None, **kwargs):
109 | """
110 | obtain bev features.
111 | """
112 | bs, num_cam, _, _, _ = mlvl_feats[0].shape
113 | dtype = mlvl_feats[0].dtype
114 |
115 | bev_queries = self.bev_embedding.weight.to(dtype)
116 | bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
117 |
118 | bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
119 | device=bev_queries.device).to(dtype)
120 | bev_pos = self.positional_encoding(bev_mask).to(dtype)
121 | bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
122 |
123 | # BEVFormer assumes the coords are x-right and y-forward for the nuScenes lidar
124 | # but OpenLane-V2's coords are x-forward and y-left
125 | # here is a fix for any lidar coords, the shift is calculated by the rotation matrix
126 | delta_global = np.array([each['can_bus'][:3] for each in img_metas])
127 | lidar2global_rotation = np.array([each['lidar2global_rotation'] for each in img_metas])
128 | delta_lidar = []
129 | for i in range(bs):
130 | delta_lidar.append(np.linalg.inv(lidar2global_rotation[i]) @ delta_global[i])
131 | delta_lidar = np.array(delta_lidar)
132 | shift_y = delta_lidar[:, 1] / self.real_h
133 | shift_x = delta_lidar[:, 0] / self.real_w
134 | shift_y = shift_y * self.use_shift
135 | shift_x = shift_x * self.use_shift
136 | shift = bev_queries.new_tensor([shift_x, shift_y]).permute(1, 0) # xy, bs -> bs, xy
137 |
138 | if prev_bev is not None:
139 | if prev_bev.shape[1] == self.bev_h * self.bev_w:
140 | prev_bev = prev_bev.permute(1, 0, 2)
141 | if self.rotate_prev_bev:
142 | for i in range(bs):
143 | # num_prev_bev = prev_bev.size(1)
144 | rotation_angle = img_metas[i]['can_bus'][-1]
145 | tmp_prev_bev = prev_bev[:, i].reshape(
146 | self.bev_h, self.bev_w, -1).permute(2, 0, 1)
147 | tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
148 | center=self.rotate_center)
149 | tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
150 | self.bev_h * self.bev_w, 1, -1)
151 | prev_bev[:, i] = tmp_prev_bev[:, 0]
152 |
153 | # add can bus signals
154 | can_bus = bev_queries.new_tensor(
155 | [each['can_bus'] for each in img_metas]) # [:, :]
156 | can_bus = self.can_bus_mlp(can_bus)[None, :, :]
157 | bev_queries = bev_queries + can_bus * self.use_can_bus
158 |
159 | feat_flatten = []
160 | spatial_shapes = []
161 | for lvl, feat in enumerate(mlvl_feats):
162 | bs, num_cam, c, h, w = feat.shape
163 | spatial_shape = (h, w)
164 | feat = feat.flatten(3).permute(1, 0, 3, 2)
165 | if self.use_cams_embeds:
166 | feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
167 | feat = feat + self.level_embeds[None,
168 | None, lvl:lvl + 1, :].to(feat.dtype)
169 | spatial_shapes.append(spatial_shape)
170 | feat_flatten.append(feat)
171 |
172 | feat_flatten = torch.cat(feat_flatten, 2)
173 | spatial_shapes = torch.as_tensor(
174 | spatial_shapes, dtype=torch.long, device=bev_pos.device)
175 | level_start_index = torch.cat((spatial_shapes.new_zeros(
176 | (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
177 |
178 | feat_flatten = feat_flatten.permute(
179 | 0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims)
180 |
181 | bev_embed = self.encoder(
182 | bev_queries,
183 | feat_flatten,
184 | feat_flatten,
185 | bev_h=self.bev_h,
186 | bev_w=self.bev_w,
187 | bev_pos=bev_pos,
188 | spatial_shapes=spatial_shapes,
189 | level_start_index=level_start_index,
190 | prev_bev=prev_bev,
191 | shift=shift,
192 | img_metas=img_metas,
193 | **kwargs
194 | )
195 |
196 | return bev_embed
197 |
198 |
--------------------------------------------------------------------------------
/projects/toponet/models/modules/sgnn_decoder.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import copy
8 | import warnings
9 | import math
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import mmcv
15 | from mmcv.cnn import Linear, build_activation_layer
16 | from mmcv.cnn.bricks.drop import build_dropout
17 | from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER, FEEDFORWARD_NETWORK,
18 | TRANSFORMER_LAYER_SEQUENCE)
19 | from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence
20 | from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
21 |
22 |
23 | @TRANSFORMER_LAYER_SEQUENCE.register_module()
24 | class TopoNetSGNNDecoder(TransformerLayerSequence):
25 |
26 | def __init__(self, *args, return_intermediate=False, **kwargs):
27 | super(TopoNetSGNNDecoder, self).__init__(*args, **kwargs)
28 | self.return_intermediate = return_intermediate
29 | self.fp16_enabled = False
30 |
31 | def forward(self,
32 | query,
33 | *args,
34 | reference_points=None,
35 | lclc_branches=None,
36 | lcte_branches=None,
37 | key_padding_mask=None,
38 | te_feats=None,
39 | te_cls_scores=None,
40 | **kwargs):
41 |
42 | output = query
43 | intermediate = []
44 | intermediate_reference_points = []
45 | intermediate_lclc_rel = []
46 | intermediate_lcte_rel = []
47 | num_query = query.size(0)
48 | num_te_query = te_feats.size(2)
49 |
50 | prev_lclc_adj = torch.zeros((query.size(1), num_query, num_query),
51 | dtype=query.dtype, device=query.device)
52 | prev_lcte_adj = torch.zeros((query.size(1), num_query, num_te_query),
53 | dtype=query.dtype, device=query.device)
54 | for lid, layer in enumerate(self.layers):
55 | reference_points_input = reference_points[..., :2].unsqueeze(
56 | 2) # BS NUM_QUERY NUM_LEVEL 2
57 | output = layer(
58 | output,
59 | *args,
60 | reference_points=reference_points_input,
61 | key_padding_mask=key_padding_mask,
62 | te_query=te_feats[lid],
63 | te_cls_scores=te_cls_scores[lid],
64 | lclc_adj=prev_lclc_adj,
65 | lcte_adj=prev_lcte_adj,
66 | **kwargs)
67 | output = output.permute(1, 0, 2)
68 |
69 | lclc_rel_out = lclc_branches[lid](output, output)
70 | lclc_rel_adj = lclc_rel_out.squeeze(-1).sigmoid()
71 | prev_lclc_adj = lclc_rel_adj.detach()
72 |
73 | lcte_rel_out = lcte_branches[lid](output, te_feats[lid])
74 | lcte_rel_adj = lcte_rel_out.squeeze(-1).sigmoid()
75 | prev_lcte_adj = lcte_rel_adj.detach()
76 |
77 | output = output.permute(1, 0, 2)
78 |
79 | if self.return_intermediate:
80 | intermediate.append(output)
81 | intermediate_reference_points.append(reference_points)
82 | intermediate_lclc_rel.append(lclc_rel_out)
83 | intermediate_lcte_rel.append(lcte_rel_out)
84 |
85 | if self.return_intermediate:
86 | return torch.stack(intermediate), torch.stack(
87 | intermediate_reference_points), torch.stack(
88 | intermediate_lclc_rel), torch.stack(
89 | intermediate_lcte_rel)
90 |
91 | return output, reference_points, lclc_rel_out, lcte_rel_out
92 |
93 |
94 | @TRANSFORMER_LAYER.register_module()
95 | class SGNNDecoderLayer(BaseTransformerLayer):
96 | """Implements decoder layer in DETR transformer.
97 |
98 | Args:
99 | attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
100 | Configs for self_attention or cross_attention, the order
101 | should be consistent with it in `operation_order`. If it is
102 | a dict, it would be expand to the number of attention in
103 | `operation_order`.
104 | feedforward_channels (int): The hidden dimension for FFNs.
105 | ffn_dropout (float): Probability of an element to be zeroed
106 | in ffn. Default 0.0.
107 | operation_order (tuple[str]): The execution order of operation
108 | in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
109 | Default:None
110 | act_cfg (dict): The activation config for FFNs. Default: `LN`
111 | norm_cfg (dict): Config dict for normalization layer.
112 | Default: `LN`.
113 | ffn_num_fcs (int): The number of fully-connected layers in FFNs.
114 | Default:2.
115 | """
116 |
117 | def __init__(self,
118 | attn_cfgs,
119 | ffn_cfgs,
120 | operation_order=None,
121 | norm_cfg=dict(type='LN'),
122 | **kwargs):
123 | super(SGNNDecoderLayer, self).__init__(
124 | attn_cfgs=attn_cfgs,
125 | ffn_cfgs=ffn_cfgs,
126 | operation_order=operation_order,
127 | norm_cfg=norm_cfg,
128 | **kwargs)
129 | assert len(operation_order) == 6
130 | assert set(operation_order) == set(
131 | ['self_attn', 'norm', 'cross_attn', 'ffn'])
132 |
133 | def forward(self,
134 | query,
135 | key=None,
136 | value=None,
137 | query_pos=None,
138 | key_pos=None,
139 | attn_masks=None,
140 | query_key_padding_mask=None,
141 | key_padding_mask=None,
142 | te_query=None,
143 | te_cls_scores=None,
144 | lclc_adj=None,
145 | lcte_adj=None,
146 | **kwargs):
147 |
148 | norm_index = 0
149 | attn_index = 0
150 | ffn_index = 0
151 | identity = query
152 | if attn_masks is None:
153 | attn_masks = [None for _ in range(self.num_attn)]
154 | elif isinstance(attn_masks, torch.Tensor):
155 | attn_masks = [
156 | copy.deepcopy(attn_masks) for _ in range(self.num_attn)
157 | ]
158 | warnings.warn(f'Use same attn_mask in all attentions in '
159 | f'{self.__class__.__name__} ')
160 | else:
161 | assert len(attn_masks) == self.num_attn, f'The length of ' \
162 | f'attn_masks {len(attn_masks)} must be equal ' \
163 | f'to the number of attention in ' \
164 | f'operation_order {self.num_attn}'
165 |
166 | for layer in self.operation_order:
167 | if layer == 'self_attn':
168 | temp_key = temp_value = query
169 | query = self.attentions[attn_index](
170 | query,
171 | temp_key,
172 | temp_value,
173 | identity if self.pre_norm else None,
174 | query_pos=query_pos,
175 | key_pos=query_pos,
176 | attn_mask=attn_masks[attn_index],
177 | key_padding_mask=query_key_padding_mask,
178 | **kwargs)
179 | attn_index += 1
180 | identity = query
181 |
182 | elif layer == 'norm':
183 | query = self.norms[norm_index](query)
184 | norm_index += 1
185 |
186 | elif layer == 'cross_attn':
187 | query = self.attentions[attn_index](
188 | query,
189 | key,
190 | value,
191 | identity if self.pre_norm else None,
192 | query_pos=query_pos,
193 | key_pos=key_pos,
194 | attn_mask=attn_masks[attn_index],
195 | key_padding_mask=key_padding_mask,
196 | **kwargs)
197 | attn_index += 1
198 | identity = query
199 |
200 | elif layer == 'ffn':
201 | query = self.ffns[ffn_index](
202 | query, te_query, lclc_adj, lcte_adj, te_cls_scores, identity=identity if self.pre_norm else None)
203 | ffn_index += 1
204 |
205 | return query
206 |
207 |
208 | @FEEDFORWARD_NETWORK.register_module()
209 | class FFN_SGNN(BaseModule):
210 |
211 | def __init__(self,
212 | embed_dims=256,
213 | feedforward_channels=512,
214 | num_fcs=2,
215 | act_cfg=dict(type='ReLU', inplace=True),
216 | ffn_drop=0.1,
217 | dropout_layer=None,
218 | add_identity=True,
219 | init_cfg=None,
220 | edge_weight=0.5,
221 | num_te_classes=13,
222 | **kwargs):
223 | super(FFN_SGNN, self).__init__(init_cfg)
224 | assert num_fcs >= 2, 'num_fcs should be no less ' \
225 | f'than 2. got {num_fcs}.'
226 | self.embed_dims = embed_dims
227 | self.feedforward_channels = feedforward_channels
228 | self.num_fcs = num_fcs
229 | self.act_cfg = act_cfg
230 | self.activate = build_activation_layer(act_cfg)
231 |
232 | layers = []
233 | in_channels = embed_dims
234 | for _ in range(num_fcs - 1):
235 | layers.append(
236 | Sequential(
237 | Linear(in_channels, feedforward_channels), self.activate,
238 | nn.Dropout(ffn_drop)))
239 | in_channels = feedforward_channels
240 | layers.append(
241 | Sequential(
242 | Linear(feedforward_channels, embed_dims), self.activate,
243 | nn.Dropout(ffn_drop)))
244 | self.layers = Sequential(*layers)
245 | self.edge_weight = edge_weight
246 |
247 | self.lclc_gnn_layer = LclcSkgGCNLayer(embed_dims, embed_dims, edge_weight=edge_weight)
248 | self.lcte_gnn_layer = LcteSkgGCNLayer(embed_dims, embed_dims, num_te_classes=num_te_classes, edge_weight=edge_weight)
249 |
250 | self.downsample = nn.Linear(embed_dims * 2, embed_dims)
251 |
252 | self.gnn_dropout1 = nn.Dropout(ffn_drop)
253 | self.gnn_dropout2 = nn.Dropout(ffn_drop)
254 |
255 | self.dropout_layer = build_dropout(
256 | dropout_layer) if dropout_layer else torch.nn.Identity()
257 | self.add_identity = add_identity
258 |
259 | def forward(self, lc_query, te_query, lclc_adj, lcte_adj, te_cls_scores, identity=None):
260 |
261 | out = self.layers(lc_query)
262 | out = out.permute(1, 0, 2)
263 |
264 | lclc_features = self.lclc_gnn_layer(out, lclc_adj)
265 | lcte_features = self.lcte_gnn_layer(te_query, lcte_adj, te_cls_scores)
266 |
267 | out = torch.cat([lclc_features, lcte_features], dim=-1)
268 |
269 | out = self.activate(out)
270 | out = self.gnn_dropout1(out)
271 | out = self.downsample(out)
272 | out = self.gnn_dropout2(out)
273 |
274 | out = out.permute(1, 0, 2)
275 | if not self.add_identity:
276 | return self.dropout_layer(out)
277 | if identity is None:
278 | identity = lc_query
279 | return identity + self.dropout_layer(out)
280 |
281 |
282 | class LclcSkgGCNLayer(nn.Module):
283 |
284 | def __init__(self, in_features, out_features, edge_weight=0.5):
285 | super(LclcSkgGCNLayer, self).__init__()
286 | self.edge_weight = edge_weight
287 |
288 | if self.edge_weight != 0:
289 | self.weight_forward = torch.Tensor(in_features, out_features)
290 | self.weight_forward = nn.Parameter(nn.init.xavier_uniform_(self.weight_forward))
291 | self.weight_backward = torch.Tensor(in_features, out_features)
292 | self.weight_backward = nn.Parameter(nn.init.xavier_uniform_(self.weight_backward))
293 |
294 | self.weight = torch.Tensor(in_features, out_features)
295 | self.weight = nn.Parameter(nn.init.xavier_uniform_(self.weight))
296 | self.edge_weight = edge_weight
297 |
298 | def forward(self, input, adj):
299 |
300 | support_loop = torch.matmul(input, self.weight)
301 | output = support_loop
302 |
303 | if self.edge_weight != 0:
304 | support_forward = torch.matmul(input, self.weight_forward)
305 | output_forward = torch.matmul(adj, support_forward)
306 | output += self.edge_weight * output_forward
307 |
308 | support_backward = torch.matmul(input, self.weight_backward)
309 | output_backward = torch.matmul(adj.permute(0, 2, 1), support_backward)
310 | output += self.edge_weight * output_backward
311 |
312 | return output
313 |
314 |
315 | class LcteSkgGCNLayer(nn.Module):
316 |
317 | def __init__(self, in_features, out_features, num_te_classes=13, edge_weight=0.5):
318 | super(LcteSkgGCNLayer, self).__init__()
319 | self.weight = torch.Tensor(num_te_classes, in_features, out_features)
320 | self.weight = nn.Parameter(nn.init.xavier_uniform_(self.weight))
321 | self.edge_weight = edge_weight
322 |
323 | def forward(self, te_query, lcte_adj, te_cls_scores):
324 | # te_cls_scores: (bs, num_te_query, num_te_classes)
325 | cls_scores = te_cls_scores.detach().sigmoid().unsqueeze(3)
326 | # te_query: (bs, num_te_query, embed_dims)
327 | # (bs, num_te_query, 1, embed_dims) * (bs, num_te_query, num_te_classes, 1)
328 | te_feats = te_query.unsqueeze(2) * cls_scores
329 | # (bs, num_te_classes, num_te_query, embed_dims)
330 | te_feats = te_feats.permute(0, 2, 1, 3)
331 |
332 | support = torch.matmul(te_feats, self.weight).sum(1)
333 | adj = lcte_adj * self.edge_weight
334 | output = torch.matmul(adj, support)
335 | return output
336 |
--------------------------------------------------------------------------------
/projects/toponet/models/modules/transformer_decoder_only.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from mmcv.cnn import xavier_init
11 | from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
12 | from mmcv.runner import auto_fp16, force_fp32
13 | from mmcv.runner.base_module import BaseModule
14 | from mmdet.models.utils.builder import TRANSFORMER
15 |
16 | from projects.bevformer.modules.decoder import CustomMSDeformableAttention
17 | from projects.bevformer.modules.spatial_cross_attention import \
18 | MSDeformableAttention3D
19 | from projects.bevformer.modules.temporal_self_attention import \
20 | TemporalSelfAttention
21 |
22 |
23 | @TRANSFORMER.register_module()
24 | class TopoNetTransformerDecoderOnly(BaseModule):
25 | """Implements the Detr3D transformer.
26 | Args:
27 | as_two_stage (bool): Generate query from encoder features.
28 | Default: False.
29 | num_feature_levels (int): Number of feature maps from FPN:
30 | Default: 4.
31 | two_stage_num_proposals (int): Number of proposals when set
32 | `as_two_stage` as True. Default: 300.
33 | """
34 |
35 | def __init__(self,
36 | decoder=None,
37 | embed_dims=256,
38 | pts_dim=3,
39 | **kwargs):
40 | super(TopoNetTransformerDecoderOnly, self).__init__(**kwargs)
41 | self.decoder = build_transformer_layer_sequence(decoder)
42 | self.embed_dims = embed_dims
43 | self.fp16_enabled = False
44 | self.pts_dim = pts_dim
45 | self.init_layers()
46 |
47 | def init_layers(self):
48 | """Initialize layers of the Detr3DTransformer."""
49 | self.reference_points = nn.Linear(self.embed_dims, self.pts_dim)
50 |
51 | def init_weights(self):
52 | """Initialize the transformer weights."""
53 | for p in self.parameters():
54 | if p.dim() > 1:
55 | nn.init.xavier_uniform_(p)
56 | for m in self.modules():
57 | if isinstance(m, MSDeformableAttention3D) or isinstance(m, TemporalSelfAttention) \
58 | or isinstance(m, CustomMSDeformableAttention):
59 | try:
60 | m.init_weight()
61 | except AttributeError:
62 | m.init_weights()
63 | xavier_init(self.reference_points, distribution='uniform', bias=0.)
64 |
65 | @auto_fp16(apply_to=('mlvl_feats', 'bev_queries', 'object_query_embed', 'prev_bev', 'bev_pos'))
66 | def forward(self,
67 | mlvl_feats,
68 | bev_embed,
69 | object_query_embed,
70 | bev_h,
71 | bev_w,
72 | lclc_branches=None,
73 | lcte_branches=None,
74 | te_feats=None,
75 | te_cls_scores=None,
76 | **kwargs):
77 |
78 | bs = mlvl_feats[0].size(0)
79 | query_pos, query = torch.split(
80 | object_query_embed, self.embed_dims, dim=1)
81 | query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
82 | query = query.unsqueeze(0).expand(bs, -1, -1)
83 | reference_points = self.reference_points(query_pos)
84 | reference_points = reference_points.sigmoid()
85 | init_reference_out = reference_points
86 |
87 | query = query.permute(1, 0, 2)
88 | query_pos = query_pos.permute(1, 0, 2)
89 | bev_embed = bev_embed.permute(1, 0, 2)
90 | inter_states, inter_references, inter_lclc_rel, inter_lcte_rel = self.decoder(
91 | query=query,
92 | key=None,
93 | value=bev_embed,
94 | query_pos=query_pos,
95 | reference_points=reference_points,
96 | lclc_branches=lclc_branches,
97 | lcte_branches=lcte_branches,
98 | te_feats=te_feats,
99 | te_cls_scores=te_cls_scores,
100 | spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
101 | level_start_index=torch.tensor([0], device=query.device),
102 | **kwargs)
103 |
104 | inter_references_out = inter_references
105 |
106 | return inter_states, init_reference_out, inter_references_out, inter_lclc_rel, inter_lcte_rel
107 |
--------------------------------------------------------------------------------
/projects/toponet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_bev_constructor
2 |
--------------------------------------------------------------------------------
/projects/toponet/utils/builder.py:
--------------------------------------------------------------------------------
1 | #---------------------------------------------------------------------------------------#
2 | # Graph-based Topology Reasoning for Driving Scenes (https://arxiv.org/abs/2304.05277) #
3 | # Source code: https://github.com/OpenDriveLab/TopoNet #
4 | # Copyright (c) OpenDriveLab. All rights reserved. #
5 | #---------------------------------------------------------------------------------------#
6 |
7 | import torch.nn as nn
8 | from mmcv.utils import Registry, build_from_cfg
9 |
10 | BEV_CONSTRUCTOR = Registry('BEV Constructor')
11 |
12 | def build_bev_constructor(cfg, default_args=None):
13 | """Builder for BEV Constructor."""
14 | return build_from_cfg(cfg, BEV_CONSTRUCTOR, default_args)
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | similaritymeasures==0.6.0
2 | numpy==1.22.4
3 | scipy==1.8.0
4 | ortools==9.2.9972
5 | setuptools==59.5.0
6 | openlanev2==1.1.0
7 |
--------------------------------------------------------------------------------
/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -x
3 |
4 | timestamp=`date +"%y%m%d.%H%M%S"`
5 |
6 | WORK_DIR=work_dirs/toponet
7 | CONFIG=projects/configs/toponet_r50_8x1_24e_olv2_subset_A.py
8 |
9 | CHECKPOINT=${WORK_DIR}/latest.pth
10 |
11 | GPUS=$1
12 | PORT=${PORT:-28510}
13 |
14 | python -m torch.distributed.run --nproc_per_node=$GPUS --master_port=$PORT \
15 | tools/test.py $CONFIG $CHECKPOINT --launcher pytorch \
16 | --out-dir ${WORK_DIR}/test --eval openlane_v2 ${@:2} \
17 | 2>&1 | tee ${WORK_DIR}/test.${timestamp}.log
18 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -x
3 |
4 | timestamp=`date +"%y%m%d.%H%M%S"`
5 |
6 | WORK_DIR=work_dirs/toponet
7 | CONFIG=projects/configs/toponet_r50_8x1_24e_olv2_subset_A.py
8 |
9 | GPUS=$1
10 | PORT=${PORT:-28510}
11 |
12 | python -m torch.distributed.run --nproc_per_node=$GPUS --master_port=$PORT \
13 | tools/train.py $CONFIG --launcher pytorch --work-dir ${WORK_DIR} --deterministic ${@:2} \
14 | 2>&1 | tee ${WORK_DIR}/train.${timestamp}.log
15 |
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Tianyu Li
5 | # ---------------------------------------------
6 | import argparse
7 | import os
8 | import os.path as osp
9 | import time
10 | import warnings
11 |
12 | import torch
13 | import mmcv
14 | from mmcv import Config, DictAction
15 | from mmcv.cnn import fuse_conv_bn
16 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
17 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
18 | wrap_fp16_model)
19 | from mmcv.utils import get_logger
20 | from mmdet.apis import multi_gpu_test, set_random_seed
21 | from mmdet.datasets import replace_ImageToTensor
22 | from mmdet3d.apis import single_gpu_test
23 | from mmdet3d.datasets import build_dataloader, build_dataset
24 | from mmdet3d.models import build_model
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser(
28 | description='MMDet test (and eval) a model')
29 | parser.add_argument('config', help='test config file path')
30 | parser.add_argument('checkpoint', help='checkpoint file')
31 | parser.add_argument('--out',
32 | action='store_true',
33 | help='output result file in pickle format')
34 | parser.add_argument('--input',
35 | action='store_true'
36 | )
37 | parser.add_argument(
38 | '--fuse-conv-bn',
39 | action='store_true',
40 | help='Whether to fuse conv and bn, this will slightly increase'
41 | 'the inference speed')
42 | parser.add_argument(
43 | '--format-only',
44 | action='store_true',
45 | help='Format the output results without perform evaluation. It is'
46 | 'useful when you want to format the result to a specific format and '
47 | 'submit it to the test server')
48 | parser.add_argument(
49 | '--eval',
50 | type=str,
51 | nargs='+',
52 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
53 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
54 | parser.add_argument('--show', action='store_true', help='show results')
55 | parser.add_argument(
56 | '--out-dir', help='directory where results will be saved')
57 | parser.add_argument(
58 | '--gpu-collect',
59 | action='store_true',
60 | help='whether to use gpu to collect results.')
61 | parser.add_argument('--seed', type=int, default=0, help='random seed')
62 | parser.add_argument(
63 | '--deterministic',
64 | action='store_true',
65 | help='whether to set deterministic options for CUDNN backend.')
66 | parser.add_argument(
67 | '--cfg-options',
68 | nargs='+',
69 | action=DictAction,
70 | help='override some settings in the used config, the key-value pair '
71 | 'in xxx=yyy format will be merged into config file. If the value to '
72 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
73 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
74 | 'Note that the quotation marks are necessary and that no white space '
75 | 'is allowed.')
76 | parser.add_argument(
77 | '--options',
78 | nargs='+',
79 | action=DictAction,
80 | help='custom options for evaluation, the key-value pair in xxx=yyy '
81 | 'format will be kwargs for dataset.evaluate() function (deprecate), '
82 | 'change to --eval-options instead.')
83 | parser.add_argument(
84 | '--eval-options',
85 | nargs='+',
86 | action=DictAction,
87 | help='custom options for evaluation, the key-value pair in xxx=yyy '
88 | 'format will be kwargs for dataset.evaluate() function')
89 | parser.add_argument(
90 | '--launcher',
91 | choices=['none', 'pytorch', 'slurm', 'mpi'],
92 | default='none',
93 | help='job launcher')
94 | parser.add_argument('--local_rank', type=int, default=0)
95 | args = parser.parse_args()
96 | if 'LOCAL_RANK' not in os.environ:
97 | os.environ['LOCAL_RANK'] = str(args.local_rank)
98 |
99 | if args.options and args.eval_options:
100 | raise ValueError(
101 | '--options and --eval-options cannot be both specified, '
102 | '--options is deprecated in favor of --eval-options')
103 | if args.options:
104 | warnings.warn('--options is deprecated in favor of --eval-options')
105 | args.eval_options = args.options
106 | return args
107 |
108 |
109 | def main():
110 | args = parse_args()
111 |
112 | assert args.out or args.eval or args.format_only or args.show \
113 | ('Please specify at least one operation (save/eval/format/show the '
114 | 'results / save the results) with the argument "--out", "--eval"'
115 | ', "--format-only" or "--show"')
116 |
117 | if args.eval and args.format_only:
118 | raise ValueError('--eval and --format_only cannot be both specified')
119 |
120 | cfg = Config.fromfile(args.config)
121 | if args.cfg_options is not None:
122 | cfg.merge_from_dict(args.cfg_options)
123 | # import modules from string list.
124 | if cfg.get('custom_imports', None):
125 | from mmcv.utils import import_modules_from_strings
126 | import_modules_from_strings(**cfg['custom_imports'])
127 |
128 | # import modules from plguin/xx, registry will be updated
129 | if hasattr(cfg, 'plugin'):
130 | if cfg.plugin:
131 | import importlib
132 | if hasattr(cfg, 'plugin_dir'):
133 | plugin_dir = cfg.plugin_dir
134 | _module_dir = os.path.dirname(plugin_dir)
135 | _module_dir = _module_dir.split('/')
136 | _module_path = _module_dir[0]
137 |
138 | for m in _module_dir[1:]:
139 | _module_path = _module_path + '.' + m
140 | print(_module_path)
141 | plg_lib = importlib.import_module(_module_path)
142 | else:
143 | # import dir is the dirpath for the config file
144 | _module_dir = os.path.dirname(args.config)
145 | _module_dir = _module_dir.split('/')
146 | _module_path = _module_dir[0]
147 | for m in _module_dir[1:]:
148 | _module_path = _module_path + '.' + m
149 | print(_module_path)
150 | plg_lib = importlib.import_module(_module_path)
151 |
152 | # set cudnn_benchmark
153 | if cfg.get('cudnn_benchmark', False):
154 | torch.backends.cudnn.benchmark = True
155 |
156 | cfg.model.pretrained = None
157 | # in case the test dataset is concatenated
158 | samples_per_gpu = 1
159 | if isinstance(cfg.data.test, dict):
160 | cfg.data.test.test_mode = True
161 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
162 | if samples_per_gpu > 1:
163 | # Replace 'ImageToTensor' to 'DefaultFormatBundle'
164 | cfg.data.test.pipeline = replace_ImageToTensor(
165 | cfg.data.test.pipeline)
166 | elif isinstance(cfg.data.test, list):
167 | for ds_cfg in cfg.data.test:
168 | ds_cfg.test_mode = True
169 | samples_per_gpu = max(
170 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
171 | if samples_per_gpu > 1:
172 | for ds_cfg in cfg.data.test:
173 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
174 |
175 | # init distributed env first, since logger depends on the dist info.
176 | if args.launcher == 'none':
177 | distributed = False
178 | else:
179 | distributed = True
180 | init_dist(args.launcher, **cfg.dist_params)
181 |
182 | logger = get_logger(name='mmdet', log_level=cfg.log_level)
183 | mmcv.mkdir_or_exist(osp.abspath(args.out_dir))
184 | # set random seeds
185 | if args.seed is not None:
186 | set_random_seed(args.seed, deterministic=args.deterministic)
187 |
188 | # build the dataloader
189 | dataset = build_dataset(cfg.data.test)
190 |
191 | if args.input:
192 | logger.info(f'Loading results from results.pkl')
193 | outputs = mmcv.load(os.path.join(args.out_dir, 'results.pkl'))
194 | else:
195 | data_loader = build_dataloader(
196 | dataset,
197 | samples_per_gpu=samples_per_gpu,
198 | workers_per_gpu=cfg.data.workers_per_gpu,
199 | dist=distributed,
200 | shuffle=False,
201 | )
202 |
203 | # build the model and load checkpoint
204 | # cfg.model.train_cfg = None
205 | model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
206 | fp16_cfg = cfg.get('fp16', None)
207 | if fp16_cfg is not None:
208 | wrap_fp16_model(model)
209 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
210 | if args.fuse_conv_bn:
211 | model = fuse_conv_bn(model)
212 | # old versions did not save class info in checkpoints, this walkaround is
213 | # for backward compatibility
214 | if 'CLASSES' in checkpoint.get('meta', {}):
215 | model.CLASSES = checkpoint['meta']['CLASSES']
216 | else:
217 | model.CLASSES = dataset.CLASSES
218 | # palette for visualization in segmentation tasks
219 | if 'PALETTE' in checkpoint.get('meta', {}):
220 | model.PALETTE = checkpoint['meta']['PALETTE']
221 | elif hasattr(dataset, 'PALETTE'):
222 | # segmentation dataset has `PALETTE` attribute
223 | model.PALETTE = dataset.PALETTE
224 |
225 | if not distributed:
226 | model = MMDataParallel(model, device_ids=cfg.gpu_ids)
227 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
228 | else:
229 | model = MMDistributedDataParallel(
230 | model.cuda(),
231 | device_ids=[torch.cuda.current_device()],
232 | broadcast_buffers=False)
233 | outputs = multi_gpu_test(model, data_loader,
234 | tmpdir=os.path.join(args.out_dir, '.dist_test'),
235 | gpu_collect=args.gpu_collect)
236 |
237 | rank, _ = get_dist_info()
238 | if rank == 0:
239 | if args.out and not args.input:
240 | logger.info(f'Writing results to results.pkl')
241 | mmcv.dump(outputs, os.path.join(args.out_dir, 'results.pkl'))
242 | kwargs = {} if args.eval_options is None else args.eval_options
243 | kwargs['logger'] = logger
244 | kwargs['show'] = args.show
245 | kwargs['out_dir'] = os.path.join(args.out_dir, 'vis/')
246 | if args.format_only:
247 | dataset.format_results(outputs, **kwargs)
248 |
249 | if args.eval:
250 | eval_kwargs = cfg.get('evaluation', {}).copy()
251 | # hard-code way to remove EvalHook args
252 | for key in [
253 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
254 | 'rule'
255 | ]:
256 | eval_kwargs.pop(key, None)
257 | eval_kwargs.update(dict(metric=args.eval, **kwargs))
258 |
259 | print(dataset.evaluate(outputs, **eval_kwargs))
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Tianyu Li
5 | # ---------------------------------------------
6 | from __future__ import division
7 | import argparse
8 | import copy
9 | import os
10 | import time
11 | import warnings
12 | from os import path as osp
13 |
14 | import mmcv
15 | import torch
16 | import torch.distributed as dist
17 | from mmcv import Config, DictAction
18 | from mmcv.runner import get_dist_info, init_dist
19 |
20 | from mmdet import __version__ as mmdet_version
21 | from mmdet3d import __version__ as mmdet3d_version
22 | from mmdet3d.apis import init_random_seed, train_model
23 | from mmdet3d.datasets import build_dataset
24 | from mmdet3d.models import build_model
25 | from mmdet3d.utils import collect_env, get_root_logger
26 | from mmdet.apis import set_random_seed
27 | from mmseg import __version__ as mmseg_version
28 |
29 | try:
30 | # If mmdet version > 2.20.0, setup_multi_processes would be imported and
31 | # used from mmdet instead of mmdet3d.
32 | from mmdet.utils import setup_multi_processes
33 | except ImportError:
34 | from mmdet3d.utils import setup_multi_processes
35 |
36 |
37 | def parse_args():
38 | parser = argparse.ArgumentParser(description='Train a detector')
39 | parser.add_argument('config', help='train config file path')
40 | parser.add_argument('--work-dir', help='the dir to save logs and models')
41 | parser.add_argument(
42 | '--resume-from', help='the checkpoint file to resume from')
43 | parser.add_argument(
44 | '--auto-resume',
45 | action='store_true',
46 | help='resume from the latest checkpoint automatically')
47 | parser.add_argument(
48 | '--no-validate',
49 | action='store_true',
50 | help='whether not to evaluate the checkpoint during training')
51 | group_gpus = parser.add_mutually_exclusive_group()
52 | group_gpus.add_argument(
53 | '--gpus',
54 | type=int,
55 | help='(Deprecated, please use --gpu-id) number of gpus to use '
56 | '(only applicable to non-distributed training)')
57 | group_gpus.add_argument(
58 | '--gpu-ids',
59 | type=int,
60 | nargs='+',
61 | help='(Deprecated, please use --gpu-id) ids of gpus to use '
62 | '(only applicable to non-distributed training)')
63 | group_gpus.add_argument(
64 | '--gpu-id',
65 | type=int,
66 | default=0,
67 | help='number of gpus to use '
68 | '(only applicable to non-distributed training)')
69 | parser.add_argument('--seed', type=int, default=0, help='random seed')
70 | parser.add_argument(
71 | '--diff-seed',
72 | action='store_true',
73 | help='Whether or not set different seeds for different ranks')
74 | parser.add_argument(
75 | '--deterministic',
76 | action='store_true',
77 | help='whether to set deterministic options for CUDNN backend.')
78 | parser.add_argument(
79 | '--options',
80 | nargs='+',
81 | action=DictAction,
82 | help='override some settings in the used config, the key-value pair '
83 | 'in xxx=yyy format will be merged into config file (deprecate), '
84 | 'change to --cfg-options instead.')
85 | parser.add_argument(
86 | '--cfg-options',
87 | nargs='+',
88 | action=DictAction,
89 | help='override some settings in the used config, the key-value pair '
90 | 'in xxx=yyy format will be merged into config file. If the value to '
91 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
92 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
93 | 'Note that the quotation marks are necessary and that no white space '
94 | 'is allowed.')
95 | parser.add_argument(
96 | '--launcher',
97 | choices=['none', 'pytorch', 'slurm', 'mpi'],
98 | default='none',
99 | help='job launcher')
100 | parser.add_argument('--local_rank', type=int, default=0)
101 | parser.add_argument(
102 | '--autoscale-lr',
103 | action='store_true',
104 | help='automatically scale lr with the number of gpus')
105 | args = parser.parse_args()
106 | if 'LOCAL_RANK' not in os.environ:
107 | os.environ['LOCAL_RANK'] = str(args.local_rank)
108 |
109 | if args.options and args.cfg_options:
110 | raise ValueError(
111 | '--options and --cfg-options cannot be both specified, '
112 | '--options is deprecated in favor of --cfg-options')
113 | if args.options:
114 | warnings.warn('--options is deprecated in favor of --cfg-options')
115 | args.cfg_options = args.options
116 |
117 | return args
118 |
119 |
120 | def auto_scale_lr(cfg, distributed, logger):
121 | """Automatically scaling LR according to GPU number and sample per GPU.
122 |
123 | Args:
124 | cfg (config): Training config.
125 | distributed (bool): Using distributed or not.
126 | logger (logging.Logger): Logger.
127 | """
128 | # Get flag from config
129 | if ('auto_scale_lr' not in cfg) or \
130 | (not cfg.auto_scale_lr.get('enable', False)):
131 | logger.info('Automatic scaling of learning rate (LR)'
132 | ' has been disabled.')
133 | return
134 |
135 | # Get base batch size from config
136 | base_batch_size = cfg.auto_scale_lr.get('base_batch_size', None)
137 | if base_batch_size is None:
138 | return
139 |
140 | # Get gpu number
141 | if distributed:
142 | _, world_size = get_dist_info()
143 | num_gpus = len(range(world_size))
144 | else:
145 | num_gpus = len(cfg.gpu_ids)
146 |
147 | # calculate the batch size
148 | samples_per_gpu = cfg.data.samples_per_gpu
149 | batch_size = num_gpus * samples_per_gpu
150 | logger.info(f'Training with {num_gpus} GPU(s) with {samples_per_gpu} '
151 | f'samples per GPU. The total batch size is {batch_size}.')
152 |
153 | if batch_size != base_batch_size:
154 | # scale LR with
155 | # [linear scaling rule](https://arxiv.org/abs/1706.02677)
156 | scaled_lr = (batch_size / base_batch_size) * cfg.optimizer.lr
157 | logger.info('LR has been automatically scaled '
158 | f'from {cfg.optimizer.lr} to {scaled_lr}')
159 | cfg.optimizer.lr = scaled_lr
160 | else:
161 | logger.info('The batch size match the '
162 | f'base batch size: {base_batch_size}, '
163 | f'will not scaling the LR ({cfg.optimizer.lr}).')
164 |
165 |
166 | def main():
167 | args = parse_args()
168 |
169 | cfg = Config.fromfile(args.config)
170 | if args.cfg_options is not None:
171 | cfg.merge_from_dict(args.cfg_options)
172 |
173 | # set multi-process settings
174 | setup_multi_processes(cfg)
175 |
176 | # set cudnn_benchmark
177 | if cfg.get('cudnn_benchmark', False):
178 | torch.backends.cudnn.benchmark = True
179 |
180 | # work_dir is determined in this priority: CLI > segment in file > filename
181 | if args.work_dir is not None:
182 | # update configs according to CLI args if args.work_dir is not None
183 | cfg.work_dir = args.work_dir
184 | elif cfg.get('work_dir', None) is None:
185 | # use config filename as default work_dir if cfg.work_dir is None
186 | cfg.work_dir = osp.join('./work_dirs',
187 | osp.splitext(osp.basename(args.config))[0])
188 | if args.resume_from is not None:
189 | cfg.resume_from = args.resume_from
190 |
191 | if args.auto_resume:
192 | cfg.auto_resume = args.auto_resume
193 | warnings.warn('`--auto-resume` is only supported when mmdet'
194 | 'version >= 2.20.0 for 3D detection model or'
195 | 'mmsegmentation version >= 0.21.0 for 3D'
196 | 'segmentation model')
197 |
198 | if args.gpus is not None:
199 | cfg.gpu_ids = range(1)
200 | warnings.warn('`--gpus` is deprecated because we only support '
201 | 'single GPU mode in non-distributed training. '
202 | 'Use `gpus=1` now.')
203 | if args.gpu_ids is not None:
204 | cfg.gpu_ids = args.gpu_ids[0:1]
205 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
206 | 'Because we only support single GPU mode in '
207 | 'non-distributed training. Use the first GPU '
208 | 'in `gpu_ids` now.')
209 | if args.gpus is None and args.gpu_ids is None:
210 | cfg.gpu_ids = [args.gpu_id]
211 |
212 | if args.autoscale_lr:
213 | if 'auto_scale_lr' in cfg and \
214 | 'base_batch_size' in cfg.auto_scale_lr:
215 | cfg.auto_scale_lr.enable = True
216 | else:
217 | warnings.warn('Can not find "auto_scale_lr" or '
218 | '"auto_scale_lr.base_batch_size" in your'
219 | ' configuration file.')
220 |
221 | # init distributed env first, since logger depends on the dist info.
222 | if args.launcher == 'none':
223 | distributed = False
224 | else:
225 | distributed = True
226 | init_dist(args.launcher, **cfg.dist_params)
227 | # re-set gpu_ids with distributed training mode
228 | _, world_size = get_dist_info()
229 | cfg.gpu_ids = range(world_size)
230 |
231 | # create work_dir
232 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
233 | # dump config
234 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
235 | # init the logger before other steps
236 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
237 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
238 | # specify logger name, if we still use 'mmdet', the output info will be
239 | # filtered and won't be saved in the log_file
240 | logger = get_root_logger(
241 | log_file=log_file, log_level=cfg.log_level, name='mmdet')
242 |
243 | # init the meta dict to record some important information such as
244 | # environment info and seed, which will be logged
245 | meta = dict()
246 | # log env info
247 | env_info_dict = collect_env()
248 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
249 | dash_line = '-' * 60 + '\n'
250 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
251 | dash_line)
252 | meta['env_info'] = env_info
253 | meta['config'] = cfg.pretty_text
254 |
255 | # log some basic info
256 | logger.info(f'Distributed training: {distributed}')
257 | logger.info(f'Config:\n{cfg.pretty_text}')
258 |
259 | # set random seeds
260 | seed = init_random_seed(args.seed)
261 | seed = seed + dist.get_rank() if args.diff_seed else seed
262 | logger.info(f'Set random seed to {seed}, '
263 | f'deterministic: {args.deterministic}')
264 | set_random_seed(seed, deterministic=args.deterministic)
265 | cfg.seed = seed
266 | meta['seed'] = seed
267 | meta['exp_name'] = osp.basename(args.config)
268 |
269 | model = build_model(
270 | cfg.model,
271 | train_cfg=cfg.get('train_cfg'),
272 | test_cfg=cfg.get('test_cfg'))
273 | model.init_weights()
274 |
275 | logger.info(f'Model:\n{model}')
276 | datasets = [build_dataset(cfg.data.train)]
277 | if len(cfg.workflow) == 2:
278 | val_dataset = copy.deepcopy(cfg.data.val)
279 | # in case we use a dataset wrapper
280 | if 'dataset' in cfg.data.train:
281 | val_dataset.pipeline = cfg.data.train.dataset.pipeline
282 | else:
283 | val_dataset.pipeline = cfg.data.train.pipeline
284 | # set test_mode=False here in deep copied config
285 | # which do not affect AP/AR calculation later
286 | # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
287 | val_dataset.test_mode = False
288 | datasets.append(build_dataset(val_dataset))
289 | if cfg.checkpoint_config is not None:
290 | # save mmdet version, config file content and class names in
291 | # checkpoints as meta data
292 | cfg.checkpoint_config.meta = dict(
293 | mmdet_version=mmdet_version,
294 | mmseg_version=mmseg_version,
295 | mmdet3d_version=mmdet3d_version,
296 | config=cfg.pretty_text,
297 | CLASSES=datasets[0].CLASSES,
298 | PALETTE=datasets[0].PALETTE # for segmentors
299 | if hasattr(datasets[0], 'PALETTE') else None)
300 | # add an attribute for visualization convenience
301 | model.CLASSES = datasets[0].CLASSES
302 | auto_scale_lr(cfg, distributed=distributed, logger=logger)
303 | train_model(
304 | model,
305 | datasets,
306 | cfg,
307 | distributed=distributed,
308 | validate=(not args.no_validate),
309 | timestamp=timestamp,
310 | meta=meta)
311 |
312 |
313 | if __name__ == '__main__':
314 | main()
315 |
--------------------------------------------------------------------------------