├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── assets
└── result1.png
├── configs
├── train
│ └── train_scalable.yaml
└── validation
│ └── validation_scalable.yaml
├── data
└── valid_demo
│ ├── 1a146468873b7871.pkl
│ ├── 1a37753b91c30e16.pkl
│ ├── 1b076986122dc653.pkl
│ ├── 1c83f56236e33b4.pkl
│ ├── 1ce0b4bbd35a6ad1.pkl
│ ├── 1d3daf744e65dd7c.pkl
│ ├── 1d71990efc3c6c0e.pkl
│ ├── 1d8ff1c6acd0f266.pkl
│ ├── 1ea00a2a7b936853.pkl
│ ├── 2a716fad0956ffcf.pkl
│ └── 2ab47807431834c1.pkl
├── data_preprocess.py
├── environment.yml
├── requirements.txt
├── scripts
├── install_pyg.sh
└── traj_clstering.py
├── smart
├── __init__.py
├── datamodules
│ ├── __init__.py
│ └── scalable_datamodule.py
├── datasets
│ ├── __init__.py
│ ├── preprocess.py
│ └── scalable_dataset.py
├── layers
│ ├── __init__.py
│ ├── attention_layer.py
│ ├── fourier_embedding.py
│ └── mlp_layer.py
├── metrics
│ ├── __init__.py
│ ├── average_meter.py
│ ├── min_ade.py
│ ├── min_fde.py
│ ├── next_token_cls.py
│ └── utils.py
├── model
│ ├── __init__.py
│ └── smart.py
├── modules
│ ├── __init__.py
│ ├── agent_decoder.py
│ ├── map_decoder.py
│ └── smart_decoder.py
├── preprocess
│ ├── __init__.py
│ └── preprocess.py
├── tokens
│ ├── __init__.py
│ ├── cluster_frame_5_2048.pkl
│ └── map_traj_token5.pkl
├── transforms
│ ├── __init__.py
│ └── target_builder.py
└── utils
│ ├── __init__.py
│ ├── cluster_reader.py
│ ├── config.py
│ ├── geometry.py
│ ├── graph.py
│ ├── list.py
│ ├── log.py
│ ├── nan_checker.py
│ └── weight_init.py
├── train.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .github
6 | ckpt/
7 | # assets/
8 | # C extensions
9 | *.so
10 | # /assets
11 | /data
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # SageMath parsed files
85 | *.sage.py
86 |
87 | # Environments
88 | .env
89 | .venv
90 | *.jpg
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 | *.jpg
97 | pyg_depend/
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | # IDEs
112 | .idea
113 | .vscode
114 |
115 | # seed project
116 | av2/
117 | lightning_logs/
118 | lightning_logs_/
119 | lightning_l/
120 | .DS_Store
121 | data/argo
122 | data/res
123 | data/waymo*
124 | fig*/
125 | data/waymo_token
126 | data/submission
127 | data/token_seq_emb_nuplan
128 | data/token_seq_emb_waymo
129 | data/nuplan*
130 | submission.tar.gz
131 | data/feat*
132 | data/scalable
133 | data/pos_data
134 | res_metrics*
135 | gathered*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction
4 |
5 | [Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)
6 |
7 |
8 |
9 | - **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/)
10 | - **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
11 |
12 | ## News
13 | - **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning**
14 | - **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024
15 | - **[August 31, 2024]** Code released
16 | - **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
17 | - **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677)
18 |
19 |
20 | ## Introduction
21 | This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens.
22 |
23 | ## Requirements
24 |
25 | To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies:
26 |
27 | ```bash
28 | conda env create -f environment.yml
29 | conda activate SMART
30 | pip install -r requirements.txt
31 | ```
32 |
33 | If you encounter issues while installing pyg dependencies, execute the following script:
34 | ```setup
35 | bash install_pyg.sh
36 | ```
37 |
38 | Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice.
39 |
40 | ## Data installation
41 |
42 | **Step 1: Download the Dataset**
43 |
44 | Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows:
45 | ```
46 | SMART
47 | ├── data
48 | │ ├── waymo
49 | │ │ ├── scenario
50 | │ │ │ ├──training
51 | │ │ │ ├──validation
52 | │ │ │ ├──testing
53 | ├── model
54 | ├── tools
55 | ```
56 |
57 | **Step 2: Install the Waymo Open Dataset API**
58 |
59 | Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API.
60 |
61 | **Step 3: Preprocess the Dataset**
62 |
63 | Preprocess the dataset by running:
64 | ```
65 | python data_preprocess.py --input_dir ./data/waymo/scenario/training --output_dir ./data/waymo_processed/training
66 | ```
67 | The first path is the raw data path, and the second is the output data path.
68 |
69 | The processed data will be saved to the `data/waymo_processed/` directory as follows:
70 |
71 | ```
72 | SMART
73 | ├── data
74 | │ ├── waymo_processed
75 | │ │ ├── training
76 | │ │ ├── validation
77 | │ │ ├──testing
78 | ├── model
79 | ├── utils
80 | ```
81 |
82 | ## Training
83 |
84 | To train the model, run the following command:
85 |
86 | ```train
87 | python train.py --config ${config_path}
88 | ```
89 |
90 | The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training.
91 |
92 | ## Evaluation
93 |
94 | To evaluate the model, run:
95 |
96 | ```eval
97 | python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path}
98 | ```
99 | This will evaluate the model using the configuration and checkpoint provided.
100 |
101 |
102 | ## Pre-trained Models
103 |
104 | To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed.
105 |
106 | ## Results
107 |
108 | ### Waymo Open Motion Dataset Sim Agents Challenge
109 |
110 | Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/):
111 |
112 | | Model name | Metric Score |
113 | | :-----------: | ------------ |
114 | | SMART-tiny | 0.7591 |
115 | | SMART-large | 0.7614 |
116 | | SMART-zeroshot| 0.7210 |
117 |
118 | ### NuPlan Closed-loop Planning
119 |
120 | **SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below:
121 |
122 | 
123 |
124 | ## Citation
125 |
126 | If you find this repository useful, please consider citing our work and giving us a star:
127 |
128 | ```citation
129 | @article{wu2024smart,
130 | title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction},
131 | author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng},
132 | journal={arXiv preprint arXiv:2405.15677},
133 | year={2024}
134 | }
135 | ```
136 |
137 | ## Acknowledgements
138 | Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work.
139 |
140 | ## License
141 | All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
142 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/assets/result1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/assets/result1.png
--------------------------------------------------------------------------------
/configs/train/train_scalable.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number, the yaml support to valid case source from different dataset
2 | time_info: &time_info
3 | num_historical_steps: 11
4 | num_future_steps: 80
5 | use_intention: True
6 | token_size: 2048
7 |
8 | Dataset:
9 | root:
10 | train_batch_size: 1
11 | val_batch_size: 1
12 | test_batch_size: 1
13 | shuffle: True
14 | num_workers: 1
15 | pin_memory: True
16 | persistent_workers: True
17 | train_raw_dir: ["data/valid_demo"]
18 | val_raw_dir: ["data/valid_demo"]
19 | test_raw_dir:
20 | transform: WaymoTargetBuilder
21 | train_processed_dir:
22 | val_processed_dir:
23 | test_processed_dir:
24 | dataset: "scalable"
25 | <<: *time_info
26 |
27 | Trainer:
28 | strategy: ddp_find_unused_parameters_false
29 | accelerator: "gpu"
30 | devices: 1
31 | max_epochs: 32
32 | save_ckpt_path:
33 | num_nodes: 1
34 | mode:
35 | ckpt_path:
36 | precision: 32
37 | accumulate_grad_batches: 1
38 |
39 | Model:
40 | mode: "train"
41 | predictor: "smart"
42 | dataset: "waymo"
43 | input_dim: 2
44 | hidden_dim: 128
45 | output_dim: 2
46 | output_head: False
47 | num_heads: 8
48 | <<: *time_info
49 | head_dim: 16
50 | dropout: 0.1
51 | num_freq_bands: 64
52 | lr: 0.0005
53 | warmup_steps: 0
54 | total_steps: 32
55 | decoder:
56 | <<: *time_info
57 | num_map_layers: 3
58 | num_agent_layers: 6
59 | a2a_radius: 60
60 | pl2pl_radius: 10
61 | pl2a_radius: 30
62 | time_span: 30
63 |
--------------------------------------------------------------------------------
/configs/validation/validation_scalable.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number, the yaml support to valid case source from different dataset
2 | time_info: &time_info
3 | num_historical_steps: 11
4 | num_future_steps: 80
5 | token_size: 2048
6 |
7 | Dataset:
8 | root:
9 | batch_size: 1
10 | shuffle: True
11 | num_workers: 1
12 | pin_memory: True
13 | persistent_workers: True
14 | train_raw_dir:
15 | val_raw_dir: ["data/valid_demo"]
16 | test_raw_dir:
17 | TargetBuilder: WaymoTargetBuilder
18 | train_processed_dir:
19 | val_processed_dir:
20 | test_processed_dir:
21 | dataset: "scalable"
22 | <<: *time_info
23 |
24 | Trainer:
25 | strategy: ddp_find_unused_parameters_false
26 | accelerator: "gpu"
27 | devices: 1
28 | max_epochs: 32
29 | save_ckpt_path:
30 | num_nodes: 1
31 | mode:
32 | ckpt_path:
33 | precision: 32
34 | accumulate_grad_batches: 1
35 |
36 | Model:
37 | mode: "validation"
38 | predictor: "smart"
39 | dataset: "waymo"
40 | input_dim: 2
41 | hidden_dim: 128
42 | output_dim: 2
43 | output_head: False
44 | num_heads: 8
45 | <<: *time_info
46 | head_dim: 16
47 | dropout: 0.1
48 | num_freq_bands: 64
49 | lr: 0.0005
50 | warmup_steps: 0
51 | total_steps: 32
52 | decoder:
53 | <<: *time_info
54 | num_map_layers: 3
55 | num_agent_layers: 6
56 | a2a_radius: 60
57 | pl2pl_radius: 10
58 | pl2a_radius: 30
59 | time_span: 30
60 |
61 |
--------------------------------------------------------------------------------
/data/valid_demo/1a146468873b7871.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1a146468873b7871.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1a37753b91c30e16.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1a37753b91c30e16.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1b076986122dc653.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1b076986122dc653.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1c83f56236e33b4.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1c83f56236e33b4.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1ce0b4bbd35a6ad1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1ce0b4bbd35a6ad1.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1d3daf744e65dd7c.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d3daf744e65dd7c.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1d71990efc3c6c0e.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d71990efc3c6c0e.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1d8ff1c6acd0f266.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d8ff1c6acd0f266.pkl
--------------------------------------------------------------------------------
/data/valid_demo/1ea00a2a7b936853.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1ea00a2a7b936853.pkl
--------------------------------------------------------------------------------
/data/valid_demo/2a716fad0956ffcf.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/2a716fad0956ffcf.pkl
--------------------------------------------------------------------------------
/data/valid_demo/2ab47807431834c1.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/2ab47807431834c1.pkl
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: smart
2 | channels:
3 | - pytorch
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=5.1=1_gnu
13 | - blas=1.0=mkl
14 | - brotli-python=1.0.9=py39h6a678d5_8
15 | - bzip2=1.0.8=h5eee18b_6
16 | - ca-certificates=2024.9.24=h06a4308_0
17 | - certifi=2024.8.30=py39h06a4308_0
18 | - charset-normalizer=3.3.2=pyhd3eb1b0_0
19 | - cudatoolkit=11.3.1=h2bc3f7f_2
20 | - ffmpeg=4.3=hf484d3e_0
21 | - freetype=2.12.1=h4a9f257_0
22 | - gmp=6.2.1=h295c915_3
23 | - gnutls=3.6.15=he1e5248_0
24 | - idna=3.7=py39h06a4308_0
25 | - intel-openmp=2023.1.0=hdb19cb5_46306
26 | - jpeg=9e=h5eee18b_3
27 | - lame=3.100=h7b6447c_0
28 | - lcms2=2.12=h3be6417_0
29 | - ld_impl_linux-64=2.40=h12ee557_0
30 | - lerc=3.0=h295c915_0
31 | - libdeflate=1.17=h5eee18b_1
32 | - libffi=3.4.4=h6a678d5_1
33 | - libgcc-ng=11.2.0=h1234567_1
34 | - libgomp=11.2.0=h1234567_1
35 | - libiconv=1.14=0
36 | - libidn2=2.3.4=h5eee18b_0
37 | - libpng=1.6.39=h5eee18b_0
38 | - libstdcxx-ng=11.2.0=h1234567_1
39 | - libtasn1=4.19.0=h5eee18b_0
40 | - libtiff=4.5.1=h6a678d5_0
41 | - libunistring=0.9.10=h27cfd23_0
42 | - libwebp-base=1.3.2=h5eee18b_1
43 | - lz4-c=1.9.4=h6a678d5_1
44 | - mkl=2023.1.0=h213fc3f_46344
45 | - mkl-service=2.4.0=py39h5eee18b_1
46 | - mkl_fft=1.3.10=py39h5eee18b_0
47 | - mkl_random=1.2.7=py39h1128e8f_0
48 | - ncurses=6.4=h6a678d5_0
49 | - nettle=3.7.3=hbbd107a_1
50 | - openh264=2.1.1=h4ff587b_0
51 | - openjpeg=2.5.2=he7f1fd0_0
52 | - openssl=3.0.15=h5eee18b_0
53 | - pillow=10.4.0=py39h5eee18b_0
54 | - pip=24.2=py39h06a4308_0
55 | - pysocks=1.7.1=py39h06a4308_0
56 | - python=3.9.19=h955ad1f_1
57 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
58 | - pytorch-mutex=1.0=cuda
59 | - readline=8.2=h5eee18b_0
60 | - requests=2.32.3=py39h06a4308_0
61 | - setuptools=75.1.0=py39h06a4308_0
62 | - sqlite=3.45.3=h5eee18b_0
63 | - tbb=2021.8.0=hdb19cb5_0
64 | - tk=8.6.14=h39e8969_0
65 | - torchvision=0.13.1=py39_cu113
66 | - typing_extensions=4.11.0=py39h06a4308_0
67 | - urllib3=2.2.3=py39h06a4308_0
68 | - wheel=0.44.0=py39h06a4308_0
69 | - xz=5.4.6=h5eee18b_1
70 | - zlib=1.2.13=h5eee18b_1
71 | - zstd=1.5.6=hc292b87_0
72 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs==2.4.3
2 | aiohttp==3.10.10
3 | aiosignal==1.3.1
4 | async-timeout==4.0.3
5 | attrs==24.2.0
6 | contourpy==1.3.0
7 | cycler==0.12.1
8 | easydict==1.13
9 | fonttools==4.54.1
10 | frozenlist==1.4.1
11 | fsspec==2024.10.0
12 | importlib-resources==6.4.5
13 | jinja2==3.1.4
14 | kiwisolver==1.4.7
15 | lightning-utilities==0.11.8
16 | markupsafe==3.0.2
17 | matplotlib==3.9.2
18 | multidict==6.1.0
19 | numpy==1.26.4
20 | packaging==24.1
21 | pandas==2.0.3
22 | propcache==0.2.0
23 | psutil==6.1.0
24 | pyparsing==3.2.0
25 | python-dateutil==2.9.0.post0
26 | pytorch-lightning==2.0.3
27 | pytz==2024.2
28 | pyyaml==6.0.1
29 | scipy==1.10.1
30 | shapely==2.0.6
31 | six==1.16.0
32 | torch-cluster==1.6.0+pt112cu113
33 | torch-geometric==2.6.1
34 | torch-scatter==2.1.0+pt112cu113
35 | torch-sparse==0.6.16+pt112cu113
36 | torch-spline-conv==1.2.1+pt112cu113
37 | torchmetrics==1.5.0
38 | tqdm==4.66.5
39 | tzdata==2024.2
40 | yarl==1.16.0
41 | zipp==3.20.2
42 | waymo-open-dataset-tf-2-12-0==1.6.4
43 |
--------------------------------------------------------------------------------
/scripts/install_pyg.sh:
--------------------------------------------------------------------------------
1 | mkdir pyg_depend && cd pyg_depend
2 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
3 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
4 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
5 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
6 | python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl
7 | python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl
8 | python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl
9 | python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl
10 | python3 -m pip install torch_geometric
11 |
--------------------------------------------------------------------------------
/scripts/traj_clstering.py:
--------------------------------------------------------------------------------
1 | from smart.utils.geometry import wrap_angle
2 | import numpy as np
3 |
4 |
5 | def average_distance_vectorized(point_set1, centroids):
6 | dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1))
7 | return np.mean(dists, axis=2)
8 |
9 |
10 | def assign_clusters(sub_X, centroids):
11 | distances = average_distance_vectorized(sub_X, centroids)
12 | return np.argmin(distances, axis=1)
13 |
14 |
15 | def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
16 | S = []
17 | ret_traj_list = []
18 | while len(S) < N:
19 | num_all = X.shape[0]
20 | # 随机选择第一个簇中心
21 | choice_index = np.random.choice(num_all)
22 | x0 = X[choice_index]
23 | if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10:
24 | continue
25 | res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2)
26 | del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2)
27 | if cal_mean_heading:
28 | del_contour = X[del_mask]
29 | diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :]
30 | del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean()
31 | x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length)
32 | del_traj = a_pos[del_mask]
33 | ret_traj = del_traj.mean(0)[None, ...]
34 | if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0:
35 | print(ret_traj)
36 | print('1')
37 | else:
38 | x0 = x0[None, ...]
39 | ret_traj = a_pos[choice_index][None, ...]
40 | X = X[res_mask]
41 | a_pos = a_pos[res_mask]
42 | S.append(x0)
43 | ret_traj_list.append(ret_traj)
44 | centroids = np.concatenate(S, axis=0)
45 | ret_traj = np.concatenate(ret_traj_list, axis=0)
46 |
47 | # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2))
48 |
49 | # for k in range(1, K):
50 | # new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2))
51 | # closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
52 | # probabilities = closest_dist_sq / np.sum(closest_dist_sq)
53 | # centroids[k] = X[np.random.choice(N, p=probabilities)]
54 |
55 | return centroids, ret_traj
56 |
57 |
58 | def cal_polygon_contour(x, y, theta, width, length):
59 |
60 | left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
61 | left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
62 | left_front = np.column_stack((left_front_x, left_front_y))
63 |
64 | right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
65 | right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
66 | right_front = np.column_stack((right_front_x, right_front_y))
67 |
68 | right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
69 | right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
70 | right_back = np.column_stack((right_back_x, right_back_y))
71 |
72 | left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
73 | left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
74 | left_back = np.column_stack((left_back_x, left_back_y))
75 |
76 | polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
77 |
78 | return polygon_contour
79 |
80 |
81 | if __name__ == '__main__':
82 | shift = 5 # motion token time dimension
83 | num_cluster = 6 # vocabulary size
84 | cal_mean_heading = True
85 | data = {
86 | "veh": np.random.rand(1000, 6, 3),
87 | "cyc": np.random.rand(1000, 6, 3),
88 | "ped": np.random.rand(1000, 6, 3)
89 | }
90 | # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]]
91 | nms_res = {}
92 | res = {'token': {}, 'traj': {}, 'token_all': {}}
93 | for k, v in data.items():
94 | # if k != 'veh':
95 | # continue
96 | a_pos = v
97 | print(a_pos.shape)
98 | # a_pos = a_pos[:, shift:1+shift, :]
99 | cal_num = min(int(1e6), a_pos.shape[0])
100 | a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)]
101 | a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1])
102 | print(a_pos.shape)
103 | if shift <= 2:
104 | if k == 'veh':
105 | width = 1.0
106 | length = 2.4
107 | elif k == 'cyc':
108 | width = 0.5
109 | length = 1.5
110 | else:
111 | width = 0.5
112 | length = 0.5
113 | else:
114 | if k == 'veh':
115 | width = 2.0
116 | length = 4.8
117 | elif k == 'cyc':
118 | width = 1.0
119 | length = 2.0
120 | else:
121 | width = 1.0
122 | length = 1.0
123 | contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length)
124 |
125 | # plt.figure(figsize=(10, 10))
126 | # for rect in contour:
127 | # rect_closed = np.vstack([rect, rect[0]])
128 | # plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1)
129 |
130 | # plt.title("Plot of 256 Rectangles")
131 | # plt.xlabel("x")
132 | # plt.ylabel("y")
133 | # plt.axis('equal')
134 | # plt.savefig(f'src_{k}_new.jpg', dpi=300)
135 |
136 | if k == 'veh':
137 | tol = 0.05
138 | elif k == 'cyc':
139 | tol = 0.004
140 | else:
141 | tol = 0.004
142 | centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1])
143 | # plt.figure(figsize=(10, 10))
144 | contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)),
145 | ret_traj[:, :, 1].reshape(num_cluster*(shift+1)),
146 | ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length)
147 |
148 | res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2)
149 | res['token'][k] = centroids
150 | res['traj'][k] = ret_traj
151 |
--------------------------------------------------------------------------------
/smart/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/__init__.py
--------------------------------------------------------------------------------
/smart/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from smart.datamodules.scalable_datamodule import MultiDataModule
2 |
--------------------------------------------------------------------------------
/smart/datamodules/scalable_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytorch_lightning as pl
4 | from torch_geometric.loader import DataLoader
5 | from smart.datasets.scalable_dataset import MultiDataset
6 | from smart.transforms import WaymoTargetBuilder
7 |
8 |
9 | class MultiDataModule(pl.LightningDataModule):
10 | transforms = {
11 | "WaymoTargetBuilder": WaymoTargetBuilder,
12 | }
13 |
14 | dataset = {
15 | "scalable": MultiDataset,
16 | }
17 |
18 | def __init__(self,
19 | root: str,
20 | train_batch_size: int,
21 | val_batch_size: int,
22 | test_batch_size: int,
23 | shuffle: bool = False,
24 | num_workers: int = 0,
25 | pin_memory: bool = True,
26 | persistent_workers: bool = True,
27 | train_raw_dir: Optional[str] = None,
28 | val_raw_dir: Optional[str] = None,
29 | test_raw_dir: Optional[str] = None,
30 | train_processed_dir: Optional[str] = None,
31 | val_processed_dir: Optional[str] = None,
32 | test_processed_dir: Optional[str] = None,
33 | transform: Optional[str] = None,
34 | dataset: Optional[str] = None,
35 | num_historical_steps: int = 50,
36 | num_future_steps: int = 60,
37 | processor='ntp',
38 | use_intention=False,
39 | token_size=512,
40 | **kwargs) -> None:
41 | super(MultiDataModule, self).__init__()
42 | self.root = root
43 | self.dataset_class = dataset
44 | self.train_batch_size = train_batch_size
45 | self.val_batch_size = val_batch_size
46 | self.test_batch_size = test_batch_size
47 | self.shuffle = shuffle
48 | self.num_workers = num_workers
49 | self.pin_memory = pin_memory
50 | self.persistent_workers = persistent_workers and num_workers > 0
51 | self.train_raw_dir = train_raw_dir
52 | self.val_raw_dir = val_raw_dir
53 | self.test_raw_dir = test_raw_dir
54 | self.train_processed_dir = train_processed_dir
55 | self.val_processed_dir = val_processed_dir
56 | self.test_processed_dir = test_processed_dir
57 | self.processor = processor
58 | self.use_intention = use_intention
59 | self.token_size = token_size
60 |
61 | train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train")
62 | val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val")
63 | test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps)
64 |
65 | self.train_transform = train_transform
66 | self.val_transform = val_transform
67 | self.test_transform = test_transform
68 |
69 | def setup(self, stage: Optional[str] = None) -> None:
70 | self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir,
71 | raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size)
72 | self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir,
73 | raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size)
74 | self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir,
75 | raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size)
76 |
77 | def train_dataloader(self):
78 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
79 | num_workers=self.num_workers, pin_memory=self.pin_memory,
80 | persistent_workers=self.persistent_workers)
81 |
82 | def val_dataloader(self):
83 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
84 | num_workers=self.num_workers, pin_memory=self.pin_memory,
85 | persistent_workers=self.persistent_workers)
86 |
87 | def test_dataloader(self):
88 | return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
89 | num_workers=self.num_workers, pin_memory=self.pin_memory,
90 | persistent_workers=self.persistent_workers)
91 |
--------------------------------------------------------------------------------
/smart/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from smart.datasets.scalable_dataset import MultiDataset
2 |
--------------------------------------------------------------------------------
/smart/datasets/preprocess.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.interpolate import interp1d
4 | from scipy.spatial.distance import euclidean
5 | import math
6 | import pickle
7 | from smart.utils import wrap_angle
8 | import os
9 |
10 | def cal_polygon_contour(x, y, theta, width, length):
11 | left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
12 | left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
13 | left_front = np.column_stack((left_front_x, left_front_y))
14 |
15 | right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
16 | right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
17 | right_front = np.column_stack((right_front_x, right_front_y))
18 |
19 | right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
20 | right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
21 | right_back = np.column_stack((right_back_x, right_back_y))
22 |
23 | left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
24 | left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
25 | left_back = np.column_stack((left_back_x, left_back_y))
26 |
27 | polygon_contour = np.concatenate(
28 | (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
29 |
30 | return polygon_contour
31 |
32 |
33 | def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
34 | # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
35 | dist_along_path_list = [[0]]
36 | polylines_list = [[polylines[0]]]
37 | for i in range(1, polylines.shape[0]):
38 | euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
39 | heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
40 | abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
41 | if heading_diff > math.pi / 4 and euclidean_dist > 3:
42 | dist_along_path_list.append([0])
43 | polylines_list.append([polylines[i]])
44 | elif heading_diff > math.pi / 8 and euclidean_dist > 3:
45 | dist_along_path_list.append([0])
46 | polylines_list.append([polylines[i]])
47 | elif heading_diff > 0.1 and euclidean_dist > 3:
48 | dist_along_path_list.append([0])
49 | polylines_list.append([polylines[i]])
50 | elif euclidean_dist > 10:
51 | dist_along_path_list.append([0])
52 | polylines_list.append([polylines[i]])
53 | else:
54 | dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
55 | polylines_list[-1].append(polylines[i])
56 | # plt.plot(polylines[:, 0], polylines[:, 1])
57 | # plt.savefig('tmp.jpg')
58 | new_x_list = []
59 | new_y_list = []
60 | multi_polylines_list = []
61 | for idx in range(len(dist_along_path_list)):
62 | if len(dist_along_path_list[idx]) < 2:
63 | continue
64 | dist_along_path = np.array(dist_along_path_list[idx])
65 | polylines_cur = np.array(polylines_list[idx])
66 | # Create interpolation functions for x and y coordinates
67 | fx = interp1d(dist_along_path, polylines_cur[:, 0])
68 | fy = interp1d(dist_along_path, polylines_cur[:, 1])
69 | # fyaw = interp1d(dist_along_path, heading)
70 |
71 | # Create an array of distances at which to interpolate
72 | new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
73 | new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
74 | # Use the interpolation functions to generate new x and y coordinates
75 | new_x = fx(new_dist_along_path)
76 | new_y = fy(new_dist_along_path)
77 | # new_yaw = fyaw(new_dist_along_path)
78 | new_x_list.append(new_x)
79 | new_y_list.append(new_y)
80 |
81 | # Combine the new x and y coordinates into a single array
82 | new_polylines = np.vstack((new_x, new_y)).T
83 | polyline_size = int(split_distace / distance)
84 | if new_polylines.shape[0] >= (polyline_size + 1):
85 | padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
86 | final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
87 | else:
88 | padding_size = new_polylines.shape[0]
89 | final_index = 0
90 | multi_polylines = None
91 | new_polylines = torch.from_numpy(new_polylines)
92 | new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
93 | new_polylines[1:, 0] - new_polylines[:-1, 0])
94 | new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
95 | new_polylines = torch.cat([new_polylines, new_heading], -1)
96 | if new_polylines.shape[0] >= (polyline_size + 1):
97 | multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
98 | multi_polylines = multi_polylines.transpose(1, 2)
99 | multi_polylines = multi_polylines[:, ::5, :]
100 | if padding_size >= 3:
101 | last_polyline = new_polylines[final_index * polyline_size:]
102 | last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
103 | if multi_polylines is not None:
104 | multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
105 | else:
106 | multi_polylines = last_polyline.unsqueeze(0)
107 | if multi_polylines is None:
108 | continue
109 | multi_polylines_list.append(multi_polylines)
110 | if len(multi_polylines_list) > 0:
111 | multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
112 | else:
113 | multi_polylines_list = None
114 | return multi_polylines_list
115 |
116 |
117 | def average_distance_vectorized(point_set1, centroids):
118 | dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
119 | return np.mean(dists, axis=2)
120 |
121 |
122 | def assign_clusters(sub_X, centroids):
123 | distances = average_distance_vectorized(sub_X, centroids)
124 | return np.argmin(distances, axis=1)
125 |
126 |
127 | class TokenProcessor:
128 |
129 | def __init__(self, token_size):
130 | module_dir = os.path.dirname(os.path.dirname(__file__))
131 | self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl')
132 | self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
133 | self.noise = False
134 | self.disturb = False
135 | self.shift = 5
136 | self.get_trajectory_token()
137 | self.training = False
138 | self.current_step = 10
139 |
140 | def preprocess(self, data):
141 | data = self.tokenize_agent(data)
142 | data = self.tokenize_map(data)
143 | del data['city']
144 | if 'polygon_is_intersection' in data['map_polygon']:
145 | del data['map_polygon']['polygon_is_intersection']
146 | if 'route_type' in data['map_polygon']:
147 | del data['map_polygon']['route_type']
148 | return data
149 |
150 | def get_trajectory_token(self):
151 | agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
152 | map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
153 | self.trajectory_token = agent_token_data['token']
154 | self.trajectory_token_all = agent_token_data['token_all']
155 | self.map_token = {'traj_src': map_token_traj['traj_src'], }
156 | self.token_last = {}
157 | for k, v in self.trajectory_token_all.items():
158 | token_last = torch.from_numpy(v[:, -2:]).to(torch.float)
159 | diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]
160 | theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
161 | cos, sin = theta.cos(), theta.sin()
162 | rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)
163 | rot_mat[:, 0, 0] = cos
164 | rot_mat[:, 0, 1] = -sin
165 | rot_mat[:, 1, 0] = sin
166 | rot_mat[:, 1, 1] = cos
167 | agent_token = torch.bmm(token_last[:, 1], rot_mat)
168 | agent_token -= token_last[:, 0].mean(1)[:, None, :]
169 | self.token_last[k] = agent_token.numpy()
170 |
171 | def clean_heading(self, data):
172 | heading = data['agent']['heading']
173 | valid = data['agent']['valid_mask']
174 | pi = torch.tensor(torch.pi)
175 | n_vehicles, n_frames = heading.shape
176 |
177 | heading_diff_raw = heading[:, :-1] - heading[:, 1:]
178 | heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
179 | heading_diff[heading_diff > pi] -= 2 * pi
180 | heading_diff[heading_diff < -pi] += 2 * pi
181 |
182 | valid_pairs = valid[:, :-1] & valid[:, 1:]
183 |
184 | for i in range(n_frames - 1):
185 | change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]
186 |
187 | heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]
188 |
189 | if i < n_frames - 2:
190 | heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]
191 | heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
192 | heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi
193 | heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi
194 |
195 | def tokenize_agent(self, data):
196 | if data['agent']["velocity"].shape[1] == 90:
197 | print(data['scenario_id'], data['agent']["velocity"].shape)
198 | interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * (
199 | data['agent']['position'][:, self.current_step, 0] != 0)
200 | if data['agent']["velocity"].shape[-1] == 2:
201 | data['agent']["velocity"] = torch.cat([data['agent']["velocity"],
202 | torch.zeros(data['agent']["velocity"].shape[0],
203 | data['agent']["velocity"].shape[1], 1)], dim=-1)
204 | vel = data['agent']["velocity"][interplote_mask, self.current_step]
205 | data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][
206 | interplote_mask, self.current_step,
207 | :3] - vel * 0.1
208 | data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True
209 | data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][
210 | interplote_mask, self.current_step]
211 | data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][
212 | interplote_mask, self.current_step]
213 |
214 | data['agent']['type'] = data['agent']['type'].to(torch.uint8)
215 |
216 | self.clean_heading(data)
217 | matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * (
218 | data['agent']['valid_mask'][:, self.current_step - 5] == False)
219 |
220 | interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)
221 | data['agent']['valid_mask'][interplote_mask_first, 0] = True
222 |
223 | agent_pos = data['agent']['position'][:, :, :2]
224 | valid_mask = data['agent']['valid_mask']
225 |
226 | valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
227 | token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
228 | agent_type = data['agent']['type']
229 | agent_category = data['agent']['category']
230 | agent_heading = data['agent']['heading']
231 | vehicle_mask = agent_type == 0
232 | cyclist_mask = agent_type == 2
233 | ped_mask = agent_type == 1
234 |
235 | veh_pos = agent_pos[vehicle_mask, :, :]
236 | veh_valid_mask = valid_mask[vehicle_mask, :]
237 | cyc_pos = agent_pos[cyclist_mask, :, :]
238 | cyc_valid_mask = valid_mask[cyclist_mask, :]
239 | ped_pos = agent_pos[ped_mask, :, :]
240 | ped_valid_mask = valid_mask[ped_mask, :]
241 |
242 | veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],
243 | 'veh', agent_category[vehicle_mask],
244 | matching_extra_mask[vehicle_mask])
245 | ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',
246 | agent_category[ped_mask], matching_extra_mask[ped_mask])
247 | cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],
248 | 'cyc', agent_category[cyclist_mask],
249 | matching_extra_mask[cyclist_mask])
250 |
251 | token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
252 | token_index[vehicle_mask] = veh_token_index
253 | token_index[ped_mask] = ped_token_index
254 | token_index[cyclist_mask] = cyc_token_index
255 |
256 | token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
257 | veh_token_contour.shape[2], veh_token_contour.shape[3]))
258 | token_contour[vehicle_mask] = veh_token_contour
259 | token_contour[ped_mask] = ped_token_contour
260 | token_contour[cyclist_mask] = cyc_token_contour
261 |
262 | trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float)
263 | trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float)
264 | trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float)
265 |
266 | agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))
267 | agent_token_traj[vehicle_mask] = trajectory_token_veh
268 | agent_token_traj[ped_mask] = trajectory_token_ped
269 | agent_token_traj[cyclist_mask] = trajectory_token_cyc
270 |
271 | if not self.training:
272 | token_valid_mask[matching_extra_mask, 1] = True
273 |
274 | data['agent']['token_idx'] = token_index
275 | data['agent']['token_contour'] = token_contour
276 | token_pos = token_contour.mean(dim=2)
277 | data['agent']['token_pos'] = token_pos
278 | diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
279 | data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
280 | data['agent']['agent_valid_mask'] = token_valid_mask
281 |
282 | vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),
283 | ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1)
284 | vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),
285 | (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)
286 | vel[~vel_valid_mask] = 0
287 | vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][
288 | data['agent']['valid_mask'][:, self.current_step],
289 | self.current_step, :2]
290 |
291 | data['agent']['token_velocity'] = vel
292 |
293 | return data
294 |
295 | def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask):
296 | agent_token_src = self.trajectory_token[category]
297 | token_last = self.token_last[category]
298 | if self.shift <= 2:
299 | if category == 'veh':
300 | width = 1.0
301 | length = 2.4
302 | elif category == 'cyc':
303 | width = 0.5
304 | length = 1.5
305 | else:
306 | width = 0.5
307 | length = 0.5
308 | else:
309 | if category == 'veh':
310 | width = 2.0
311 | length = 4.8
312 | elif category == 'cyc':
313 | width = 1.0
314 | length = 2.0
315 | else:
316 | width = 1.0
317 | length = 1.0
318 |
319 | prev_heading = heading[:, 0]
320 | prev_pos = pos[:, 0]
321 | agent_num, num_step, feat_dim = pos.shape
322 | token_num, token_contour_dim, feat_dim = agent_token_src.shape
323 | agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
324 | token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)
325 | token_index_list = []
326 | token_contour_list = []
327 | prev_token_idx = None
328 |
329 | for i in range(self.shift, pos.shape[1], self.shift):
330 | theta = prev_heading
331 | cur_heading = heading[:, i]
332 | cur_pos = pos[:, i]
333 | cos, sin = theta.cos(), theta.sin()
334 | rot_mat = theta.new_zeros(agent_num, 2, 2)
335 | rot_mat[:, 0, 0] = cos
336 | rot_mat[:, 0, 1] = sin
337 | rot_mat[:, 1, 0] = -sin
338 | rot_mat[:, 1, 1] = cos
339 | agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,
340 | token_num,
341 | token_contour_dim,
342 | feat_dim)
343 | agent_token_world += prev_pos[:, None, None, :]
344 |
345 | cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
346 | agent_token_index = torch.from_numpy(np.argmin(
347 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
348 | axis=-1))
349 | if prev_token_idx is not None and self.noise:
350 | same_idx = prev_token_idx == agent_token_index
351 | same_idx[:] = True
352 | topk_indices = np.argsort(
353 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
354 | axis=2), axis=-1)[:, :5]
355 | sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
356 | agent_token_index[same_idx] = \
357 | torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
358 |
359 | token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]
360 |
361 | diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]
362 |
363 | prev_heading = heading[:, i].clone()
364 | prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
365 | valid_mask[:, i - self.shift]]
366 |
367 | prev_pos = pos[:, i].clone()
368 | prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]]
369 | prev_token_idx = agent_token_index
370 | token_index_list.append(agent_token_index[:, None])
371 | token_contour_list.append(token_contour_select[:, None, ...])
372 |
373 | token_index = torch.cat(token_index_list, dim=1)
374 | token_contour = torch.cat(token_contour_list, dim=1)
375 |
376 | # extra matching
377 | if not self.training:
378 | theta = heading[extra_mask, self.current_step - 1]
379 | prev_pos = pos[extra_mask, self.current_step - 1]
380 | cur_pos = pos[extra_mask, self.current_step]
381 | cur_heading = heading[extra_mask, self.current_step]
382 | cos, sin = theta.cos(), theta.sin()
383 | rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
384 | rot_mat[:, 0, 0] = cos
385 | rot_mat[:, 0, 1] = sin
386 | rot_mat[:, 1, 0] = -sin
387 | rot_mat[:, 1, 1] = cos
388 | agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
389 | extra_mask.sum(), token_num, token_contour_dim, feat_dim)
390 | agent_token_world += prev_pos[:, None, None, :]
391 |
392 | cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
393 | agent_token_index = torch.from_numpy(np.argmin(
394 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
395 | axis=-1))
396 | token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]
397 |
398 | token_index[extra_mask, 1] = agent_token_index
399 | token_contour[extra_mask, 1] = token_contour_select
400 |
401 | return token_index, token_contour
402 |
403 | def tokenize_map(self, data):
404 | data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
405 | data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
406 | pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
407 | pt_type = data['map_point']['type'].to(torch.uint8)
408 | pt_side = torch.zeros_like(pt_type)
409 | pt_pos = data['map_point']['position'][:, :2]
410 | data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
411 | pt_heading = data['map_point']['orientation']
412 | split_polyline_type = []
413 | split_polyline_pos = []
414 | split_polyline_theta = []
415 | split_polyline_side = []
416 | pl_idx_list = []
417 | split_polygon_type = []
418 | data['map_point']['type'].unique()
419 |
420 | for i in sorted(np.unique(pt2pl[1])):
421 | index = pt2pl[0, pt2pl[1] == i]
422 | polygon_type = data['map_polygon']["type"][i]
423 | cur_side = pt_side[index]
424 | cur_type = pt_type[index]
425 | cur_pos = pt_pos[index]
426 | cur_heading = pt_heading[index]
427 |
428 | for side_val in np.unique(cur_side):
429 | for type_val in np.unique(cur_type):
430 | if type_val == 13:
431 | continue
432 | indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
433 | if len(indices) <= 2:
434 | continue
435 | split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
436 | if split_polyline is None:
437 | continue
438 | new_cur_type = cur_type[indices][0]
439 | new_cur_side = cur_side[indices][0]
440 | map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
441 | new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
442 | new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
443 | cur_pl_idx = torch.Tensor([i])
444 | new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
445 | split_polyline_pos.append(split_polyline[..., :2])
446 | split_polyline_theta.append(split_polyline[..., 2])
447 | split_polyline_type.append(new_cur_type)
448 | split_polyline_side.append(new_cur_side)
449 | pl_idx_list.append(new_cur_pl_idx)
450 | split_polygon_type.append(map_polygon_type)
451 |
452 | split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
453 | split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
454 | split_polyline_type = torch.cat(split_polyline_type, dim=0)
455 | split_polyline_side = torch.cat(split_polyline_side, dim=0)
456 | split_polygon_type = torch.cat(split_polygon_type, dim=0)
457 | pl_idx_list = torch.cat(pl_idx_list, dim=0)
458 | vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
459 | data['map_save'] = {}
460 | data['pt_token'] = {}
461 | data['map_save']['traj_pos'] = split_polyline_pos
462 | data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0])
463 | data['map_save']['pl_idx_list'] = pl_idx_list
464 | data['pt_token']['type'] = split_polyline_type
465 | data['pt_token']['side'] = split_polyline_side
466 | data['pt_token']['pl_type'] = split_polygon_type
467 | data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
468 | return data
--------------------------------------------------------------------------------
/smart/datasets/scalable_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from typing import Callable, List, Optional, Tuple, Union
4 | import pandas as pd
5 | from torch_geometric.data import Dataset
6 | from smart.utils.log import Logging
7 | import numpy as np
8 | from .preprocess import TokenProcessor
9 |
10 |
11 | def distance(point1, point2):
12 | return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)
13 |
14 |
15 | class MultiDataset(Dataset):
16 | def __init__(self,
17 | root: str,
18 | split: str,
19 | raw_dir: List[str] = None,
20 | processed_dir: List[str] = None,
21 | transform: Optional[Callable] = None,
22 | dim: int = 3,
23 | num_historical_steps: int = 50,
24 | num_future_steps: int = 60,
25 | predict_unseen_agents: bool = False,
26 | vector_repr: bool = True,
27 | cluster: bool = False,
28 | processor=None,
29 | use_intention=False,
30 | token_size=512) -> None:
31 | self.logger = Logging().log(level='DEBUG')
32 | self.root = root
33 | self.well_done = [0]
34 | if split not in ('train', 'val', 'test'):
35 | raise ValueError(f'{split} is not a valid split')
36 | self.split = split
37 | self.training = split == 'train'
38 | self.logger.debug("Starting loading dataset")
39 | self._raw_file_names = []
40 | self._raw_paths = []
41 | self._raw_file_dataset = []
42 | if raw_dir is not None:
43 | self._raw_dir = raw_dir
44 | for raw_dir in self._raw_dir:
45 | raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
46 | dataset = "waymo"
47 | file_list = os.listdir(raw_dir)
48 | self._raw_file_names.extend(file_list)
49 | self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list])
50 | self._raw_file_dataset.extend([dataset for _ in range(len(file_list))])
51 | if self.root is not None:
52 | split_datainfo = os.path.join(root, "split_datainfo.pkl")
53 | with open(split_datainfo, 'rb+') as f:
54 | split_datainfo = pickle.load(f)
55 | if split == "test":
56 | split = "val"
57 | self._processed_file_names = split_datainfo[split]
58 | self.dim = dim
59 | self.num_historical_steps = num_historical_steps
60 | self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names)
61 | self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples))
62 | self.token_processor = TokenProcessor(2048)
63 | super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)
64 |
65 | @property
66 | def raw_dir(self) -> str:
67 | return self._raw_dir
68 |
69 | @property
70 | def raw_paths(self) -> List[str]:
71 | return self._raw_paths
72 |
73 | @property
74 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
75 | return self._raw_file_names
76 |
77 | @property
78 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
79 | return self._processed_file_names
80 |
81 | def len(self) -> int:
82 | return self._num_samples
83 |
84 | def generate_ref_token(self):
85 | pass
86 |
87 | def get(self, idx: int):
88 | with open(self.raw_paths[idx], 'rb') as handle:
89 | data = pickle.load(handle)
90 | data = self.token_processor.preprocess(data)
91 | return data
92 |
--------------------------------------------------------------------------------
/smart/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from smart.layers.attention_layer import AttentionLayer
3 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
4 | from smart.layers.mlp_layer import MLPLayer
5 |
--------------------------------------------------------------------------------
/smart/layers/attention_layer.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch_geometric.nn.conv import MessagePassing
7 | from torch_geometric.utils import softmax
8 |
9 | from smart.utils import weight_init
10 |
11 |
12 | class AttentionLayer(MessagePassing):
13 |
14 | def __init__(self,
15 | hidden_dim: int,
16 | num_heads: int,
17 | head_dim: int,
18 | dropout: float,
19 | bipartite: bool,
20 | has_pos_emb: bool,
21 | **kwargs) -> None:
22 | super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
23 | self.num_heads = num_heads
24 | self.head_dim = head_dim
25 | self.has_pos_emb = has_pos_emb
26 | self.scale = head_dim ** -0.5
27 |
28 | self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
29 | self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
30 | self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
31 | if has_pos_emb:
32 | self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
33 | self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
34 | self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
35 | self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
36 | self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
37 | self.attn_drop = nn.Dropout(dropout)
38 | self.ff_mlp = nn.Sequential(
39 | nn.Linear(hidden_dim, hidden_dim * 4),
40 | nn.ReLU(inplace=True),
41 | nn.Dropout(dropout),
42 | nn.Linear(hidden_dim * 4, hidden_dim),
43 | )
44 | if bipartite:
45 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
46 | self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
47 | else:
48 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
49 | self.attn_prenorm_x_dst = self.attn_prenorm_x_src
50 | if has_pos_emb:
51 | self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
52 | self.attn_postnorm = nn.LayerNorm(hidden_dim)
53 | self.ff_prenorm = nn.LayerNorm(hidden_dim)
54 | self.ff_postnorm = nn.LayerNorm(hidden_dim)
55 | self.apply(weight_init)
56 |
57 | def forward(self,
58 | x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
59 | r: Optional[torch.Tensor],
60 | edge_index: torch.Tensor) -> torch.Tensor:
61 | if isinstance(x, torch.Tensor):
62 | x_src = x_dst = self.attn_prenorm_x_src(x)
63 | else:
64 | x_src, x_dst = x
65 | x_src = self.attn_prenorm_x_src(x_src)
66 | x_dst = self.attn_prenorm_x_dst(x_dst)
67 | x = x[1]
68 | if self.has_pos_emb and r is not None:
69 | r = self.attn_prenorm_r(r)
70 | x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
71 | x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
72 | return x
73 |
74 | def message(self,
75 | q_i: torch.Tensor,
76 | k_j: torch.Tensor,
77 | v_j: torch.Tensor,
78 | r: Optional[torch.Tensor],
79 | index: torch.Tensor,
80 | ptr: Optional[torch.Tensor]) -> torch.Tensor:
81 | if self.has_pos_emb and r is not None:
82 | k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
83 | v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
84 | sim = (q_i * k_j).sum(dim=-1) * self.scale
85 | attn = softmax(sim, index, ptr)
86 | self.attention_weight = attn.sum(-1).detach()
87 | attn = self.attn_drop(attn)
88 | return v_j * attn.unsqueeze(-1)
89 |
90 | def update(self,
91 | inputs: torch.Tensor,
92 | x_dst: torch.Tensor) -> torch.Tensor:
93 | inputs = inputs.view(-1, self.num_heads * self.head_dim)
94 | g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
95 | return inputs + g * (self.to_s(x_dst) - inputs)
96 |
97 | def _attn_block(self,
98 | x_src: torch.Tensor,
99 | x_dst: torch.Tensor,
100 | r: Optional[torch.Tensor],
101 | edge_index: torch.Tensor) -> torch.Tensor:
102 | q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
103 | k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
104 | v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
105 | agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
106 | return self.to_out(agg)
107 |
108 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
109 | return self.ff_mlp(x)
110 |
--------------------------------------------------------------------------------
/smart/layers/fourier_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Optional
3 | import torch
4 | import torch.nn as nn
5 |
6 | from smart.utils import weight_init
7 |
8 |
9 | class FourierEmbedding(nn.Module):
10 |
11 | def __init__(self,
12 | input_dim: int,
13 | hidden_dim: int,
14 | num_freq_bands: int) -> None:
15 | super(FourierEmbedding, self).__init__()
16 | self.input_dim = input_dim
17 | self.hidden_dim = hidden_dim
18 |
19 | self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
20 | self.mlps = nn.ModuleList(
21 | [nn.Sequential(
22 | nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
23 | nn.LayerNorm(hidden_dim),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(hidden_dim, hidden_dim),
26 | )
27 | for _ in range(input_dim)])
28 | self.to_out = nn.Sequential(
29 | nn.LayerNorm(hidden_dim),
30 | nn.ReLU(inplace=True),
31 | nn.Linear(hidden_dim, hidden_dim),
32 | )
33 | self.apply(weight_init)
34 |
35 | def forward(self,
36 | continuous_inputs: Optional[torch.Tensor] = None,
37 | categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
38 | if continuous_inputs is None:
39 | if categorical_embs is not None:
40 | x = torch.stack(categorical_embs).sum(dim=0)
41 | else:
42 | raise ValueError('Both continuous_inputs and categorical_embs are None')
43 | else:
44 | x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
45 | # Warning: if your data are noisy, don't use learnable sinusoidal embedding
46 | x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
47 | continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
48 | for i in range(self.input_dim):
49 | continuous_embs[i] = self.mlps[i](x[:, i])
50 | x = torch.stack(continuous_embs).sum(dim=0)
51 | if categorical_embs is not None:
52 | x = x + torch.stack(categorical_embs).sum(dim=0)
53 | return self.to_out(x)
54 |
55 |
56 | class MLPEmbedding(nn.Module):
57 | def __init__(self,
58 | input_dim: int,
59 | hidden_dim: int) -> None:
60 | super(MLPEmbedding, self).__init__()
61 | self.input_dim = input_dim
62 | self.hidden_dim = hidden_dim
63 | self.mlp = nn.Sequential(
64 | nn.Linear(input_dim, 128),
65 | nn.LayerNorm(128),
66 | nn.ReLU(inplace=True),
67 | nn.Linear(128, hidden_dim),
68 | nn.LayerNorm(hidden_dim),
69 | nn.ReLU(inplace=True),
70 | nn.Linear(hidden_dim, hidden_dim))
71 | self.apply(weight_init)
72 |
73 | def forward(self,
74 | continuous_inputs: Optional[torch.Tensor] = None,
75 | categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
76 | if continuous_inputs is None:
77 | if categorical_embs is not None:
78 | x = torch.stack(categorical_embs).sum(dim=0)
79 | else:
80 | raise ValueError('Both continuous_inputs and categorical_embs are None')
81 | else:
82 | x = self.mlp(continuous_inputs)
83 | if categorical_embs is not None:
84 | x = x + torch.stack(categorical_embs).sum(dim=0)
85 | return x
86 |
--------------------------------------------------------------------------------
/smart/layers/mlp_layer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 |
5 | from smart.utils import weight_init
6 |
7 |
8 | class MLPLayer(nn.Module):
9 |
10 | def __init__(self,
11 | input_dim: int,
12 | hidden_dim: int,
13 | output_dim: int) -> None:
14 | super(MLPLayer, self).__init__()
15 | self.mlp = nn.Sequential(
16 | nn.Linear(input_dim, hidden_dim),
17 | nn.LayerNorm(hidden_dim),
18 | nn.ReLU(inplace=True),
19 | nn.Linear(hidden_dim, output_dim),
20 | )
21 | self.apply(weight_init)
22 |
23 | def forward(self, x: torch.Tensor) -> torch.Tensor:
24 | return self.mlp(x)
25 |
--------------------------------------------------------------------------------
/smart/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from smart.metrics.average_meter import AverageMeter
3 | from smart.metrics.min_ade import minADE
4 | from smart.metrics.min_fde import minFDE
5 | from smart.metrics.next_token_cls import TokenCls
6 |
--------------------------------------------------------------------------------
/smart/metrics/average_meter.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torchmetrics import Metric
4 |
5 |
6 | class AverageMeter(Metric):
7 |
8 | def __init__(self, **kwargs) -> None:
9 | super(AverageMeter, self).__init__(**kwargs)
10 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
11 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
12 |
13 | def update(self, val: torch.Tensor) -> None:
14 | self.sum += val.sum()
15 | self.count += val.numel()
16 |
17 | def compute(self) -> torch.Tensor:
18 | return self.sum / self.count
19 |
--------------------------------------------------------------------------------
/smart/metrics/min_ade.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Optional
3 |
4 | import torch
5 | from torchmetrics import Metric
6 |
7 | from smart.metrics.utils import topk
8 | from smart.metrics.utils import valid_filter
9 |
10 |
11 | class minMultiADE(Metric):
12 |
13 | def __init__(self,
14 | max_guesses: int = 6,
15 | **kwargs) -> None:
16 | super(minMultiADE, self).__init__(**kwargs)
17 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
18 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
19 | self.max_guesses = max_guesses
20 |
21 | def update(self,
22 | pred: torch.Tensor,
23 | target: torch.Tensor,
24 | prob: Optional[torch.Tensor] = None,
25 | valid_mask: Optional[torch.Tensor] = None,
26 | keep_invalid_final_step: bool = True,
27 | min_criterion: str = 'FDE') -> None:
28 | pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
29 | pred_topk, _ = topk(self.max_guesses, pred, prob)
30 | if min_criterion == 'FDE':
31 | inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
32 | inds_best = torch.norm(
33 | pred_topk[torch.arange(pred.size(0)), :, inds_last] -
34 | target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
35 | self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
36 | valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
37 | elif min_criterion == 'ADE':
38 | self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
39 | valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
40 | else:
41 | raise ValueError('{} is not a valid criterion'.format(min_criterion))
42 | self.count += pred.size(0)
43 |
44 | def compute(self) -> torch.Tensor:
45 | return self.sum / self.count
46 |
47 |
48 | class minADE(Metric):
49 |
50 | def __init__(self,
51 | max_guesses: int = 6,
52 | **kwargs) -> None:
53 | super(minADE, self).__init__(**kwargs)
54 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
55 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
56 | self.max_guesses = max_guesses
57 | self.eval_timestep = 70
58 |
59 | def update(self,
60 | pred: torch.Tensor,
61 | target: torch.Tensor,
62 | prob: Optional[torch.Tensor] = None,
63 | valid_mask: Optional[torch.Tensor] = None,
64 | keep_invalid_final_step: bool = True,
65 | min_criterion: str = 'ADE') -> None:
66 | # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
67 | # pred_topk, _ = topk(self.max_guesses, pred, prob)
68 | # if min_criterion == 'FDE':
69 | # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
70 | # inds_best = torch.norm(
71 | # pred[torch.arange(pred.size(0)), :, inds_last] -
72 | # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
73 | # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
74 | # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
75 | # elif min_criterion == 'ADE':
76 | # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
77 | # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
78 | # else:
79 | # raise ValueError('{} is not a valid criterion'.format(min_criterion))
80 | eval_timestep = min(self.eval_timestep, pred.shape[1])
81 | self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
82 | self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
83 |
84 | def compute(self) -> torch.Tensor:
85 | return self.sum / self.count
86 |
--------------------------------------------------------------------------------
/smart/metrics/min_fde.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 | from smart.metrics.utils import topk
7 | from smart.metrics.utils import valid_filter
8 |
9 |
10 | class minMultiFDE(Metric):
11 |
12 | def __init__(self,
13 | max_guesses: int = 6,
14 | **kwargs) -> None:
15 | super(minMultiFDE, self).__init__(**kwargs)
16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
18 | self.max_guesses = max_guesses
19 |
20 | def update(self,
21 | pred: torch.Tensor,
22 | target: torch.Tensor,
23 | prob: Optional[torch.Tensor] = None,
24 | valid_mask: Optional[torch.Tensor] = None,
25 | keep_invalid_final_step: bool = True) -> None:
26 | pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
27 | pred_topk, _ = topk(self.max_guesses, pred, prob)
28 | inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
29 | self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
30 | target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
31 | p=2, dim=-1).min(dim=-1)[0].sum()
32 | self.count += pred.size(0)
33 |
34 | def compute(self) -> torch.Tensor:
35 | return self.sum / self.count
36 |
37 |
38 | class minFDE(Metric):
39 |
40 | def __init__(self,
41 | max_guesses: int = 6,
42 | **kwargs) -> None:
43 | super(minFDE, self).__init__(**kwargs)
44 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
45 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
46 | self.max_guesses = max_guesses
47 | self.eval_timestep = 70
48 |
49 | def update(self,
50 | pred: torch.Tensor,
51 | target: torch.Tensor,
52 | prob: Optional[torch.Tensor] = None,
53 | valid_mask: Optional[torch.Tensor] = None,
54 | keep_invalid_final_step: bool = True) -> None:
55 | eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
56 | self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
57 | valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
58 | self.count += valid_mask[:, eval_timestep-1].sum()
59 |
60 | def compute(self) -> torch.Tensor:
61 | return self.sum / self.count
62 |
--------------------------------------------------------------------------------
/smart/metrics/next_token_cls.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 | from smart.metrics.utils import topk
7 | from smart.metrics.utils import valid_filter
8 |
9 |
10 | class TokenCls(Metric):
11 |
12 | def __init__(self,
13 | max_guesses: int = 6,
14 | **kwargs) -> None:
15 | super(TokenCls, self).__init__(**kwargs)
16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
18 | self.max_guesses = max_guesses
19 |
20 | def update(self,
21 | pred: torch.Tensor,
22 | target: torch.Tensor,
23 | valid_mask: Optional[torch.Tensor] = None) -> None:
24 | target = target[..., None]
25 | acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
26 | self.sum += acc.sum()
27 | self.count += valid_mask.sum()
28 |
29 | def compute(self) -> torch.Tensor:
30 | return self.sum / self.count
31 |
--------------------------------------------------------------------------------
/smart/metrics/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from torch_scatter import gather_csr
5 | from torch_scatter import segment_csr
6 |
7 |
8 | def topk(
9 | max_guesses: int,
10 | pred: torch.Tensor,
11 | prob: Optional[torch.Tensor] = None,
12 | ptr: Optional[torch.Tensor] = None,
13 | joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
14 | max_guesses = min(max_guesses, pred.size(1))
15 | if max_guesses == pred.size(1):
16 | if prob is not None:
17 | prob = prob / prob.sum(dim=-1, keepdim=True)
18 | else:
19 | prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
20 | return pred, prob
21 | else:
22 | if prob is not None:
23 | if joint:
24 | if ptr is None:
25 | inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
26 | k=max_guesses, dim=-1, largest=True, sorted=True)[1]
27 | inds_topk = inds_topk.repeat(pred.size(0), 1)
28 | else:
29 | inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
30 | reduce='mean'),
31 | k=max_guesses, dim=-1, largest=True, sorted=True)[1]
32 | inds_topk = gather_csr(src=inds_topk, indptr=ptr)
33 | else:
34 | inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
35 | pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
36 | prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
37 | prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
38 | else:
39 | pred_topk = pred[:, :max_guesses]
40 | prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
41 | return pred_topk, prob_topk
42 |
43 |
44 | def topkind(
45 | max_guesses: int,
46 | pred: torch.Tensor,
47 | prob: Optional[torch.Tensor] = None,
48 | ptr: Optional[torch.Tensor] = None,
49 | joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
50 | max_guesses = min(max_guesses, pred.size(1))
51 | if max_guesses == pred.size(1):
52 | if prob is not None:
53 | prob = prob / prob.sum(dim=-1, keepdim=True)
54 | else:
55 | prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
56 | return pred, prob, None
57 | else:
58 | if prob is not None:
59 | if joint:
60 | if ptr is None:
61 | inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
62 | k=max_guesses, dim=-1, largest=True, sorted=True)[1]
63 | inds_topk = inds_topk.repeat(pred.size(0), 1)
64 | else:
65 | inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
66 | reduce='mean'),
67 | k=max_guesses, dim=-1, largest=True, sorted=True)[1]
68 | inds_topk = gather_csr(src=inds_topk, indptr=ptr)
69 | else:
70 | inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
71 | pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
72 | prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
73 | prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
74 | else:
75 | pred_topk = pred[:, :max_guesses]
76 | prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
77 | return pred_topk, prob_topk, inds_topk
78 |
79 |
80 | def valid_filter(
81 | pred: torch.Tensor,
82 | target: torch.Tensor,
83 | prob: Optional[torch.Tensor] = None,
84 | valid_mask: Optional[torch.Tensor] = None,
85 | ptr: Optional[torch.Tensor] = None,
86 | keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
87 | torch.Tensor, torch.Tensor]:
88 | if valid_mask is None:
89 | valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
90 | if keep_invalid_final_step:
91 | filter_mask = valid_mask.any(dim=-1)
92 | else:
93 | filter_mask = valid_mask[:, -1]
94 | pred = pred[filter_mask]
95 | target = target[filter_mask]
96 | if prob is not None:
97 | prob = prob[filter_mask]
98 | valid_mask = valid_mask[filter_mask]
99 | if ptr is not None:
100 | num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
101 | ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
102 | torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
103 | else:
104 | ptr = target.new_tensor([0, target.size(0)])
105 | return pred, target, prob, valid_mask, ptr
106 |
107 |
108 | def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
109 | """
110 |
111 | Args:
112 | pred_trajs (batch_size, num_modes, num_timestamps, 7)
113 | pred_scores (batch_size, num_modes):
114 | dist_thresh (float):
115 | num_ret_modes (int, optional): Defaults to 6.
116 |
117 | Returns:
118 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
119 | ret_scores (batch_size, num_ret_modes)
120 | ret_idxs (batch_size, num_ret_modes)
121 | """
122 | batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
123 | pred_goals = pred_trajs[:, :, -1, :]
124 | dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
125 | nearby_neighbor = dist < dist_thresh
126 | pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
127 |
128 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
129 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
130 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
131 | sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
132 | sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
133 |
134 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
135 | point_cover_mask = (dist < dist_thresh)
136 |
137 | point_val = sorted_pred_scores.clone() # (batch_size, N)
138 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
139 |
140 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
141 | ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
142 | ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
143 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
144 |
145 | for k in range(num_ret_modes):
146 | cur_idx = point_val.argmax(dim=-1) # (batch_size)
147 | ret_idxs[:, k] = cur_idx
148 |
149 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
150 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
151 | point_val_selected[bs_idxs, cur_idx] = -1
152 | point_val += point_val_selected
153 |
154 | ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
155 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
156 |
157 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
158 |
159 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
160 | return ret_trajs, ret_scores, ret_idxs
161 |
162 |
163 | def batch_nms(pred_trajs, pred_scores,
164 | dist_thresh, num_ret_modes=6,
165 | mode='static', speed=None):
166 | """
167 |
168 | Args:
169 | pred_trajs (batch_size, num_modes, num_timestamps, 7)
170 | pred_scores (batch_size, num_modes):
171 | dist_thresh (float):
172 | num_ret_modes (int, optional): Defaults to 6.
173 |
174 | Returns:
175 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
176 | ret_scores (batch_size, num_ret_modes)
177 | ret_idxs (batch_size, num_ret_modes)
178 | """
179 | batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
180 |
181 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
182 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
183 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
184 | sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
185 | sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
186 |
187 | if mode == "speed":
188 | scale = torch.ones(batch_size).to(sorted_pred_goals.device)
189 | lon_dist_thresh = 4 * scale
190 | lat_dist_thresh = 0.5 * scale
191 | lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
192 | lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
193 | point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
194 | else:
195 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
196 | point_cover_mask = (dist < dist_thresh)
197 |
198 | point_val = sorted_pred_scores.clone() # (batch_size, N)
199 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
200 |
201 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
202 | ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
203 | ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
204 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
205 |
206 | for k in range(num_ret_modes):
207 | cur_idx = point_val.argmax(dim=-1) # (batch_size)
208 | ret_idxs[:, k] = cur_idx
209 |
210 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
211 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
212 | point_val_selected[bs_idxs, cur_idx] = -1
213 | point_val += point_val_selected
214 |
215 | ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
216 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
217 |
218 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
219 |
220 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
221 | return ret_trajs, ret_scores, ret_idxs
222 |
223 |
224 | def batch_nms_token(pred_trajs, pred_scores,
225 | dist_thresh, num_ret_modes=6,
226 | mode='static', speed=None):
227 | """
228 | Args:
229 | pred_trajs (batch_size, num_modes, num_timestamps, 7)
230 | pred_scores (batch_size, num_modes):
231 | dist_thresh (float):
232 | num_ret_modes (int, optional): Defaults to 6.
233 |
234 | Returns:
235 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
236 | ret_scores (batch_size, num_ret_modes)
237 | ret_idxs (batch_size, num_ret_modes)
238 | """
239 | batch_size, num_modes, num_feat_dim = pred_trajs.shape
240 |
241 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
242 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
243 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
244 | sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
245 |
246 | if mode == "nearby":
247 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
248 | values, indices = torch.topk(dist, 5, dim=-1, largest=False)
249 | thresh_hold = values[..., -1]
250 | point_cover_mask = dist < thresh_hold[..., None]
251 | else:
252 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
253 | point_cover_mask = (dist < dist_thresh)
254 |
255 | point_val = sorted_pred_scores.clone() # (batch_size, N)
256 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
257 |
258 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
259 | ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
260 | ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
261 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
262 |
263 | for k in range(num_ret_modes):
264 | cur_idx = point_val.argmax(dim=-1) # (batch_size)
265 | ret_idxs[:, k] = cur_idx
266 |
267 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
268 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
269 | point_val_selected[bs_idxs, cur_idx] = -1
270 | point_val += point_val_selected
271 |
272 | ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
273 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
274 |
275 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
276 |
277 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
278 | return ret_goals, ret_scores, ret_idxs
279 |
--------------------------------------------------------------------------------
/smart/model/__init__.py:
--------------------------------------------------------------------------------
1 | from smart.model.smart import SMART
2 |
--------------------------------------------------------------------------------
/smart/model/smart.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import pytorch_lightning as pl
3 | import torch
4 | import torch.nn as nn
5 | from torch_geometric.data import Batch
6 | from torch_geometric.data import HeteroData
7 | from smart.metrics import minADE
8 | from smart.metrics import minFDE
9 | from smart.metrics import TokenCls
10 | from smart.modules import SMARTDecoder
11 | from torch.optim.lr_scheduler import LambdaLR
12 | import math
13 | import numpy as np
14 | import pickle
15 | from collections import defaultdict
16 | import os
17 | from waymo_open_dataset.protos import sim_agents_submission_pb2
18 |
19 |
20 | def cal_polygon_contour(x, y, theta, width, length):
21 | left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
22 | left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
23 | left_front = (left_front_x, left_front_y)
24 |
25 | right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
26 | right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
27 | right_front = (right_front_x, right_front_y)
28 |
29 | right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
30 | right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
31 | right_back = (right_back_x, right_back_y)
32 |
33 | left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
34 | left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
35 | left_back = (left_back_x, left_back_y)
36 | polygon_contour = [left_front, right_front, right_back, left_back]
37 |
38 | return polygon_contour
39 |
40 |
41 | def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene:
42 | states = states.numpy()
43 | simulated_trajectories = []
44 | for i_object in range(len(object_ids)):
45 | simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(
46 | center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],
47 | center_z=states[i_object, :, 2], heading=states[i_object, :, 3],
48 | object_id=object_ids[i_object].item()
49 | ))
50 | return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories)
51 |
52 |
53 | class SMART(pl.LightningModule):
54 |
55 | def __init__(self, model_config) -> None:
56 | super(SMART, self).__init__()
57 | self.save_hyperparameters()
58 | self.model_config = model_config
59 | self.warmup_steps = model_config.warmup_steps
60 | self.lr = model_config.lr
61 | self.total_steps = model_config.total_steps
62 | self.dataset = model_config.dataset
63 | self.input_dim = model_config.input_dim
64 | self.hidden_dim = model_config.hidden_dim
65 | self.output_dim = model_config.output_dim
66 | self.output_head = model_config.output_head
67 | self.num_historical_steps = model_config.num_historical_steps
68 | self.num_future_steps = model_config.decoder.num_future_steps
69 | self.num_freq_bands = model_config.num_freq_bands
70 | self.vis_map = False
71 | self.noise = True
72 | module_dir = os.path.dirname(os.path.dirname(__file__))
73 | self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
74 | self.init_map_token()
75 | self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl')
76 | token_data = self.get_trajectory_token()
77 | self.encoder = SMARTDecoder(
78 | dataset=model_config.dataset,
79 | input_dim=model_config.input_dim,
80 | hidden_dim=model_config.hidden_dim,
81 | num_historical_steps=model_config.num_historical_steps,
82 | num_freq_bands=model_config.num_freq_bands,
83 | num_heads=model_config.num_heads,
84 | head_dim=model_config.head_dim,
85 | dropout=model_config.dropout,
86 | num_map_layers=model_config.decoder.num_map_layers,
87 | num_agent_layers=model_config.decoder.num_agent_layers,
88 | pl2pl_radius=model_config.decoder.pl2pl_radius,
89 | pl2a_radius=model_config.decoder.pl2a_radius,
90 | a2a_radius=model_config.decoder.a2a_radius,
91 | time_span=model_config.decoder.time_span,
92 | map_token={'traj_src': self.map_token['traj_src']},
93 | token_data=token_data,
94 | token_size=model_config.decoder.token_size
95 | )
96 | self.minADE = minADE(max_guesses=1)
97 | self.minFDE = minFDE(max_guesses=1)
98 | self.TokenCls = TokenCls(max_guesses=1)
99 |
100 | self.test_predictions = dict()
101 | self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
102 | self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
103 | self.inference_token = False
104 | self.rollout_num = 1
105 |
106 | def get_trajectory_token(self):
107 | token_data = pickle.load(open(self.token_path, 'rb'))
108 | self.trajectory_token = token_data['token']
109 | self.trajectory_token_traj = token_data['traj']
110 | self.trajectory_token_all = token_data['token_all']
111 | return token_data
112 |
113 | def init_map_token(self):
114 | self.argmin_sample_len = 3
115 | map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
116 | self.map_token = {'traj_src': map_token_traj['traj_src'], }
117 | traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
118 | self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
119 | indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
120 | self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
121 | self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
122 | self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)
123 |
124 | def forward(self, data: HeteroData):
125 | res = self.encoder(data)
126 | return res
127 |
128 | def inference(self, data: HeteroData):
129 | res = self.encoder.inference(data)
130 | return res
131 |
132 | def maybe_autocast(self, dtype=torch.float16):
133 | enable_autocast = self.device != torch.device("cpu")
134 |
135 | if enable_autocast:
136 | return torch.cuda.amp.autocast(dtype=dtype)
137 | else:
138 | return contextlib.nullcontext()
139 |
140 | def training_step(self,
141 | data,
142 | batch_idx):
143 | data = self.match_token_map(data)
144 | data = self.sample_pt_pred(data)
145 | if isinstance(data, Batch):
146 | data['agent']['av_index'] += data['agent']['ptr'][:-1]
147 | pred = self(data)
148 | next_token_prob = pred['next_token_prob']
149 | next_token_idx_gt = pred['next_token_idx_gt']
150 | next_token_eval_mask = pred['next_token_eval_mask']
151 | cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
152 | loss = cls_loss
153 | self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
154 | self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
155 | return loss
156 |
157 | def validation_step(self,
158 | data,
159 | batch_idx):
160 | data = self.match_token_map(data)
161 | data = self.sample_pt_pred(data)
162 | if isinstance(data, Batch):
163 | data['agent']['av_index'] += data['agent']['ptr'][:-1]
164 | pred = self(data)
165 | next_token_idx = pred['next_token_idx']
166 | next_token_idx_gt = pred['next_token_idx_gt']
167 | next_token_eval_mask = pred['next_token_eval_mask']
168 | next_token_prob = pred['next_token_prob']
169 | cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
170 | loss = cls_loss
171 | self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
172 | valid_mask=next_token_eval_mask[next_token_eval_mask])
173 | self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
174 | self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
175 |
176 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3)
177 | if self.inference_token:
178 | pred = self.inference(data)
179 | pos_a = pred['pos_a']
180 | gt = pred['gt']
181 | valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:]
182 | pred_traj = pred['pred_traj']
183 | # next_token_idx = pred['next_token_idx'][..., None]
184 | # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:]
185 | # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
186 | # next_token_eval_mask[:, 1:] = False
187 | # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
188 | # valid_mask=next_token_eval_mask[next_token_eval_mask])
189 | # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
190 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]
191 |
192 | self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
193 | self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
194 | # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())
195 |
196 | self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
197 | self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
198 |
199 | def on_validation_start(self):
200 | self.gt = []
201 | self.pred = []
202 | self.scenario_rollouts = []
203 | self.batch_metric = defaultdict(list)
204 |
205 | def configure_optimizers(self):
206 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
207 |
208 | def lr_lambda(current_step):
209 | if current_step + 1 < self.warmup_steps:
210 | return float(current_step + 1) / float(max(1, self.warmup_steps))
211 | return max(
212 | 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
213 | )
214 |
215 | lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
216 | return [optimizer], [lr_scheduler]
217 |
218 | def load_params_from_file(self, filename, logger, to_cpu=False):
219 | if not os.path.isfile(filename):
220 | raise FileNotFoundError
221 |
222 | logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
223 | loc_type = torch.device('cpu') if to_cpu else None
224 | checkpoint = torch.load(filename, map_location=loc_type)
225 | model_state_disk = checkpoint['state_dict']
226 |
227 | version = checkpoint.get("version", None)
228 | if version is not None:
229 | logger.info('==> Checkpoint trained from version: %s' % version)
230 |
231 | logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
232 | model_state = self.state_dict()
233 | model_state_disk_filter = {}
234 | for key, val in model_state_disk.items():
235 | if key in model_state and model_state_disk[key].shape == model_state[key].shape:
236 | model_state_disk_filter[key] = val
237 | else:
238 | if key not in model_state:
239 | print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
240 | else:
241 | print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')
242 |
243 | model_state_disk = model_state_disk_filter
244 |
245 | missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)
246 |
247 | logger.info(f'Missing keys: {missing_keys}')
248 | logger.info(f'The number of missing keys: {len(missing_keys)}')
249 | logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
250 | logger.info('==> Done (total keys %d)' % (len(model_state)))
251 |
252 | epoch = checkpoint.get('epoch', -1)
253 | it = checkpoint.get('it', 0.0)
254 |
255 | return it, epoch
256 |
257 | def match_token_map(self, data):
258 | traj_pos = data['map_save']['traj_pos'].to(torch.float)
259 | traj_theta = data['map_save']['traj_theta'].to(torch.float)
260 | pl_idx_list = data['map_save']['pl_idx_list']
261 | token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
262 | token_src = self.map_token['traj_src'].to(traj_pos.device)
263 | max_traj_len = self.map_token['traj_src'].shape[1]
264 | pl_num = traj_pos.shape[0]
265 |
266 | pt_token_pos = traj_pos[:, 0, :].clone()
267 | pt_token_orientation = traj_theta.clone()
268 | cos, sin = traj_theta.cos(), traj_theta.sin()
269 | rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
270 | rot_mat[..., 0, 0] = cos
271 | rot_mat[..., 0, 1] = -sin
272 | rot_mat[..., 1, 0] = sin
273 | rot_mat[..., 1, 1] = cos
274 | traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
275 | distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
276 | pt_token_id = torch.argmin(distance, dim=1)
277 |
278 | if self.noise:
279 | topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
280 | sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
281 | pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
282 |
283 | cos, sin = traj_theta.cos(), traj_theta.sin()
284 | rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
285 | rot_mat[..., 0, 0] = cos
286 | rot_mat[..., 0, 1] = sin
287 | rot_mat[..., 1, 0] = -sin
288 | rot_mat[..., 1, 1] = cos
289 | token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
290 | rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
291 | token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
292 |
293 | pl_idx_full = pl_idx_list.clone()
294 | token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
295 | count_nums = []
296 | for pl in pl_idx_full.unique():
297 | pt = token2pl[0, token2pl[1, :] == pl]
298 | left_side = (data['pt_token']['side'][pt] == 0).sum()
299 | right_side = (data['pt_token']['side'][pt] == 1).sum()
300 | center_side = (data['pt_token']['side'][pt] == 2).sum()
301 | count_nums.append(torch.Tensor([left_side, right_side, center_side]))
302 | count_nums = torch.stack(count_nums, dim=0)
303 | num_polyline = int(count_nums.max().item())
304 | traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
305 | idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
306 | idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) #
307 | counts_num_expanded = count_nums.unsqueeze(-1)
308 | mask_update = idx_matrix < counts_num_expanded
309 | traj_mask[mask_update] = True
310 |
311 | data['pt_token']['traj_mask'] = traj_mask
312 | data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
313 | device=traj_pos.device, dtype=torch.float)], dim=-1)
314 | data['pt_token']['orientation'] = pt_token_orientation
315 | data['pt_token']['height'] = data['pt_token']['position'][:, -1]
316 | data[('pt_token', 'to', 'map_polygon')] = {}
317 | data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
318 | data['pt_token']['token_idx'] = pt_token_id
319 | return data
320 |
321 | def sample_pt_pred(self, data):
322 | traj_mask = data['pt_token']['traj_mask']
323 | raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
324 | masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
325 | masked_pt_index = torch.sort(masked_pt_index, -1)[0]
326 | pt_valid_mask = traj_mask.clone()
327 | pt_valid_mask.scatter_(2, masked_pt_index, False)
328 | pt_pred_mask = traj_mask.clone()
329 | pt_pred_mask.scatter_(2, masked_pt_index, False)
330 | tmp_mask = pt_pred_mask.clone()
331 | tmp_mask[:, :, :] = True
332 | tmp_mask.scatter_(2, masked_pt_index-1, False)
333 | pt_pred_mask.masked_fill_(tmp_mask, False)
334 | pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
335 | pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
336 |
337 | data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
338 | data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
339 | data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]
340 |
341 | return data
342 |
--------------------------------------------------------------------------------
/smart/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from smart.modules.smart_decoder import SMARTDecoder
2 | from smart.modules.map_decoder import SMARTMapDecoder
3 | from smart.modules.agent_decoder import SMARTAgentDecoder
4 |
--------------------------------------------------------------------------------
/smart/modules/agent_decoder.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from typing import Dict, Mapping, Optional
3 | import torch
4 | import torch.nn as nn
5 | from smart.layers import MLPLayer
6 | from smart.layers.attention_layer import AttentionLayer
7 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
8 | from torch_cluster import radius, radius_graph
9 | from torch_geometric.data import Batch, HeteroData
10 | from torch_geometric.utils import dense_to_sparse, subgraph
11 | from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle
12 | import math
13 |
14 |
15 | def cal_polygon_contour(x, y, theta, width, length):
16 | left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
17 | left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
18 | left_front = (left_front_x, left_front_y)
19 |
20 | right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
21 | right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
22 | right_front = (right_front_x, right_front_y)
23 |
24 | right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
25 | right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
26 | right_back = (right_back_x, right_back_y)
27 |
28 | left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
29 | left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
30 | left_back = (left_back_x, left_back_y)
31 | polygon_contour = [left_front, right_front, right_back, left_back]
32 |
33 | return polygon_contour
34 |
35 |
36 | class SMARTAgentDecoder(nn.Module):
37 |
38 | def __init__(self,
39 | dataset: str,
40 | input_dim: int,
41 | hidden_dim: int,
42 | num_historical_steps: int,
43 | time_span: Optional[int],
44 | pl2a_radius: float,
45 | a2a_radius: float,
46 | num_freq_bands: int,
47 | num_layers: int,
48 | num_heads: int,
49 | head_dim: int,
50 | dropout: float,
51 | token_data: Dict,
52 | token_size=512) -> None:
53 | super(SMARTAgentDecoder, self).__init__()
54 | self.dataset = dataset
55 | self.input_dim = input_dim
56 | self.hidden_dim = hidden_dim
57 | self.num_historical_steps = num_historical_steps
58 | self.time_span = time_span if time_span is not None else num_historical_steps
59 | self.pl2a_radius = pl2a_radius
60 | self.a2a_radius = a2a_radius
61 | self.num_freq_bands = num_freq_bands
62 | self.num_layers = num_layers
63 | self.num_heads = num_heads
64 | self.head_dim = head_dim
65 | self.dropout = dropout
66 |
67 | input_dim_x_a = 2
68 | input_dim_r_t = 4
69 | input_dim_r_pt2a = 3
70 | input_dim_r_a2a = 3
71 | input_dim_token = 8
72 |
73 | self.type_a_emb = nn.Embedding(4, hidden_dim)
74 | self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
75 |
76 | self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
77 | self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
78 | self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
79 | num_freq_bands=num_freq_bands)
80 | self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
81 | num_freq_bands=num_freq_bands)
82 | self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
83 | self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
84 | self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
85 | self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim)
86 |
87 | self.t_attn_layers = nn.ModuleList(
88 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
89 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
90 | )
91 | self.pt2a_attn_layers = nn.ModuleList(
92 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
93 | bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
94 | )
95 | self.a2a_attn_layers = nn.ModuleList(
96 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
97 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
98 | )
99 | self.token_size = token_size
100 | self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
101 | output_dim=self.token_size)
102 | self.trajectory_token = token_data['token']
103 | self.trajectory_token_traj = token_data['traj']
104 | self.trajectory_token_all = token_data['token_all']
105 | self.apply(weight_init)
106 | self.shift = 5
107 | self.beam_size = 5
108 | self.hist_mask = True
109 |
110 | def transform_rel(self, token_traj, prev_pos, prev_heading=None):
111 | if prev_heading is None:
112 | diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
113 | prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
114 |
115 | num_agent, num_step, traj_num, traj_dim = token_traj.shape
116 | cos, sin = prev_heading.cos(), prev_heading.sin()
117 | rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
118 | rot_mat[:, :, 0, 0] = cos
119 | rot_mat[:, :, 0, 1] = -sin
120 | rot_mat[:, :, 1, 0] = sin
121 | rot_mat[:, :, 1, 1] = cos
122 | agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
123 | agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
124 | return agent_pred_rel
125 |
126 | def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False):
127 | num_agent, num_step, traj_dim = pos_a.shape
128 | motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
129 | pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
130 |
131 | agent_type = data['agent']['type']
132 | veh_mask = (agent_type == 0)
133 | cyc_mask = (agent_type == 2)
134 | ped_mask = (agent_type == 1)
135 | trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
136 | self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1))
137 | trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
138 | self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
139 | trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
140 | self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
141 |
142 | if inference:
143 | agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
144 | trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(
145 | torch.float)
146 | trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(
147 | torch.float)
148 | trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(
149 | torch.float)
150 | agent_token_traj_all[veh_mask] = torch.cat(
151 | [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
152 | agent_token_traj_all[ped_mask] = torch.cat(
153 | [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
154 | agent_token_traj_all[cyc_mask] = torch.cat(
155 | [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
156 |
157 | agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
158 | agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
159 | agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
160 | agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
161 |
162 | agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device)
163 | agent_token_traj[veh_mask] = trajectory_token_veh
164 | agent_token_traj[ped_mask] = trajectory_token_ped
165 | agent_token_traj[cyc_mask] = trajectory_token_cyc
166 |
167 | vel = data['agent']['token_velocity']
168 |
169 | categorical_embs = [
170 | self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step,
171 | dim=0),
172 |
173 | self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave(
174 | repeats=num_step,
175 | dim=0)
176 | ]
177 | feature_a = torch.stack(
178 | [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
179 | angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
180 | ], dim=-1)
181 |
182 | x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
183 | categorical_embs=categorical_embs)
184 | x_a = x_a.view(-1, num_step, self.hidden_dim)
185 |
186 | feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
187 | feat_a = self.fusion_emb(feat_a)
188 |
189 | if inference:
190 | return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs
191 | else:
192 | return feat_a, agent_token_traj
193 |
194 | def agent_predict_next(self, data, agent_category, feat_a):
195 | num_agent, num_step, traj_dim = data['agent']['token_pos'].shape
196 | agent_type = data['agent']['type']
197 | veh_mask = (agent_type == 0) # * agent_category==3
198 | cyc_mask = (agent_type == 2) # * agent_category==3
199 | ped_mask = (agent_type == 1) # * agent_category==3
200 | token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device)
201 | token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
202 | token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
203 | token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
204 | return token_res
205 |
206 | def agent_predict_next_inf(self, data, agent_category, feat_a):
207 | num_agent, traj_dim = feat_a.shape
208 | agent_type = data['agent']['type']
209 |
210 | veh_mask = (agent_type == 0) # * agent_category==3
211 | cyc_mask = (agent_type == 2) # * agent_category==3
212 | ped_mask = (agent_type == 1) # * agent_category==3
213 |
214 | token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device)
215 | token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
216 | token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
217 | token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
218 |
219 | return token_res
220 |
221 | def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None):
222 | pos_t = pos_a.reshape(-1, self.input_dim)
223 | head_t = head_a.reshape(-1)
224 | head_vector_t = head_vector_a.reshape(-1, 2)
225 | hist_mask = mask.clone()
226 |
227 | if self.hist_mask and self.training:
228 | hist_mask[
229 | torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
230 | mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
231 | elif inference_mask is not None:
232 | mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
233 | else:
234 | mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
235 |
236 | edge_index_t = dense_to_sparse(mask_t)[0]
237 | edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]]
238 | edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift]
239 | rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
240 | rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
241 | r_t = torch.stack(
242 | [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
243 | angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
244 | rel_head_t,
245 | edge_index_t[0] - edge_index_t[1]], dim=-1)
246 | r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
247 | return edge_index_t, r_t
248 |
249 | def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s):
250 | pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
251 | head_s = head_a.transpose(0, 1).reshape(-1)
252 | head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
253 | edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
254 | max_num_neighbors=300)
255 | edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
256 | rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
257 | rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
258 | r_a2a = torch.stack(
259 | [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
260 | angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
261 | rel_head_a2a], dim=-1)
262 | r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
263 | return edge_index_a2a, r_a2a
264 |
265 | def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask,
266 | batch_s, batch_pl):
267 | mask_pl2a = mask.clone()
268 | mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
269 | pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
270 | head_s = head_a.transpose(0, 1).reshape(-1)
271 | head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
272 | pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
273 | orient_pl = data['pt_token']['orientation'].contiguous()
274 | pos_pl = pos_pl.repeat(num_step, 1)
275 | orient_pl = orient_pl.repeat(num_step)
276 | edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
277 | batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
278 | edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
279 | rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
280 | rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
281 | r_pl2a = torch.stack(
282 | [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
283 | angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
284 | rel_orient_pl2a], dim=-1)
285 | r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
286 | return edge_index_pl2a, r_pl2a
287 |
288 | def forward(self,
289 | data: HeteroData,
290 | map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
291 | pos_a = data['agent']['token_pos']
292 | head_a = data['agent']['token_heading']
293 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
294 | num_agent, num_step, traj_dim = pos_a.shape
295 | agent_category = data['agent']['category']
296 | agent_token_index = data['agent']['token_idx']
297 | feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index,
298 | pos_a, head_vector_a)
299 |
300 | agent_valid_mask = data['agent']['agent_valid_mask'].clone()
301 | # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
302 | # agent_valid_mask[~eval_mask] = False
303 | mask = agent_valid_mask
304 | edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask)
305 |
306 | if isinstance(data, Batch):
307 | batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
308 | for t in range(num_step)], dim=0)
309 | batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
310 | for t in range(num_step)], dim=0)
311 | else:
312 | batch_s = torch.arange(num_step,
313 | device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
314 | batch_pl = torch.arange(num_step,
315 | device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
316 |
317 | mask_s = mask.transpose(0, 1).reshape(-1)
318 | edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s)
319 | mask[agent_category != 3] = False
320 | edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
321 | head_vector_a, mask, batch_s, batch_pl)
322 |
323 | for i in range(self.num_layers):
324 | feat_a = feat_a.reshape(-1, self.hidden_dim)
325 | feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
326 | feat_a = feat_a.reshape(-1, num_step,
327 | self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
328 | feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
329 | repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
330 | -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
331 | feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
332 | feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
333 |
334 | num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape
335 | next_token_prob = self.token_predict_head(feat_a)
336 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
337 | _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
338 |
339 | next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
340 | next_token_eval_mask = mask.clone()
341 | next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
342 | next_token_eval_mask[:, -1] = False
343 |
344 | return {'x_a': feat_a,
345 | 'next_token_idx': next_token_idx,
346 | 'next_token_prob': next_token_prob,
347 | 'next_token_idx_gt': next_token_index_gt,
348 | 'next_token_eval_mask': next_token_eval_mask,
349 | }
350 |
351 | def inference(self,
352 | data: HeteroData,
353 | map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
354 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
355 | pos_a = data['agent']['token_pos'].clone()
356 | head_a = data['agent']['token_heading'].clone()
357 | num_agent, num_step, traj_dim = pos_a.shape
358 | pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
359 | head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
360 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
361 |
362 | agent_valid_mask = data['agent']['agent_valid_mask'].clone()
363 | agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
364 | agent_valid_mask[~eval_mask] = False
365 | agent_token_index = data['agent']['token_idx']
366 | agent_category = data['agent']['category']
367 | feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(
368 | data,
369 | agent_category,
370 | agent_token_index,
371 | pos_a,
372 | head_vector_a,
373 | inference=True)
374 |
375 | agent_type = data["agent"]["type"]
376 | veh_mask = (agent_type == 0) # * agent_category==3
377 | cyc_mask = (agent_type == 2) # * agent_category==3
378 | ped_mask = (agent_type == 1) # * agent_category==3
379 | av_mask = data["agent"]["av_index"]
380 |
381 | self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps
382 | pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device)
383 | pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device)
384 | pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device)
385 | next_token_idx_list = []
386 | mask = agent_valid_mask.clone()
387 | feat_a_t_dict = {}
388 | for t in range(self.num_recurrent_steps_val // self.shift):
389 | if t == 0:
390 | inference_mask = mask.clone()
391 | inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
392 | else:
393 | inference_mask = torch.zeros_like(mask)
394 | inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
395 | edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask)
396 | if isinstance(data, Batch):
397 | batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
398 | for t in range(num_step)], dim=0)
399 | batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
400 | for t in range(num_step)], dim=0)
401 | else:
402 | batch_s = torch.arange(num_step,
403 | device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
404 | batch_pl = torch.arange(num_step,
405 | device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
406 | # In the inference stage, we only infer the current stage for recurrent
407 | edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
408 | head_vector_a,
409 | inference_mask, batch_s,
410 | batch_pl)
411 | mask_s = inference_mask.transpose(0, 1).reshape(-1)
412 | edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a,
413 | batch_s, mask_s)
414 |
415 | for i in range(self.num_layers):
416 | if i in feat_a_t_dict:
417 | feat_a = feat_a_t_dict[i]
418 | feat_a = feat_a.reshape(-1, self.hidden_dim)
419 | feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
420 | feat_a = feat_a.reshape(-1, num_step,
421 | self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
422 | feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
423 | repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
424 | -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
425 | feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
426 | feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
427 |
428 | if i+1 not in feat_a_t_dict:
429 | feat_a_t_dict[i+1] = feat_a
430 | else:
431 | feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
432 |
433 | next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
434 |
435 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
436 |
437 | topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1)
438 |
439 | expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
440 | next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index)
441 |
442 | theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
443 | cos, sin = theta.cos(), theta.sin()
444 | rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
445 | rot_mat[:, 0, 0] = cos
446 | rot_mat[:, 0, 1] = sin
447 | rot_mat[:, 1, 0] = -sin
448 | rot_mat[:, 1, 1] = cos
449 | agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
450 | rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
451 | -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
452 | agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
453 |
454 | sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device)
455 | agent_pred_rel = agent_pred_rel.gather(dim=1,
456 | index=sample_index[..., None, None, None].expand(-1, -1, 6, 4,
457 | 2))[:, 0, ...]
458 | pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0]
459 | pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
460 | diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
461 | pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
462 |
463 | pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
464 | diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
465 | theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
466 | head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
467 | next_token_idx = next_token_idx.gather(dim=1, index=sample_index)
468 | next_token_idx = next_token_idx.squeeze(-1)
469 | next_token_idx_list.append(next_token_idx[:, None])
470 | agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
471 | next_token_idx[veh_mask]]
472 | agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
473 | next_token_idx[ped_mask]]
474 | agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
475 | next_token_idx[cyc_mask]]
476 | motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
477 | pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
478 |
479 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
480 |
481 | vel = motion_vector_a.clone() / (0.1 * self.shift)
482 | vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
483 | motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
484 | x_a = torch.stack(
485 | [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
486 | angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
487 |
488 | x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
489 | categorical_embs=categorical_embs)
490 | x_a = x_a.view(-1, num_step, self.hidden_dim)
491 |
492 | feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
493 | feat_a = self.fusion_emb(feat_a)
494 |
495 | agent_valid_mask[agent_category != 3] = False
496 |
497 | return {
498 | 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
499 | 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
500 | 'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(),
501 | 'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
502 | 'pred_traj': pred_traj,
503 | 'pred_head': pred_head,
504 | 'next_token_idx': torch.cat(next_token_idx_list, dim=-1),
505 | 'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1),
506 | 'next_token_eval_mask': data['agent']['agent_valid_mask'],
507 | 'pred_prob': pred_prob,
508 | 'vel': vel
509 | }
510 |
--------------------------------------------------------------------------------
/smart/modules/map_decoder.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from typing import Dict
3 | import torch
4 | import torch.nn as nn
5 | from torch_cluster import radius_graph
6 | from torch_geometric.data import Batch
7 | from torch_geometric.data import HeteroData
8 | from torch_geometric.utils import dense_to_sparse, subgraph
9 | from smart.utils.nan_checker import check_nan_inf
10 | from smart.layers.attention_layer import AttentionLayer
11 | from smart.layers import MLPLayer
12 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
13 | from smart.utils import angle_between_2d_vectors
14 | from smart.utils import merge_edges
15 | from smart.utils import weight_init
16 | from smart.utils import wrap_angle
17 | import pickle
18 |
19 |
20 | class SMARTMapDecoder(nn.Module):
21 |
22 | def __init__(self,
23 | dataset: str,
24 | input_dim: int,
25 | hidden_dim: int,
26 | num_historical_steps: int,
27 | pl2pl_radius: float,
28 | num_freq_bands: int,
29 | num_layers: int,
30 | num_heads: int,
31 | head_dim: int,
32 | dropout: float,
33 | map_token) -> None:
34 | super(SMARTMapDecoder, self).__init__()
35 | self.dataset = dataset
36 | self.input_dim = input_dim
37 | self.hidden_dim = hidden_dim
38 | self.num_historical_steps = num_historical_steps
39 | self.pl2pl_radius = pl2pl_radius
40 | self.num_freq_bands = num_freq_bands
41 | self.num_layers = num_layers
42 | self.num_heads = num_heads
43 | self.head_dim = head_dim
44 | self.dropout = dropout
45 |
46 | if input_dim == 2:
47 | input_dim_r_pt2pt = 3
48 | elif input_dim == 3:
49 | input_dim_r_pt2pt = 4
50 | else:
51 | raise ValueError('{} is not a valid dimension'.format(input_dim))
52 |
53 | self.type_pt_emb = nn.Embedding(17, hidden_dim)
54 | self.side_pt_emb = nn.Embedding(4, hidden_dim)
55 | self.polygon_type_emb = nn.Embedding(4, hidden_dim)
56 | self.light_pl_emb = nn.Embedding(4, hidden_dim)
57 |
58 | self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
59 | num_freq_bands=num_freq_bands)
60 | self.pt2pt_layers = nn.ModuleList(
61 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
62 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
63 | )
64 | self.token_size = 1024
65 | self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
66 | output_dim=self.token_size)
67 | input_dim_token = 22
68 | self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
69 | self.map_token = map_token
70 | self.apply(weight_init)
71 | self.mask_pt = False
72 |
73 | def maybe_autocast(self, dtype=torch.float32):
74 | return torch.cuda.amp.autocast(dtype=dtype)
75 |
76 | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
77 | pt_valid_mask = data['pt_token']['pt_valid_mask']
78 | pt_pred_mask = data['pt_token']['pt_pred_mask']
79 | pt_target_mask = data['pt_token']['pt_target_mask']
80 | mask_s = pt_valid_mask
81 |
82 | pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
83 | orient_pt = data['pt_token']['orientation'].contiguous()
84 | orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
85 | token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
86 | pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
87 | pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
88 |
89 | if self.input_dim == 2:
90 | x_pt = pt_token_emb
91 | elif self.input_dim == 3:
92 | x_pt = pt_token_emb
93 | else:
94 | raise ValueError('{} is not a valid dimension'.format(self.input_dim))
95 |
96 | token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
97 | token_light_type = data['map_polygon']['light_type'][token2pl[1]]
98 | x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
99 | self.polygon_type_emb(data['pt_token']['pl_type'].long()),
100 | self.light_pl_emb(token_light_type.long()),]
101 | x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
102 | edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
103 | batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
104 | loop=False, max_num_neighbors=100)
105 | if self.mask_pt:
106 | edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
107 | rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
108 | rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
109 | if self.input_dim == 2:
110 | r_pt2pt = torch.stack(
111 | [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
112 | angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
113 | nbr_vector=rel_pos_pt2pt[:, :2]),
114 | rel_orient_pt2pt], dim=-1)
115 | elif self.input_dim == 3:
116 | r_pt2pt = torch.stack(
117 | [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
118 | angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
119 | nbr_vector=rel_pos_pt2pt[:, :2]),
120 | rel_pos_pt2pt[:, -1],
121 | rel_orient_pt2pt], dim=-1)
122 | else:
123 | raise ValueError('{} is not a valid dimension'.format(self.input_dim))
124 | r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
125 | for i in range(self.num_layers):
126 | x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
127 |
128 | next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
129 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
130 | _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
131 | next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
132 |
133 | return {
134 | 'x_pt': x_pt,
135 | 'map_next_token_idx': next_token_idx,
136 | 'map_next_token_prob': next_token_prob,
137 | 'map_next_token_idx_gt': next_token_index_gt,
138 | 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
139 | }
140 |
--------------------------------------------------------------------------------
/smart/modules/smart_decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | import torch
3 | import torch.nn as nn
4 | from torch_geometric.data import HeteroData
5 | from smart.modules.agent_decoder import SMARTAgentDecoder
6 | from smart.modules.map_decoder import SMARTMapDecoder
7 |
8 |
9 | class SMARTDecoder(nn.Module):
10 |
11 | def __init__(self,
12 | dataset: str,
13 | input_dim: int,
14 | hidden_dim: int,
15 | num_historical_steps: int,
16 | pl2pl_radius: float,
17 | time_span: Optional[int],
18 | pl2a_radius: float,
19 | a2a_radius: float,
20 | num_freq_bands: int,
21 | num_map_layers: int,
22 | num_agent_layers: int,
23 | num_heads: int,
24 | head_dim: int,
25 | dropout: float,
26 | map_token: Dict,
27 | token_data: Dict,
28 | use_intention=False,
29 | token_size=512) -> None:
30 | super(SMARTDecoder, self).__init__()
31 | self.map_encoder = SMARTMapDecoder(
32 | dataset=dataset,
33 | input_dim=input_dim,
34 | hidden_dim=hidden_dim,
35 | num_historical_steps=num_historical_steps,
36 | pl2pl_radius=pl2pl_radius,
37 | num_freq_bands=num_freq_bands,
38 | num_layers=num_map_layers,
39 | num_heads=num_heads,
40 | head_dim=head_dim,
41 | dropout=dropout,
42 | map_token=map_token
43 | )
44 | self.agent_encoder = SMARTAgentDecoder(
45 | dataset=dataset,
46 | input_dim=input_dim,
47 | hidden_dim=hidden_dim,
48 | num_historical_steps=num_historical_steps,
49 | time_span=time_span,
50 | pl2a_radius=pl2a_radius,
51 | a2a_radius=a2a_radius,
52 | num_freq_bands=num_freq_bands,
53 | num_layers=num_agent_layers,
54 | num_heads=num_heads,
55 | head_dim=head_dim,
56 | dropout=dropout,
57 | token_size=token_size,
58 | token_data=token_data
59 | )
60 | self.map_enc = None
61 |
62 | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
63 | map_enc = self.map_encoder(data)
64 | agent_enc = self.agent_encoder(data, map_enc)
65 | return {**map_enc, **agent_enc}
66 |
67 | def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
68 | map_enc = self.map_encoder(data)
69 | agent_enc = self.agent_encoder.inference(data, map_enc)
70 | return {**map_enc, **agent_enc}
71 |
72 | def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
73 | agent_enc = self.agent_encoder.inference(data, map_enc)
74 | return {**map_enc, **agent_enc}
75 |
--------------------------------------------------------------------------------
/smart/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/preprocess/__init__.py
--------------------------------------------------------------------------------
/smart/preprocess/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import torch
5 | from typing import Any, Dict, List, Optional
6 |
7 | predict_unseen_agents = False
8 | vector_repr = True
9 | _agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
10 | _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
11 | _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
12 | _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
13 | 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
14 | 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
15 | 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
16 | _point_sides = ['LEFT', 'RIGHT', 'CENTER']
17 | _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
18 | _polygon_is_intersections = [True, False, None]
19 |
20 |
21 | Lane_type_hash = {
22 | 4: "BIKE",
23 | 3: "VEHICLE",
24 | 2: "VEHICLE",
25 | 1: "BUS"
26 | }
27 |
28 | boundary_type_hash = {
29 | 5: "UNKNOWN",
30 | 6: "DASHED_WHITE",
31 | 7: "SOLID_WHITE",
32 | 8: "DOUBLE_DASH_WHITE",
33 | 9: "DASHED_YELLOW",
34 | 10: "DOUBLE_DASH_YELLOW",
35 | 11: "SOLID_YELLOW",
36 | 12: "DOUBLE_SOLID_YELLOW",
37 | 13: "DASH_SOLID_YELLOW",
38 | 14: "UNKNOWN",
39 | 15: "EDGE",
40 | 16: "EDGE"
41 | }
42 |
43 |
44 | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
45 | if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps
46 | historical_df = df[df['timestep'] == num_historical_steps-1]
47 | agent_ids = list(historical_df['track_id'].unique())
48 | df = df[df['track_id'].isin(agent_ids)]
49 | else:
50 | agent_ids = list(df['track_id'].unique())
51 |
52 | num_agents = len(agent_ids)
53 | # initialization
54 | valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
55 | current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
56 | predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
57 | agent_id: List[Optional[str]] = [None] * num_agents
58 | agent_type = torch.zeros(num_agents, dtype=torch.uint8)
59 | agent_category = torch.zeros(num_agents, dtype=torch.uint8)
60 | position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
61 | heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
62 | velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
63 | shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
64 |
65 | for track_id, track_df in df.groupby('track_id'):
66 | agent_idx = agent_ids.index(track_id)
67 | agent_steps = track_df['timestep'].values
68 |
69 | valid_mask[agent_idx, agent_steps] = True
70 | current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
71 | predict_mask[agent_idx, agent_steps] = True
72 | if vector_repr: # a time step t is valid only when both t and t-1 are valid
73 | valid_mask[agent_idx, 1: num_historical_steps] = (
74 | valid_mask[agent_idx, :num_historical_steps - 1] &
75 | valid_mask[agent_idx, 1: num_historical_steps])
76 | valid_mask[agent_idx, 0] = False
77 | predict_mask[agent_idx, :num_historical_steps] = False
78 | if not current_valid_mask[agent_idx]:
79 | predict_mask[agent_idx, num_historical_steps:] = False
80 |
81 | agent_id[agent_idx] = track_id
82 | agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
83 | agent_category[agent_idx] = track_df['object_category'].values[0]
84 | position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
85 | track_df['position_y'].values,
86 | track_df['position_z'].values],
87 | axis=-1)).float()
88 | heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
89 | velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
90 | track_df['velocity_y'].values],
91 | axis=-1)).float()
92 | shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
93 | track_df['width'].values,
94 | track_df["height"].values],
95 | axis=-1)).float()
96 | av_idx = agent_id.index(av_id)
97 |
98 | return {
99 | 'num_nodes': num_agents,
100 | 'av_index': av_idx,
101 | 'valid_mask': valid_mask,
102 | 'predict_mask': predict_mask,
103 | 'id': agent_id,
104 | 'type': agent_type,
105 | 'category': agent_category,
106 | 'position': position,
107 | 'heading': heading,
108 | 'velocity': velocity,
109 | 'shape': shape
110 | }
--------------------------------------------------------------------------------
/smart/tokens/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/__init__.py
--------------------------------------------------------------------------------
/smart/tokens/cluster_frame_5_2048.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/cluster_frame_5_2048.pkl
--------------------------------------------------------------------------------
/smart/tokens/map_traj_token5.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/map_traj_token5.pkl
--------------------------------------------------------------------------------
/smart/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from smart.transforms.target_builder import WaymoTargetBuilder
2 |
--------------------------------------------------------------------------------
/smart/transforms/target_builder.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | from torch_geometric.data import HeteroData
5 | from torch_geometric.transforms import BaseTransform
6 | from smart.utils import wrap_angle
7 | from smart.utils.log import Logging
8 |
9 |
10 | def to_16(data):
11 | if isinstance(data, dict):
12 | for key, value in data.items():
13 | new_value = to_16(value)
14 | data[key] = new_value
15 | if isinstance(data, torch.Tensor):
16 | if data.dtype == torch.float32:
17 | data = data.to(torch.float16)
18 | return data
19 |
20 |
21 | def tofloat32(data):
22 | for name in data:
23 | value = data[name]
24 | if isinstance(value, dict):
25 | value = tofloat32(value)
26 | elif isinstance(value, torch.Tensor) and value.dtype == torch.float64:
27 | value = value.to(torch.float32)
28 | data[name] = value
29 | return data
30 |
31 |
32 | class WaymoTargetBuilder(BaseTransform):
33 |
34 | def __init__(self,
35 | num_historical_steps: int,
36 | num_future_steps: int,
37 | mode="train") -> None:
38 | self.num_historical_steps = num_historical_steps
39 | self.num_future_steps = num_future_steps
40 | self.mode = mode
41 | self.num_features = 3
42 | self.augment = False
43 | self.logger = Logging().log(level='DEBUG')
44 |
45 | def score_ego_agent(self, agent):
46 | av_index = agent['av_index']
47 | agent["category"][av_index] = 5
48 | return agent
49 |
50 | def clip(self, agent, max_num=32):
51 | av_index = agent["av_index"]
52 | valid = agent['valid_mask']
53 | ego_pos = agent["position"][av_index]
54 | obstacle_mask = agent['type'] == 3
55 | distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car
56 | distance[obstacle_mask] = 10e5
57 | sort_idx = distance.sort()[1]
58 | mask = torch.zeros(valid.shape[0])
59 | mask[sort_idx[:max_num]] = 1
60 | mask = mask.to(torch.bool)
61 | mask[av_index] = True
62 | new_av_index = mask[:av_index].sum()
63 | agent["num_nodes"] = int(mask.sum())
64 | agent["av_index"] = int(new_av_index)
65 | excluded = ["num_nodes", "av_index", "ego"]
66 | for key, val in agent.items():
67 | if key in excluded:
68 | continue
69 | if key == "id":
70 | val = list(np.array(val)[mask])
71 | agent[key] = val
72 | continue
73 | if len(val.size()) > 1:
74 | agent[key] = val[mask, ...]
75 | else:
76 | agent[key] = val[mask]
77 | return agent
78 |
79 | def score_nearby_vehicle(self, agent, max_num=10):
80 | av_index = agent['av_index']
81 | agent["category"] = torch.zeros_like(agent["category"])
82 | obstacle_mask = agent['type'] == 3
83 | pos = agent["position"][av_index, self.num_historical_steps, :2]
84 | distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
85 | distance[obstacle_mask] = 10e5
86 | sort_idx = distance.sort()[1]
87 | nearby_mask = torch.zeros(distance.shape[0])
88 | nearby_mask[sort_idx[1:max_num]] = 1
89 | nearby_mask = nearby_mask.bool()
90 | agent["category"][nearby_mask] = 3
91 | agent["category"][obstacle_mask] = 0
92 |
93 | def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
94 | av_index = agent['av_index']
95 | agent["category"] = torch.zeros_like(agent["category"])
96 | pos = agent["position"][av_index, self.num_historical_steps, :2]
97 | distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
98 | distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1)
99 | invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters
100 | agent["valid_mask"] = agent["valid_mask"] * invalid_mask
101 | # we do not predict vehicle too far away from ego car
102 | closet_vehicle = distance < 100
103 | valid = agent['valid_mask']
104 | valid_current = valid[:, (self.num_historical_steps):]
105 | valid_counts = valid_current.sum(1)
106 | counts_vehicle = valid_counts >= 1
107 | no_backgroud = agent['type'] != 3
108 | vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud
109 | if vehicle2pred.sum() > max_num:
110 | # too many still vehicle so that train the model using the moving vehicle as much as possible
111 | true_indices = torch.nonzero(vehicle2pred).squeeze(1)
112 | selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]]
113 | vehicle2pred.fill_(False)
114 | vehicle2pred[selected_indices] = True
115 | agent["category"][vehicle2pred] = 3
116 |
117 | def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps):
118 | origin = position[:, num_historical_steps - 1]
119 | theta = heading[:, num_historical_steps - 1]
120 | cos, sin = theta.cos(), theta.sin()
121 | rot_mat = theta.new_zeros(num_nodes, 2, 2)
122 | rot_mat[:, 0, 0] = cos
123 | rot_mat[:, 0, 1] = -sin
124 | rot_mat[:, 1, 0] = sin
125 | rot_mat[:, 1, 1] = cos
126 | target = origin.new_zeros(num_nodes, num_future_steps, 4)
127 | target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] -
128 | origin[:, :2].unsqueeze(1), rot_mat)
129 | his = origin.new_zeros(num_nodes, num_historical_steps, 4)
130 | his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] -
131 | origin[:, :2].unsqueeze(1), rot_mat)
132 | if position.size(2) == 3:
133 | target[..., 2] = (position[:, num_historical_steps:, 2] -
134 | origin[:, 2].unsqueeze(-1))
135 | his[..., 2] = (position[:, :num_historical_steps, 2] -
136 | origin[:, 2].unsqueeze(-1))
137 | target[..., 3] = wrap_angle(heading[:, num_historical_steps:] -
138 | theta.unsqueeze(-1))
139 | his[..., 3] = wrap_angle(heading[:, :num_historical_steps] -
140 | theta.unsqueeze(-1))
141 | else:
142 | target[..., 2] = wrap_angle(heading[:, num_historical_steps:] -
143 | theta.unsqueeze(-1))
144 | his[..., 2] = wrap_angle(heading[:, :num_historical_steps] -
145 | theta.unsqueeze(-1))
146 | return his, target
147 |
148 | def __call__(self, data) -> HeteroData:
149 | agent = data["agent"]
150 | self.score_ego_agent(agent)
151 | self.score_trained_vehicle(agent, max_num=32)
152 | return HeteroData(data)
153 |
--------------------------------------------------------------------------------
/smart/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from smart.utils.geometry import angle_between_2d_vectors
3 | from smart.utils.geometry import angle_between_3d_vectors
4 | from smart.utils.geometry import side_to_directed_lineseg
5 | from smart.utils.geometry import wrap_angle
6 | from smart.utils.graph import add_edges
7 | from smart.utils.graph import bipartite_dense_to_sparse
8 | from smart.utils.graph import complete_graph
9 | from smart.utils.graph import merge_edges
10 | from smart.utils.graph import unbatch
11 | from smart.utils.list import safe_list_index
12 | from smart.utils.weight_init import weight_init
13 |
--------------------------------------------------------------------------------
/smart/utils/cluster_reader.py:
--------------------------------------------------------------------------------
1 | import io
2 | import pickle
3 | import pandas as pd
4 | import json
5 |
6 |
7 | class LoadScenarioFromCeph:
8 | def __init__(self):
9 | from petrel_client.client import Client
10 | self.file_client = Client('~/petreloss.conf')
11 |
12 | def list(self, dir_path):
13 | return list(self.file_client.list(dir_path))
14 |
15 | def save(self, data, url):
16 | self.file_client.put(url, pickle.dumps(data))
17 |
18 | def read_correct_csv(self, scenario_path):
19 | output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
20 | return output
21 |
22 | def contains(self, url):
23 | return self.file_client.contains(url)
24 |
25 | def read_string(self, csv_url):
26 | from io import StringIO
27 | df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
28 | return df
29 |
30 | def read(self, scenario_path):
31 | with io.BytesIO(self.file_client.get(scenario_path)) as f:
32 | datas = pickle.load(f)
33 | return datas
34 |
35 | def read_json(self, path):
36 | with io.BytesIO(self.file_client.get(path)) as f:
37 | data = json.load(f)
38 | return data
39 |
40 | def read_csv(self, scenario_path):
41 | return pickle.loads(self.file_client.get(scenario_path))
42 |
43 | def read_model(self, model_path):
44 | with io.BytesIO(self.file_client.get(model_path)) as f:
45 | pass
46 |
--------------------------------------------------------------------------------
/smart/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import easydict
4 |
5 |
6 | def load_config_act(path):
7 | """ load config file"""
8 | with open(path, 'r') as f:
9 | cfg = yaml.load(f, Loader=yaml.FullLoader)
10 | return easydict.EasyDict(cfg)
11 |
12 |
13 | def load_config_init(path):
14 | """ load config file"""
15 | path = os.path.join('init/configs', f'{path}.yaml')
16 | with open(path, 'r') as f:
17 | cfg = yaml.load(f, Loader=yaml.FullLoader)
18 | return cfg
19 |
--------------------------------------------------------------------------------
/smart/utils/geometry.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 |
4 | import torch
5 |
6 |
7 | def angle_between_2d_vectors(
8 | ctr_vector: torch.Tensor,
9 | nbr_vector: torch.Tensor) -> torch.Tensor:
10 | return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0],
11 | (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1))
12 |
13 |
14 | def angle_between_3d_vectors(
15 | ctr_vector: torch.Tensor,
16 | nbr_vector: torch.Tensor) -> torch.Tensor:
17 | return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1),
18 | (ctr_vector * nbr_vector).sum(dim=-1))
19 |
20 |
21 | def side_to_directed_lineseg(
22 | query_point: torch.Tensor,
23 | start_point: torch.Tensor,
24 | end_point: torch.Tensor) -> str:
25 | cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) -
26 | (end_point[1] - start_point[1]) * (query_point[0] - start_point[0]))
27 | if cond > 0:
28 | return 'LEFT'
29 | elif cond < 0:
30 | return 'RIGHT'
31 | else:
32 | return 'CENTER'
33 |
34 |
35 | def wrap_angle(
36 | angle: torch.Tensor,
37 | min_val: float = -math.pi,
38 | max_val: float = math.pi) -> torch.Tensor:
39 | return min_val + (angle + max_val) % (max_val - min_val)
40 |
--------------------------------------------------------------------------------
/smart/utils/graph.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List, Optional, Tuple, Union
3 |
4 | import torch
5 | from torch_geometric.utils import coalesce
6 | from torch_geometric.utils import degree
7 |
8 |
9 | def add_edges(
10 | from_edge_index: torch.Tensor,
11 | to_edge_index: torch.Tensor,
12 | from_edge_attr: Optional[torch.Tensor] = None,
13 | to_edge_attr: Optional[torch.Tensor] = None,
14 | replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
15 | from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)
16 | mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &
17 | (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))
18 | if replace:
19 | to_mask = mask.any(dim=1)
20 | if from_edge_attr is not None and to_edge_attr is not None:
21 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
22 | to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)
23 | to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)
24 | else:
25 | from_mask = mask.any(dim=0)
26 | if from_edge_attr is not None and to_edge_attr is not None:
27 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
28 | to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)
29 | to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)
30 | return to_edge_index, to_edge_attr
31 |
32 |
33 | def merge_edges(
34 | edge_indices: List[torch.Tensor],
35 | edge_attrs: Optional[List[torch.Tensor]] = None,
36 | reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
37 | edge_index = torch.cat(edge_indices, dim=1)
38 | if edge_attrs is not None:
39 | edge_attr = torch.cat(edge_attrs, dim=0)
40 | else:
41 | edge_attr = None
42 | return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)
43 |
44 |
45 | def complete_graph(
46 | num_nodes: Union[int, Tuple[int, int]],
47 | ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
48 | loop: bool = False,
49 | device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:
50 | if ptr is None:
51 | if isinstance(num_nodes, int):
52 | num_src, num_dst = num_nodes, num_nodes
53 | else:
54 | num_src, num_dst = num_nodes
55 | edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
56 | torch.arange(num_dst, dtype=torch.long, device=device)).t()
57 | else:
58 | if isinstance(ptr, torch.Tensor):
59 | ptr_src, ptr_dst = ptr, ptr
60 | num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]
61 | else:
62 | ptr_src, ptr_dst = ptr
63 | num_src_batch = ptr_src[1:] - ptr_src[:-1]
64 | num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]
65 | edge_index = torch.cat(
66 | [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
67 | torch.arange(num_dst, dtype=torch.long, device=device)) + p
68 | for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],
69 | dim=0)
70 | edge_index = edge_index.t()
71 | if isinstance(num_nodes, int) and not loop:
72 | edge_index = edge_index[:, edge_index[0] != edge_index[1]]
73 | return edge_index.contiguous()
74 |
75 |
76 | def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
77 | index = adj.nonzero(as_tuple=True)
78 | if len(index) == 3:
79 | batch_src = index[0] * adj.size(1)
80 | batch_dst = index[0] * adj.size(2)
81 | index = (batch_src + index[1], batch_dst + index[2])
82 | return torch.stack(index, dim=0)
83 |
84 |
85 | def unbatch(
86 | src: torch.Tensor,
87 | batch: torch.Tensor,
88 | dim: int = 0) -> List[torch.Tensor]:
89 | sizes = degree(batch, dtype=torch.long).tolist()
90 | return src.split(sizes, dim)
91 |
--------------------------------------------------------------------------------
/smart/utils/list.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import Any, List, Optional
3 |
4 |
5 | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
6 | try:
7 | return ls.index(elem)
8 | except ValueError:
9 | return None
10 |
--------------------------------------------------------------------------------
/smart/utils/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import os
4 |
5 |
6 | class Logging:
7 |
8 | def make_log_dir(self, dirname='logs'):
9 | now_dir = os.path.dirname(__file__)
10 | path = os.path.join(now_dir, dirname)
11 | path = os.path.normpath(path)
12 | if not os.path.exists(path):
13 | os.mkdir(path)
14 | return path
15 |
16 | def get_log_filename(self):
17 | filename = "{}.log".format(time.strftime("%Y-%m-%d",time.localtime()))
18 | filename = os.path.join(self.make_log_dir(), filename)
19 | filename = os.path.normpath(filename)
20 | return filename
21 |
22 | def log(self, level='DEBUG', name="simagent"):
23 | logger = logging.getLogger(name)
24 | level = getattr(logging, level)
25 | logger.setLevel(level)
26 | if not logger.handlers:
27 | sh = logging.StreamHandler()
28 | fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
29 | fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
30 | sh.setFormatter(fmt=fmt)
31 | fh.setFormatter(fmt=fmt)
32 | logger.addHandler(sh)
33 | logger.addHandler(fh)
34 | return logger
35 |
36 | def add_log(self, logger, level='DEBUG'):
37 | level = getattr(logging, level)
38 | logger.setLevel(level)
39 | if not logger.handlers:
40 | sh = logging.StreamHandler()
41 | fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
42 | fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
43 | sh.setFormatter(fmt=fmt)
44 | fh.setFormatter(fmt=fmt)
45 | logger.addHandler(sh)
46 | logger.addHandler(fh)
47 | return logger
48 |
49 |
50 | if __name__ == '__main__':
51 | logger = Logging().log(level='INFO')
52 | logger.debug("1111111111111111111111") #使用日志器生成日志
53 | logger.info("222222222222222222222222")
54 | logger.error("附件为IP飞机外婆家二分IP文件放")
55 | logger.warning("3333333333333333333333333333")
56 | logger.critical("44444444444444444444444444")
57 |
--------------------------------------------------------------------------------
/smart/utils/nan_checker.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def check_nan_inf(t, s):
4 | assert not torch.isinf(t).any(), f"{s} is inf, {t}"
5 | assert not torch.isnan(t).any(), f"{s} is nan, {t}"
--------------------------------------------------------------------------------
/smart/utils/weight_init.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 |
4 |
5 | def weight_init(m: nn.Module) -> None:
6 | if isinstance(m, nn.Linear):
7 | nn.init.xavier_uniform_(m.weight)
8 | if m.bias is not None:
9 | nn.init.zeros_(m.bias)
10 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
11 | fan_in = m.in_channels / m.groups
12 | fan_out = m.out_channels / m.groups
13 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
14 | nn.init.uniform_(m.weight, -bound, bound)
15 | if m.bias is not None:
16 | nn.init.zeros_(m.bias)
17 | elif isinstance(m, nn.Embedding):
18 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
19 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
20 | nn.init.ones_(m.weight)
21 | nn.init.zeros_(m.bias)
22 | elif isinstance(m, nn.LayerNorm):
23 | nn.init.ones_(m.weight)
24 | nn.init.zeros_(m.bias)
25 | elif isinstance(m, nn.MultiheadAttention):
26 | if m.in_proj_weight is not None:
27 | fan_in = m.embed_dim
28 | fan_out = m.embed_dim
29 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
30 | nn.init.uniform_(m.in_proj_weight, -bound, bound)
31 | else:
32 | nn.init.xavier_uniform_(m.q_proj_weight)
33 | nn.init.xavier_uniform_(m.k_proj_weight)
34 | nn.init.xavier_uniform_(m.v_proj_weight)
35 | if m.in_proj_bias is not None:
36 | nn.init.zeros_(m.in_proj_bias)
37 | nn.init.xavier_uniform_(m.out_proj.weight)
38 | if m.out_proj.bias is not None:
39 | nn.init.zeros_(m.out_proj.bias)
40 | if m.bias_k is not None:
41 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
42 | if m.bias_v is not None:
43 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
44 | elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
45 | for name, param in m.named_parameters():
46 | if 'weight_ih' in name:
47 | for ih in param.chunk(4, 0):
48 | nn.init.xavier_uniform_(ih)
49 | elif 'weight_hh' in name:
50 | for hh in param.chunk(4, 0):
51 | nn.init.orthogonal_(hh)
52 | elif 'weight_hr' in name:
53 | nn.init.xavier_uniform_(param)
54 | elif 'bias_ih' in name:
55 | nn.init.zeros_(param)
56 | elif 'bias_hh' in name:
57 | nn.init.zeros_(param)
58 | nn.init.ones_(param.chunk(4, 0)[1])
59 | elif isinstance(m, (nn.GRU, nn.GRUCell)):
60 | for name, param in m.named_parameters():
61 | if 'weight_ih' in name:
62 | for ih in param.chunk(3, 0):
63 | nn.init.xavier_uniform_(ih)
64 | elif 'weight_hh' in name:
65 | for hh in param.chunk(3, 0):
66 | nn.init.orthogonal_(hh)
67 | elif 'bias_ih' in name:
68 | nn.init.zeros_(param)
69 | elif 'bias_hh' in name:
70 | nn.init.zeros_(param)
71 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | from argparse import ArgumentParser
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import LearningRateMonitor
5 | from pytorch_lightning.callbacks import ModelCheckpoint
6 | from pytorch_lightning.strategies import DDPStrategy
7 | from smart.utils.config import load_config_act
8 | from smart.datamodules import MultiDataModule
9 | from smart.model import SMART
10 | from smart.utils.log import Logging
11 |
12 |
13 | if __name__ == '__main__':
14 | parser = ArgumentParser()
15 | Predictor_hash = {"smart": SMART, }
16 | parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml')
17 | parser.add_argument('--pretrain_ckpt', type=str, default="")
18 | parser.add_argument('--ckpt_path', type=str, default="")
19 | parser.add_argument('--save_ckpt_path', type=str, default="")
20 | args = parser.parse_args()
21 | config = load_config_act(args.config)
22 | Predictor = Predictor_hash[config.Model.predictor]
23 | strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
24 | Data_config = config.Dataset
25 | datamodule = MultiDataModule(**vars(Data_config))
26 |
27 | if args.pretrain_ckpt == "":
28 | model = Predictor(config.Model)
29 | else:
30 | logger = Logging().log(level='DEBUG')
31 | model = Predictor(config.Model)
32 | model.load_params_from_file(filename=args.pretrain_ckpt,
33 | logger=logger)
34 | trainer_config = config.Trainer
35 | model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
36 | filename="{epoch:02d}",
37 | monitor='val_cls_acc',
38 | every_n_epochs=1,
39 | save_top_k=5,
40 | mode='max')
41 | lr_monitor = LearningRateMonitor(logging_interval='epoch')
42 | trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices,
43 | strategy=strategy,
44 | accumulate_grad_batches=trainer_config.accumulate_grad_batches,
45 | num_nodes=trainer_config.num_nodes,
46 | callbacks=[model_checkpoint, lr_monitor],
47 | max_epochs=trainer_config.max_epochs,
48 | num_sanity_val_steps=0,
49 | gradient_clip_val=0.5)
50 | if args.ckpt_path == "":
51 | trainer.fit(model,
52 | datamodule)
53 | else:
54 | trainer.fit(model,
55 | datamodule,
56 | ckpt_path=args.ckpt_path)
57 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 |
2 | from argparse import ArgumentParser
3 | import pytorch_lightning as pl
4 | from torch_geometric.loader import DataLoader
5 | from smart.datasets.scalable_dataset import MultiDataset
6 | from smart.model import SMART
7 | from smart.transforms import WaymoTargetBuilder
8 | from smart.utils.config import load_config_act
9 | from smart.utils.log import Logging
10 |
11 | if __name__ == '__main__':
12 | pl.seed_everything(2, workers=True)
13 | parser = ArgumentParser()
14 | parser.add_argument('--config', type=str, default="configs/validation/validation_scalable.yaml")
15 | parser.add_argument('--pretrain_ckpt', type=str, default="")
16 | parser.add_argument('--ckpt_path', type=str, default="")
17 | parser.add_argument('--save_ckpt_path', type=str, default="")
18 | args = parser.parse_args()
19 | config = load_config_act(args.config)
20 |
21 | data_config = config.Dataset
22 | val_dataset = {
23 | "scalable": MultiDataset,
24 | }[data_config.dataset](root=data_config.root, split='val',
25 | raw_dir=data_config.val_raw_dir,
26 | processed_dir=data_config.val_processed_dir,
27 | transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))
28 | dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers,
29 | pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False)
30 | Predictor = SMART
31 | if args.pretrain_ckpt == "":
32 | model = Predictor(config.Model)
33 | else:
34 | logger = Logging().log(level='DEBUG')
35 | model = Predictor(config.Model)
36 | model.load_params_from_file(filename=args.pretrain_ckpt,
37 | logger=logger)
38 |
39 | trainer_config = config.Trainer
40 | trainer = pl.Trainer(accelerator=trainer_config.accelerator,
41 | devices=trainer_config.devices,
42 | strategy='ddp', num_sanity_val_steps=0)
43 | trainer.validate(model, dataloader)
44 |
--------------------------------------------------------------------------------