├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── assets └── result1.png ├── configs ├── train │ └── train_scalable.yaml └── validation │ └── validation_scalable.yaml ├── data └── valid_demo │ ├── 1a146468873b7871.pkl │ ├── 1a37753b91c30e16.pkl │ ├── 1b076986122dc653.pkl │ ├── 1c83f56236e33b4.pkl │ ├── 1ce0b4bbd35a6ad1.pkl │ ├── 1d3daf744e65dd7c.pkl │ ├── 1d71990efc3c6c0e.pkl │ ├── 1d8ff1c6acd0f266.pkl │ ├── 1ea00a2a7b936853.pkl │ ├── 2a716fad0956ffcf.pkl │ └── 2ab47807431834c1.pkl ├── data_preprocess.py ├── environment.yml ├── requirements.txt ├── scripts ├── install_pyg.sh └── traj_clstering.py ├── smart ├── __init__.py ├── datamodules │ ├── __init__.py │ └── scalable_datamodule.py ├── datasets │ ├── __init__.py │ ├── preprocess.py │ └── scalable_dataset.py ├── layers │ ├── __init__.py │ ├── attention_layer.py │ ├── fourier_embedding.py │ └── mlp_layer.py ├── metrics │ ├── __init__.py │ ├── average_meter.py │ ├── min_ade.py │ ├── min_fde.py │ ├── next_token_cls.py │ └── utils.py ├── model │ ├── __init__.py │ └── smart.py ├── modules │ ├── __init__.py │ ├── agent_decoder.py │ ├── map_decoder.py │ └── smart_decoder.py ├── preprocess │ ├── __init__.py │ └── preprocess.py ├── tokens │ ├── __init__.py │ ├── cluster_frame_5_2048.pkl │ └── map_traj_token5.pkl ├── transforms │ ├── __init__.py │ └── target_builder.py └── utils │ ├── __init__.py │ ├── cluster_reader.py │ ├── config.py │ ├── geometry.py │ ├── graph.py │ ├── list.py │ ├── log.py │ ├── nan_checker.py │ └── weight_init.py ├── train.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .github 6 | ckpt/ 7 | # assets/ 8 | # C extensions 9 | *.so 10 | # /assets 11 | /data 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | *.jpg 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | *.jpg 97 | pyg_depend/ 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # IDEs 112 | .idea 113 | .vscode 114 | 115 | # seed project 116 | av2/ 117 | lightning_logs/ 118 | lightning_logs_/ 119 | lightning_l/ 120 | .DS_Store 121 | data/argo 122 | data/res 123 | data/waymo* 124 | fig*/ 125 | data/waymo_token 126 | data/submission 127 | data/token_seq_emb_nuplan 128 | data/token_seq_emb_waymo 129 | data/nuplan* 130 | submission.tar.gz 131 | data/feat* 132 | data/scalable 133 | data/pos_data 134 | res_metrics* 135 | gathered* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction 4 | 5 | [Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/) 6 | 7 |
8 | 9 | - **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) 10 | - **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/) 11 | 12 | ## News 13 | - **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning** 14 | - **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024 15 | - **[August 31, 2024]** Code released 16 | - **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/) 17 | - **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677) 18 | 19 | 20 | ## Introduction 21 | This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens. 22 | 23 | ## Requirements 24 | 25 | To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies: 26 | 27 | ```bash 28 | conda env create -f environment.yml 29 | conda activate SMART 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | If you encounter issues while installing pyg dependencies, execute the following script: 34 | ```setup 35 | bash install_pyg.sh 36 | ``` 37 | 38 | Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice. 39 | 40 | ## Data installation 41 | 42 | **Step 1: Download the Dataset** 43 | 44 | Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows: 45 | ``` 46 | SMART 47 | ├── data 48 | │ ├── waymo 49 | │ │ ├── scenario 50 | │ │ │ ├──training 51 | │ │ │ ├──validation 52 | │ │ │ ├──testing 53 | ├── model 54 | ├── tools 55 | ``` 56 | 57 | **Step 2: Install the Waymo Open Dataset API** 58 | 59 | Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API. 60 | 61 | **Step 3: Preprocess the Dataset** 62 | 63 | Preprocess the dataset by running: 64 | ``` 65 | python data_preprocess.py --input_dir ./data/waymo/scenario/training --output_dir ./data/waymo_processed/training 66 | ``` 67 | The first path is the raw data path, and the second is the output data path. 68 | 69 | The processed data will be saved to the `data/waymo_processed/` directory as follows: 70 | 71 | ``` 72 | SMART 73 | ├── data 74 | │ ├── waymo_processed 75 | │ │ ├── training 76 | │ │ ├── validation 77 | │ │ ├──testing 78 | ├── model 79 | ├── utils 80 | ``` 81 | 82 | ## Training 83 | 84 | To train the model, run the following command: 85 | 86 | ```train 87 | python train.py --config ${config_path} 88 | ``` 89 | 90 | The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training. 91 | 92 | ## Evaluation 93 | 94 | To evaluate the model, run: 95 | 96 | ```eval 97 | python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path} 98 | ``` 99 | This will evaluate the model using the configuration and checkpoint provided. 100 | 101 | 102 | ## Pre-trained Models 103 | 104 | To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed. 105 | 106 | ## Results 107 | 108 | ### Waymo Open Motion Dataset Sim Agents Challenge 109 | 110 | Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/): 111 | 112 | | Model name | Metric Score | 113 | | :-----------: | ------------ | 114 | | SMART-tiny | 0.7591 | 115 | | SMART-large | 0.7614 | 116 | | SMART-zeroshot| 0.7210 | 117 | 118 | ### NuPlan Closed-loop Planning 119 | 120 | **SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below: 121 | 122 | ![nuPlan Closed-loop Planning](assets/result1.png) 123 | 124 | ## Citation 125 | 126 | If you find this repository useful, please consider citing our work and giving us a star: 127 | 128 | ```citation 129 | @article{wu2024smart, 130 | title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction}, 131 | author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng}, 132 | journal={arXiv preprint arXiv:2405.15677}, 133 | year={2024} 134 | } 135 | ``` 136 | 137 | ## Acknowledgements 138 | Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work. 139 | 140 | ## License 141 | All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). 142 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /assets/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/assets/result1.png -------------------------------------------------------------------------------- /configs/train/train_scalable.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number, the yaml support to valid case source from different dataset 2 | time_info: &time_info 3 | num_historical_steps: 11 4 | num_future_steps: 80 5 | use_intention: True 6 | token_size: 2048 7 | 8 | Dataset: 9 | root: 10 | train_batch_size: 1 11 | val_batch_size: 1 12 | test_batch_size: 1 13 | shuffle: True 14 | num_workers: 1 15 | pin_memory: True 16 | persistent_workers: True 17 | train_raw_dir: ["data/valid_demo"] 18 | val_raw_dir: ["data/valid_demo"] 19 | test_raw_dir: 20 | transform: WaymoTargetBuilder 21 | train_processed_dir: 22 | val_processed_dir: 23 | test_processed_dir: 24 | dataset: "scalable" 25 | <<: *time_info 26 | 27 | Trainer: 28 | strategy: ddp_find_unused_parameters_false 29 | accelerator: "gpu" 30 | devices: 1 31 | max_epochs: 32 32 | save_ckpt_path: 33 | num_nodes: 1 34 | mode: 35 | ckpt_path: 36 | precision: 32 37 | accumulate_grad_batches: 1 38 | 39 | Model: 40 | mode: "train" 41 | predictor: "smart" 42 | dataset: "waymo" 43 | input_dim: 2 44 | hidden_dim: 128 45 | output_dim: 2 46 | output_head: False 47 | num_heads: 8 48 | <<: *time_info 49 | head_dim: 16 50 | dropout: 0.1 51 | num_freq_bands: 64 52 | lr: 0.0005 53 | warmup_steps: 0 54 | total_steps: 32 55 | decoder: 56 | <<: *time_info 57 | num_map_layers: 3 58 | num_agent_layers: 6 59 | a2a_radius: 60 60 | pl2pl_radius: 10 61 | pl2a_radius: 30 62 | time_span: 30 63 | -------------------------------------------------------------------------------- /configs/validation/validation_scalable.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number, the yaml support to valid case source from different dataset 2 | time_info: &time_info 3 | num_historical_steps: 11 4 | num_future_steps: 80 5 | token_size: 2048 6 | 7 | Dataset: 8 | root: 9 | batch_size: 1 10 | shuffle: True 11 | num_workers: 1 12 | pin_memory: True 13 | persistent_workers: True 14 | train_raw_dir: 15 | val_raw_dir: ["data/valid_demo"] 16 | test_raw_dir: 17 | TargetBuilder: WaymoTargetBuilder 18 | train_processed_dir: 19 | val_processed_dir: 20 | test_processed_dir: 21 | dataset: "scalable" 22 | <<: *time_info 23 | 24 | Trainer: 25 | strategy: ddp_find_unused_parameters_false 26 | accelerator: "gpu" 27 | devices: 1 28 | max_epochs: 32 29 | save_ckpt_path: 30 | num_nodes: 1 31 | mode: 32 | ckpt_path: 33 | precision: 32 34 | accumulate_grad_batches: 1 35 | 36 | Model: 37 | mode: "validation" 38 | predictor: "smart" 39 | dataset: "waymo" 40 | input_dim: 2 41 | hidden_dim: 128 42 | output_dim: 2 43 | output_head: False 44 | num_heads: 8 45 | <<: *time_info 46 | head_dim: 16 47 | dropout: 0.1 48 | num_freq_bands: 64 49 | lr: 0.0005 50 | warmup_steps: 0 51 | total_steps: 32 52 | decoder: 53 | <<: *time_info 54 | num_map_layers: 3 55 | num_agent_layers: 6 56 | a2a_radius: 60 57 | pl2pl_radius: 10 58 | pl2a_radius: 30 59 | time_span: 30 60 | 61 | -------------------------------------------------------------------------------- /data/valid_demo/1a146468873b7871.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1a146468873b7871.pkl -------------------------------------------------------------------------------- /data/valid_demo/1a37753b91c30e16.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1a37753b91c30e16.pkl -------------------------------------------------------------------------------- /data/valid_demo/1b076986122dc653.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1b076986122dc653.pkl -------------------------------------------------------------------------------- /data/valid_demo/1c83f56236e33b4.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1c83f56236e33b4.pkl -------------------------------------------------------------------------------- /data/valid_demo/1ce0b4bbd35a6ad1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1ce0b4bbd35a6ad1.pkl -------------------------------------------------------------------------------- /data/valid_demo/1d3daf744e65dd7c.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d3daf744e65dd7c.pkl -------------------------------------------------------------------------------- /data/valid_demo/1d71990efc3c6c0e.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d71990efc3c6c0e.pkl -------------------------------------------------------------------------------- /data/valid_demo/1d8ff1c6acd0f266.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1d8ff1c6acd0f266.pkl -------------------------------------------------------------------------------- /data/valid_demo/1ea00a2a7b936853.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/1ea00a2a7b936853.pkl -------------------------------------------------------------------------------- /data/valid_demo/2a716fad0956ffcf.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/2a716fad0956ffcf.pkl -------------------------------------------------------------------------------- /data/valid_demo/2ab47807431834c1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/data/valid_demo/2ab47807431834c1.pkl -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: smart 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=5.1=1_gnu 13 | - blas=1.0=mkl 14 | - brotli-python=1.0.9=py39h6a678d5_8 15 | - bzip2=1.0.8=h5eee18b_6 16 | - ca-certificates=2024.9.24=h06a4308_0 17 | - certifi=2024.8.30=py39h06a4308_0 18 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 19 | - cudatoolkit=11.3.1=h2bc3f7f_2 20 | - ffmpeg=4.3=hf484d3e_0 21 | - freetype=2.12.1=h4a9f257_0 22 | - gmp=6.2.1=h295c915_3 23 | - gnutls=3.6.15=he1e5248_0 24 | - idna=3.7=py39h06a4308_0 25 | - intel-openmp=2023.1.0=hdb19cb5_46306 26 | - jpeg=9e=h5eee18b_3 27 | - lame=3.100=h7b6447c_0 28 | - lcms2=2.12=h3be6417_0 29 | - ld_impl_linux-64=2.40=h12ee557_0 30 | - lerc=3.0=h295c915_0 31 | - libdeflate=1.17=h5eee18b_1 32 | - libffi=3.4.4=h6a678d5_1 33 | - libgcc-ng=11.2.0=h1234567_1 34 | - libgomp=11.2.0=h1234567_1 35 | - libiconv=1.14=0 36 | - libidn2=2.3.4=h5eee18b_0 37 | - libpng=1.6.39=h5eee18b_0 38 | - libstdcxx-ng=11.2.0=h1234567_1 39 | - libtasn1=4.19.0=h5eee18b_0 40 | - libtiff=4.5.1=h6a678d5_0 41 | - libunistring=0.9.10=h27cfd23_0 42 | - libwebp-base=1.3.2=h5eee18b_1 43 | - lz4-c=1.9.4=h6a678d5_1 44 | - mkl=2023.1.0=h213fc3f_46344 45 | - mkl-service=2.4.0=py39h5eee18b_1 46 | - mkl_fft=1.3.10=py39h5eee18b_0 47 | - mkl_random=1.2.7=py39h1128e8f_0 48 | - ncurses=6.4=h6a678d5_0 49 | - nettle=3.7.3=hbbd107a_1 50 | - openh264=2.1.1=h4ff587b_0 51 | - openjpeg=2.5.2=he7f1fd0_0 52 | - openssl=3.0.15=h5eee18b_0 53 | - pillow=10.4.0=py39h5eee18b_0 54 | - pip=24.2=py39h06a4308_0 55 | - pysocks=1.7.1=py39h06a4308_0 56 | - python=3.9.19=h955ad1f_1 57 | - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 58 | - pytorch-mutex=1.0=cuda 59 | - readline=8.2=h5eee18b_0 60 | - requests=2.32.3=py39h06a4308_0 61 | - setuptools=75.1.0=py39h06a4308_0 62 | - sqlite=3.45.3=h5eee18b_0 63 | - tbb=2021.8.0=hdb19cb5_0 64 | - tk=8.6.14=h39e8969_0 65 | - torchvision=0.13.1=py39_cu113 66 | - typing_extensions=4.11.0=py39h06a4308_0 67 | - urllib3=2.2.3=py39h06a4308_0 68 | - wheel=0.44.0=py39h06a4308_0 69 | - xz=5.4.6=h5eee18b_1 70 | - zlib=1.2.13=h5eee18b_1 71 | - zstd=1.5.6=hc292b87_0 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.4.3 2 | aiohttp==3.10.10 3 | aiosignal==1.3.1 4 | async-timeout==4.0.3 5 | attrs==24.2.0 6 | contourpy==1.3.0 7 | cycler==0.12.1 8 | easydict==1.13 9 | fonttools==4.54.1 10 | frozenlist==1.4.1 11 | fsspec==2024.10.0 12 | importlib-resources==6.4.5 13 | jinja2==3.1.4 14 | kiwisolver==1.4.7 15 | lightning-utilities==0.11.8 16 | markupsafe==3.0.2 17 | matplotlib==3.9.2 18 | multidict==6.1.0 19 | numpy==1.26.4 20 | packaging==24.1 21 | pandas==2.0.3 22 | propcache==0.2.0 23 | psutil==6.1.0 24 | pyparsing==3.2.0 25 | python-dateutil==2.9.0.post0 26 | pytorch-lightning==2.0.3 27 | pytz==2024.2 28 | pyyaml==6.0.1 29 | scipy==1.10.1 30 | shapely==2.0.6 31 | six==1.16.0 32 | torch-cluster==1.6.0+pt112cu113 33 | torch-geometric==2.6.1 34 | torch-scatter==2.1.0+pt112cu113 35 | torch-sparse==0.6.16+pt112cu113 36 | torch-spline-conv==1.2.1+pt112cu113 37 | torchmetrics==1.5.0 38 | tqdm==4.66.5 39 | tzdata==2024.2 40 | yarl==1.16.0 41 | zipp==3.20.2 42 | waymo-open-dataset-tf-2-12-0==1.6.4 43 | -------------------------------------------------------------------------------- /scripts/install_pyg.sh: -------------------------------------------------------------------------------- 1 | mkdir pyg_depend && cd pyg_depend 2 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl 3 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl 4 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl 5 | wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl 6 | python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl 7 | python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl 8 | python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl 9 | python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl 10 | python3 -m pip install torch_geometric 11 | -------------------------------------------------------------------------------- /scripts/traj_clstering.py: -------------------------------------------------------------------------------- 1 | from smart.utils.geometry import wrap_angle 2 | import numpy as np 3 | 4 | 5 | def average_distance_vectorized(point_set1, centroids): 6 | dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1)) 7 | return np.mean(dists, axis=2) 8 | 9 | 10 | def assign_clusters(sub_X, centroids): 11 | distances = average_distance_vectorized(sub_X, centroids) 12 | return np.argmin(distances, axis=1) 13 | 14 | 15 | def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None): 16 | S = [] 17 | ret_traj_list = [] 18 | while len(S) < N: 19 | num_all = X.shape[0] 20 | # 随机选择第一个簇中心 21 | choice_index = np.random.choice(num_all) 22 | x0 = X[choice_index] 23 | if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10: 24 | continue 25 | res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2) 26 | del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2) 27 | if cal_mean_heading: 28 | del_contour = X[del_mask] 29 | diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :] 30 | del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean() 31 | x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length) 32 | del_traj = a_pos[del_mask] 33 | ret_traj = del_traj.mean(0)[None, ...] 34 | if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0: 35 | print(ret_traj) 36 | print('1') 37 | else: 38 | x0 = x0[None, ...] 39 | ret_traj = a_pos[choice_index][None, ...] 40 | X = X[res_mask] 41 | a_pos = a_pos[res_mask] 42 | S.append(x0) 43 | ret_traj_list.append(ret_traj) 44 | centroids = np.concatenate(S, axis=0) 45 | ret_traj = np.concatenate(ret_traj_list, axis=0) 46 | 47 | # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2)) 48 | 49 | # for k in range(1, K): 50 | # new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2)) 51 | # closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq) 52 | # probabilities = closest_dist_sq / np.sum(closest_dist_sq) 53 | # centroids[k] = X[np.random.choice(N, p=probabilities)] 54 | 55 | return centroids, ret_traj 56 | 57 | 58 | def cal_polygon_contour(x, y, theta, width, length): 59 | 60 | left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) 61 | left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) 62 | left_front = np.column_stack((left_front_x, left_front_y)) 63 | 64 | right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) 65 | right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) 66 | right_front = np.column_stack((right_front_x, right_front_y)) 67 | 68 | right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) 69 | right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) 70 | right_back = np.column_stack((right_back_x, right_back_y)) 71 | 72 | left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) 73 | left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) 74 | left_back = np.column_stack((left_back_x, left_back_y)) 75 | 76 | polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) 77 | 78 | return polygon_contour 79 | 80 | 81 | if __name__ == '__main__': 82 | shift = 5 # motion token time dimension 83 | num_cluster = 6 # vocabulary size 84 | cal_mean_heading = True 85 | data = { 86 | "veh": np.random.rand(1000, 6, 3), 87 | "cyc": np.random.rand(1000, 6, 3), 88 | "ped": np.random.rand(1000, 6, 3) 89 | } 90 | # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]] 91 | nms_res = {} 92 | res = {'token': {}, 'traj': {}, 'token_all': {}} 93 | for k, v in data.items(): 94 | # if k != 'veh': 95 | # continue 96 | a_pos = v 97 | print(a_pos.shape) 98 | # a_pos = a_pos[:, shift:1+shift, :] 99 | cal_num = min(int(1e6), a_pos.shape[0]) 100 | a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)] 101 | a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1]) 102 | print(a_pos.shape) 103 | if shift <= 2: 104 | if k == 'veh': 105 | width = 1.0 106 | length = 2.4 107 | elif k == 'cyc': 108 | width = 0.5 109 | length = 1.5 110 | else: 111 | width = 0.5 112 | length = 0.5 113 | else: 114 | if k == 'veh': 115 | width = 2.0 116 | length = 4.8 117 | elif k == 'cyc': 118 | width = 1.0 119 | length = 2.0 120 | else: 121 | width = 1.0 122 | length = 1.0 123 | contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length) 124 | 125 | # plt.figure(figsize=(10, 10)) 126 | # for rect in contour: 127 | # rect_closed = np.vstack([rect, rect[0]]) 128 | # plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1) 129 | 130 | # plt.title("Plot of 256 Rectangles") 131 | # plt.xlabel("x") 132 | # plt.ylabel("y") 133 | # plt.axis('equal') 134 | # plt.savefig(f'src_{k}_new.jpg', dpi=300) 135 | 136 | if k == 'veh': 137 | tol = 0.05 138 | elif k == 'cyc': 139 | tol = 0.004 140 | else: 141 | tol = 0.004 142 | centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1]) 143 | # plt.figure(figsize=(10, 10)) 144 | contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)), 145 | ret_traj[:, :, 1].reshape(num_cluster*(shift+1)), 146 | ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length) 147 | 148 | res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2) 149 | res['token'][k] = centroids 150 | res['traj'][k] = ret_traj 151 | -------------------------------------------------------------------------------- /smart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/__init__.py -------------------------------------------------------------------------------- /smart/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from smart.datamodules.scalable_datamodule import MultiDataModule 2 | -------------------------------------------------------------------------------- /smart/datamodules/scalable_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytorch_lightning as pl 4 | from torch_geometric.loader import DataLoader 5 | from smart.datasets.scalable_dataset import MultiDataset 6 | from smart.transforms import WaymoTargetBuilder 7 | 8 | 9 | class MultiDataModule(pl.LightningDataModule): 10 | transforms = { 11 | "WaymoTargetBuilder": WaymoTargetBuilder, 12 | } 13 | 14 | dataset = { 15 | "scalable": MultiDataset, 16 | } 17 | 18 | def __init__(self, 19 | root: str, 20 | train_batch_size: int, 21 | val_batch_size: int, 22 | test_batch_size: int, 23 | shuffle: bool = False, 24 | num_workers: int = 0, 25 | pin_memory: bool = True, 26 | persistent_workers: bool = True, 27 | train_raw_dir: Optional[str] = None, 28 | val_raw_dir: Optional[str] = None, 29 | test_raw_dir: Optional[str] = None, 30 | train_processed_dir: Optional[str] = None, 31 | val_processed_dir: Optional[str] = None, 32 | test_processed_dir: Optional[str] = None, 33 | transform: Optional[str] = None, 34 | dataset: Optional[str] = None, 35 | num_historical_steps: int = 50, 36 | num_future_steps: int = 60, 37 | processor='ntp', 38 | use_intention=False, 39 | token_size=512, 40 | **kwargs) -> None: 41 | super(MultiDataModule, self).__init__() 42 | self.root = root 43 | self.dataset_class = dataset 44 | self.train_batch_size = train_batch_size 45 | self.val_batch_size = val_batch_size 46 | self.test_batch_size = test_batch_size 47 | self.shuffle = shuffle 48 | self.num_workers = num_workers 49 | self.pin_memory = pin_memory 50 | self.persistent_workers = persistent_workers and num_workers > 0 51 | self.train_raw_dir = train_raw_dir 52 | self.val_raw_dir = val_raw_dir 53 | self.test_raw_dir = test_raw_dir 54 | self.train_processed_dir = train_processed_dir 55 | self.val_processed_dir = val_processed_dir 56 | self.test_processed_dir = test_processed_dir 57 | self.processor = processor 58 | self.use_intention = use_intention 59 | self.token_size = token_size 60 | 61 | train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train") 62 | val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val") 63 | test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps) 64 | 65 | self.train_transform = train_transform 66 | self.val_transform = val_transform 67 | self.test_transform = test_transform 68 | 69 | def setup(self, stage: Optional[str] = None) -> None: 70 | self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir, 71 | raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size) 72 | self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir, 73 | raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size) 74 | self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir, 75 | raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size) 76 | 77 | def train_dataloader(self): 78 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 79 | num_workers=self.num_workers, pin_memory=self.pin_memory, 80 | persistent_workers=self.persistent_workers) 81 | 82 | def val_dataloader(self): 83 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, 84 | num_workers=self.num_workers, pin_memory=self.pin_memory, 85 | persistent_workers=self.persistent_workers) 86 | 87 | def test_dataloader(self): 88 | return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, 89 | num_workers=self.num_workers, pin_memory=self.pin_memory, 90 | persistent_workers=self.persistent_workers) 91 | -------------------------------------------------------------------------------- /smart/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from smart.datasets.scalable_dataset import MultiDataset 2 | -------------------------------------------------------------------------------- /smart/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.interpolate import interp1d 4 | from scipy.spatial.distance import euclidean 5 | import math 6 | import pickle 7 | from smart.utils import wrap_angle 8 | import os 9 | 10 | def cal_polygon_contour(x, y, theta, width, length): 11 | left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) 12 | left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) 13 | left_front = np.column_stack((left_front_x, left_front_y)) 14 | 15 | right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) 16 | right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) 17 | right_front = np.column_stack((right_front_x, right_front_y)) 18 | 19 | right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) 20 | right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) 21 | right_back = np.column_stack((right_back_x, right_back_y)) 22 | 23 | left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) 24 | left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) 25 | left_back = np.column_stack((left_back_x, left_back_y)) 26 | 27 | polygon_contour = np.concatenate( 28 | (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) 29 | 30 | return polygon_contour 31 | 32 | 33 | def interplating_polyline(polylines, heading, distance=0.5, split_distace=5): 34 | # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter 35 | dist_along_path_list = [[0]] 36 | polylines_list = [[polylines[0]]] 37 | for i in range(1, polylines.shape[0]): 38 | euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2]) 39 | heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])), 40 | abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi)) 41 | if heading_diff > math.pi / 4 and euclidean_dist > 3: 42 | dist_along_path_list.append([0]) 43 | polylines_list.append([polylines[i]]) 44 | elif heading_diff > math.pi / 8 and euclidean_dist > 3: 45 | dist_along_path_list.append([0]) 46 | polylines_list.append([polylines[i]]) 47 | elif heading_diff > 0.1 and euclidean_dist > 3: 48 | dist_along_path_list.append([0]) 49 | polylines_list.append([polylines[i]]) 50 | elif euclidean_dist > 10: 51 | dist_along_path_list.append([0]) 52 | polylines_list.append([polylines[i]]) 53 | else: 54 | dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist) 55 | polylines_list[-1].append(polylines[i]) 56 | # plt.plot(polylines[:, 0], polylines[:, 1]) 57 | # plt.savefig('tmp.jpg') 58 | new_x_list = [] 59 | new_y_list = [] 60 | multi_polylines_list = [] 61 | for idx in range(len(dist_along_path_list)): 62 | if len(dist_along_path_list[idx]) < 2: 63 | continue 64 | dist_along_path = np.array(dist_along_path_list[idx]) 65 | polylines_cur = np.array(polylines_list[idx]) 66 | # Create interpolation functions for x and y coordinates 67 | fx = interp1d(dist_along_path, polylines_cur[:, 0]) 68 | fy = interp1d(dist_along_path, polylines_cur[:, 1]) 69 | # fyaw = interp1d(dist_along_path, heading) 70 | 71 | # Create an array of distances at which to interpolate 72 | new_dist_along_path = np.arange(0, dist_along_path[-1], distance) 73 | new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]]) 74 | # Use the interpolation functions to generate new x and y coordinates 75 | new_x = fx(new_dist_along_path) 76 | new_y = fy(new_dist_along_path) 77 | # new_yaw = fyaw(new_dist_along_path) 78 | new_x_list.append(new_x) 79 | new_y_list.append(new_y) 80 | 81 | # Combine the new x and y coordinates into a single array 82 | new_polylines = np.vstack((new_x, new_y)).T 83 | polyline_size = int(split_distace / distance) 84 | if new_polylines.shape[0] >= (polyline_size + 1): 85 | padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size 86 | final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1 87 | else: 88 | padding_size = new_polylines.shape[0] 89 | final_index = 0 90 | multi_polylines = None 91 | new_polylines = torch.from_numpy(new_polylines) 92 | new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1], 93 | new_polylines[1:, 0] - new_polylines[:-1, 0]) 94 | new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] 95 | new_polylines = torch.cat([new_polylines, new_heading], -1) 96 | if new_polylines.shape[0] >= (polyline_size + 1): 97 | multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size) 98 | multi_polylines = multi_polylines.transpose(1, 2) 99 | multi_polylines = multi_polylines[:, ::5, :] 100 | if padding_size >= 3: 101 | last_polyline = new_polylines[final_index * polyline_size:] 102 | last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()] 103 | if multi_polylines is not None: 104 | multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0) 105 | else: 106 | multi_polylines = last_polyline.unsqueeze(0) 107 | if multi_polylines is None: 108 | continue 109 | multi_polylines_list.append(multi_polylines) 110 | if len(multi_polylines_list) > 0: 111 | multi_polylines_list = torch.cat(multi_polylines_list, dim=0) 112 | else: 113 | multi_polylines_list = None 114 | return multi_polylines_list 115 | 116 | 117 | def average_distance_vectorized(point_set1, centroids): 118 | dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1)) 119 | return np.mean(dists, axis=2) 120 | 121 | 122 | def assign_clusters(sub_X, centroids): 123 | distances = average_distance_vectorized(sub_X, centroids) 124 | return np.argmin(distances, axis=1) 125 | 126 | 127 | class TokenProcessor: 128 | 129 | def __init__(self, token_size): 130 | module_dir = os.path.dirname(os.path.dirname(__file__)) 131 | self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl') 132 | self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') 133 | self.noise = False 134 | self.disturb = False 135 | self.shift = 5 136 | self.get_trajectory_token() 137 | self.training = False 138 | self.current_step = 10 139 | 140 | def preprocess(self, data): 141 | data = self.tokenize_agent(data) 142 | data = self.tokenize_map(data) 143 | del data['city'] 144 | if 'polygon_is_intersection' in data['map_polygon']: 145 | del data['map_polygon']['polygon_is_intersection'] 146 | if 'route_type' in data['map_polygon']: 147 | del data['map_polygon']['route_type'] 148 | return data 149 | 150 | def get_trajectory_token(self): 151 | agent_token_data = pickle.load(open(self.agent_token_path, 'rb')) 152 | map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) 153 | self.trajectory_token = agent_token_data['token'] 154 | self.trajectory_token_all = agent_token_data['token_all'] 155 | self.map_token = {'traj_src': map_token_traj['traj_src'], } 156 | self.token_last = {} 157 | for k, v in self.trajectory_token_all.items(): 158 | token_last = torch.from_numpy(v[:, -2:]).to(torch.float) 159 | diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3] 160 | theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) 161 | cos, sin = theta.cos(), theta.sin() 162 | rot_mat = theta.new_zeros(token_last.shape[0], 2, 2) 163 | rot_mat[:, 0, 0] = cos 164 | rot_mat[:, 0, 1] = -sin 165 | rot_mat[:, 1, 0] = sin 166 | rot_mat[:, 1, 1] = cos 167 | agent_token = torch.bmm(token_last[:, 1], rot_mat) 168 | agent_token -= token_last[:, 0].mean(1)[:, None, :] 169 | self.token_last[k] = agent_token.numpy() 170 | 171 | def clean_heading(self, data): 172 | heading = data['agent']['heading'] 173 | valid = data['agent']['valid_mask'] 174 | pi = torch.tensor(torch.pi) 175 | n_vehicles, n_frames = heading.shape 176 | 177 | heading_diff_raw = heading[:, :-1] - heading[:, 1:] 178 | heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi 179 | heading_diff[heading_diff > pi] -= 2 * pi 180 | heading_diff[heading_diff < -pi] += 2 * pi 181 | 182 | valid_pairs = valid[:, :-1] & valid[:, 1:] 183 | 184 | for i in range(n_frames - 1): 185 | change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1] 186 | 187 | heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()] 188 | 189 | if i < n_frames - 2: 190 | heading_diff_raw = heading[:, i + 1] - heading[:, i + 2] 191 | heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi 192 | heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi 193 | heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi 194 | 195 | def tokenize_agent(self, data): 196 | if data['agent']["velocity"].shape[1] == 90: 197 | print(data['scenario_id'], data['agent']["velocity"].shape) 198 | interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * ( 199 | data['agent']['position'][:, self.current_step, 0] != 0) 200 | if data['agent']["velocity"].shape[-1] == 2: 201 | data['agent']["velocity"] = torch.cat([data['agent']["velocity"], 202 | torch.zeros(data['agent']["velocity"].shape[0], 203 | data['agent']["velocity"].shape[1], 1)], dim=-1) 204 | vel = data['agent']["velocity"][interplote_mask, self.current_step] 205 | data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][ 206 | interplote_mask, self.current_step, 207 | :3] - vel * 0.1 208 | data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True 209 | data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][ 210 | interplote_mask, self.current_step] 211 | data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][ 212 | interplote_mask, self.current_step] 213 | 214 | data['agent']['type'] = data['agent']['type'].to(torch.uint8) 215 | 216 | self.clean_heading(data) 217 | matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * ( 218 | data['agent']['valid_mask'][:, self.current_step - 5] == False) 219 | 220 | interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0) 221 | data['agent']['valid_mask'][interplote_mask_first, 0] = True 222 | 223 | agent_pos = data['agent']['position'][:, :, :2] 224 | valid_mask = data['agent']['valid_mask'] 225 | 226 | valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift) 227 | token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1] 228 | agent_type = data['agent']['type'] 229 | agent_category = data['agent']['category'] 230 | agent_heading = data['agent']['heading'] 231 | vehicle_mask = agent_type == 0 232 | cyclist_mask = agent_type == 2 233 | ped_mask = agent_type == 1 234 | 235 | veh_pos = agent_pos[vehicle_mask, :, :] 236 | veh_valid_mask = valid_mask[vehicle_mask, :] 237 | cyc_pos = agent_pos[cyclist_mask, :, :] 238 | cyc_valid_mask = valid_mask[cyclist_mask, :] 239 | ped_pos = agent_pos[ped_mask, :, :] 240 | ped_valid_mask = valid_mask[ped_mask, :] 241 | 242 | veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask], 243 | 'veh', agent_category[vehicle_mask], 244 | matching_extra_mask[vehicle_mask]) 245 | ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped', 246 | agent_category[ped_mask], matching_extra_mask[ped_mask]) 247 | cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask], 248 | 'cyc', agent_category[cyclist_mask], 249 | matching_extra_mask[cyclist_mask]) 250 | 251 | token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64) 252 | token_index[vehicle_mask] = veh_token_index 253 | token_index[ped_mask] = ped_token_index 254 | token_index[cyclist_mask] = cyc_token_index 255 | 256 | token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1], 257 | veh_token_contour.shape[2], veh_token_contour.shape[3])) 258 | token_contour[vehicle_mask] = veh_token_contour 259 | token_contour[ped_mask] = ped_token_contour 260 | token_contour[cyclist_mask] = cyc_token_contour 261 | 262 | trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float) 263 | trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float) 264 | trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float) 265 | 266 | agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2)) 267 | agent_token_traj[vehicle_mask] = trajectory_token_veh 268 | agent_token_traj[ped_mask] = trajectory_token_ped 269 | agent_token_traj[cyclist_mask] = trajectory_token_cyc 270 | 271 | if not self.training: 272 | token_valid_mask[matching_extra_mask, 1] = True 273 | 274 | data['agent']['token_idx'] = token_index 275 | data['agent']['token_contour'] = token_contour 276 | token_pos = token_contour.mean(dim=2) 277 | data['agent']['token_pos'] = token_pos 278 | diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :] 279 | data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) 280 | data['agent']['agent_valid_mask'] = token_valid_mask 281 | 282 | vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2), 283 | ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1) 284 | vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool), 285 | (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1) 286 | vel[~vel_valid_mask] = 0 287 | vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][ 288 | data['agent']['valid_mask'][:, self.current_step], 289 | self.current_step, :2] 290 | 291 | data['agent']['token_velocity'] = vel 292 | 293 | return data 294 | 295 | def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask): 296 | agent_token_src = self.trajectory_token[category] 297 | token_last = self.token_last[category] 298 | if self.shift <= 2: 299 | if category == 'veh': 300 | width = 1.0 301 | length = 2.4 302 | elif category == 'cyc': 303 | width = 0.5 304 | length = 1.5 305 | else: 306 | width = 0.5 307 | length = 0.5 308 | else: 309 | if category == 'veh': 310 | width = 2.0 311 | length = 4.8 312 | elif category == 'cyc': 313 | width = 1.0 314 | length = 2.0 315 | else: 316 | width = 1.0 317 | length = 1.0 318 | 319 | prev_heading = heading[:, 0] 320 | prev_pos = pos[:, 0] 321 | agent_num, num_step, feat_dim = pos.shape 322 | token_num, token_contour_dim, feat_dim = agent_token_src.shape 323 | agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0) 324 | token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0) 325 | token_index_list = [] 326 | token_contour_list = [] 327 | prev_token_idx = None 328 | 329 | for i in range(self.shift, pos.shape[1], self.shift): 330 | theta = prev_heading 331 | cur_heading = heading[:, i] 332 | cur_pos = pos[:, i] 333 | cos, sin = theta.cos(), theta.sin() 334 | rot_mat = theta.new_zeros(agent_num, 2, 2) 335 | rot_mat[:, 0, 0] = cos 336 | rot_mat[:, 0, 1] = sin 337 | rot_mat[:, 1, 0] = -sin 338 | rot_mat[:, 1, 1] = cos 339 | agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num, 340 | token_num, 341 | token_contour_dim, 342 | feat_dim) 343 | agent_token_world += prev_pos[:, None, None, :] 344 | 345 | cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) 346 | agent_token_index = torch.from_numpy(np.argmin( 347 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), 348 | axis=-1)) 349 | if prev_token_idx is not None and self.noise: 350 | same_idx = prev_token_idx == agent_token_index 351 | same_idx[:] = True 352 | topk_indices = np.argsort( 353 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), 354 | axis=2), axis=-1)[:, :5] 355 | sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0]) 356 | agent_token_index[same_idx] = \ 357 | torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx] 358 | 359 | token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index] 360 | 361 | diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :] 362 | 363 | prev_heading = heading[:, i].clone() 364 | prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[ 365 | valid_mask[:, i - self.shift]] 366 | 367 | prev_pos = pos[:, i].clone() 368 | prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]] 369 | prev_token_idx = agent_token_index 370 | token_index_list.append(agent_token_index[:, None]) 371 | token_contour_list.append(token_contour_select[:, None, ...]) 372 | 373 | token_index = torch.cat(token_index_list, dim=1) 374 | token_contour = torch.cat(token_contour_list, dim=1) 375 | 376 | # extra matching 377 | if not self.training: 378 | theta = heading[extra_mask, self.current_step - 1] 379 | prev_pos = pos[extra_mask, self.current_step - 1] 380 | cur_pos = pos[extra_mask, self.current_step] 381 | cur_heading = heading[extra_mask, self.current_step] 382 | cos, sin = theta.cos(), theta.sin() 383 | rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2) 384 | rot_mat[:, 0, 0] = cos 385 | rot_mat[:, 0, 1] = sin 386 | rot_mat[:, 1, 0] = -sin 387 | rot_mat[:, 1, 1] = cos 388 | agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape( 389 | extra_mask.sum(), token_num, token_contour_dim, feat_dim) 390 | agent_token_world += prev_pos[:, None, None, :] 391 | 392 | cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) 393 | agent_token_index = torch.from_numpy(np.argmin( 394 | np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), 395 | axis=-1)) 396 | token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index] 397 | 398 | token_index[extra_mask, 1] = agent_token_index 399 | token_contour[extra_mask, 1] = token_contour_select 400 | 401 | return token_index, token_contour 402 | 403 | def tokenize_map(self, data): 404 | data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8) 405 | data['map_point']['type'] = data['map_point']['type'].to(torch.uint8) 406 | pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index'] 407 | pt_type = data['map_point']['type'].to(torch.uint8) 408 | pt_side = torch.zeros_like(pt_type) 409 | pt_pos = data['map_point']['position'][:, :2] 410 | data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation']) 411 | pt_heading = data['map_point']['orientation'] 412 | split_polyline_type = [] 413 | split_polyline_pos = [] 414 | split_polyline_theta = [] 415 | split_polyline_side = [] 416 | pl_idx_list = [] 417 | split_polygon_type = [] 418 | data['map_point']['type'].unique() 419 | 420 | for i in sorted(np.unique(pt2pl[1])): 421 | index = pt2pl[0, pt2pl[1] == i] 422 | polygon_type = data['map_polygon']["type"][i] 423 | cur_side = pt_side[index] 424 | cur_type = pt_type[index] 425 | cur_pos = pt_pos[index] 426 | cur_heading = pt_heading[index] 427 | 428 | for side_val in np.unique(cur_side): 429 | for type_val in np.unique(cur_type): 430 | if type_val == 13: 431 | continue 432 | indices = np.where((cur_side == side_val) & (cur_type == type_val))[0] 433 | if len(indices) <= 2: 434 | continue 435 | split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy()) 436 | if split_polyline is None: 437 | continue 438 | new_cur_type = cur_type[indices][0] 439 | new_cur_side = cur_side[indices][0] 440 | map_polygon_type = polygon_type.repeat(split_polyline.shape[0]) 441 | new_cur_type = new_cur_type.repeat(split_polyline.shape[0]) 442 | new_cur_side = new_cur_side.repeat(split_polyline.shape[0]) 443 | cur_pl_idx = torch.Tensor([i]) 444 | new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0]) 445 | split_polyline_pos.append(split_polyline[..., :2]) 446 | split_polyline_theta.append(split_polyline[..., 2]) 447 | split_polyline_type.append(new_cur_type) 448 | split_polyline_side.append(new_cur_side) 449 | pl_idx_list.append(new_cur_pl_idx) 450 | split_polygon_type.append(map_polygon_type) 451 | 452 | split_polyline_pos = torch.cat(split_polyline_pos, dim=0) 453 | split_polyline_theta = torch.cat(split_polyline_theta, dim=0) 454 | split_polyline_type = torch.cat(split_polyline_type, dim=0) 455 | split_polyline_side = torch.cat(split_polyline_side, dim=0) 456 | split_polygon_type = torch.cat(split_polygon_type, dim=0) 457 | pl_idx_list = torch.cat(pl_idx_list, dim=0) 458 | vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :] 459 | data['map_save'] = {} 460 | data['pt_token'] = {} 461 | data['map_save']['traj_pos'] = split_polyline_pos 462 | data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0]) 463 | data['map_save']['pl_idx_list'] = pl_idx_list 464 | data['pt_token']['type'] = split_polyline_type 465 | data['pt_token']['side'] = split_polyline_side 466 | data['pt_token']['pl_type'] = split_polygon_type 467 | data['pt_token']['num_nodes'] = split_polyline_pos.shape[0] 468 | return data -------------------------------------------------------------------------------- /smart/datasets/scalable_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from typing import Callable, List, Optional, Tuple, Union 4 | import pandas as pd 5 | from torch_geometric.data import Dataset 6 | from smart.utils.log import Logging 7 | import numpy as np 8 | from .preprocess import TokenProcessor 9 | 10 | 11 | def distance(point1, point2): 12 | return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2) 13 | 14 | 15 | class MultiDataset(Dataset): 16 | def __init__(self, 17 | root: str, 18 | split: str, 19 | raw_dir: List[str] = None, 20 | processed_dir: List[str] = None, 21 | transform: Optional[Callable] = None, 22 | dim: int = 3, 23 | num_historical_steps: int = 50, 24 | num_future_steps: int = 60, 25 | predict_unseen_agents: bool = False, 26 | vector_repr: bool = True, 27 | cluster: bool = False, 28 | processor=None, 29 | use_intention=False, 30 | token_size=512) -> None: 31 | self.logger = Logging().log(level='DEBUG') 32 | self.root = root 33 | self.well_done = [0] 34 | if split not in ('train', 'val', 'test'): 35 | raise ValueError(f'{split} is not a valid split') 36 | self.split = split 37 | self.training = split == 'train' 38 | self.logger.debug("Starting loading dataset") 39 | self._raw_file_names = [] 40 | self._raw_paths = [] 41 | self._raw_file_dataset = [] 42 | if raw_dir is not None: 43 | self._raw_dir = raw_dir 44 | for raw_dir in self._raw_dir: 45 | raw_dir = os.path.expanduser(os.path.normpath(raw_dir)) 46 | dataset = "waymo" 47 | file_list = os.listdir(raw_dir) 48 | self._raw_file_names.extend(file_list) 49 | self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list]) 50 | self._raw_file_dataset.extend([dataset for _ in range(len(file_list))]) 51 | if self.root is not None: 52 | split_datainfo = os.path.join(root, "split_datainfo.pkl") 53 | with open(split_datainfo, 'rb+') as f: 54 | split_datainfo = pickle.load(f) 55 | if split == "test": 56 | split = "val" 57 | self._processed_file_names = split_datainfo[split] 58 | self.dim = dim 59 | self.num_historical_steps = num_historical_steps 60 | self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names) 61 | self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples)) 62 | self.token_processor = TokenProcessor(2048) 63 | super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None) 64 | 65 | @property 66 | def raw_dir(self) -> str: 67 | return self._raw_dir 68 | 69 | @property 70 | def raw_paths(self) -> List[str]: 71 | return self._raw_paths 72 | 73 | @property 74 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 75 | return self._raw_file_names 76 | 77 | @property 78 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 79 | return self._processed_file_names 80 | 81 | def len(self) -> int: 82 | return self._num_samples 83 | 84 | def generate_ref_token(self): 85 | pass 86 | 87 | def get(self, idx: int): 88 | with open(self.raw_paths[idx], 'rb') as handle: 89 | data = pickle.load(handle) 90 | data = self.token_processor.preprocess(data) 91 | return data 92 | -------------------------------------------------------------------------------- /smart/layers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from smart.layers.attention_layer import AttentionLayer 3 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding 4 | from smart.layers.mlp_layer import MLPLayer 5 | -------------------------------------------------------------------------------- /smart/layers/attention_layer.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch_geometric.nn.conv import MessagePassing 7 | from torch_geometric.utils import softmax 8 | 9 | from smart.utils import weight_init 10 | 11 | 12 | class AttentionLayer(MessagePassing): 13 | 14 | def __init__(self, 15 | hidden_dim: int, 16 | num_heads: int, 17 | head_dim: int, 18 | dropout: float, 19 | bipartite: bool, 20 | has_pos_emb: bool, 21 | **kwargs) -> None: 22 | super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) 23 | self.num_heads = num_heads 24 | self.head_dim = head_dim 25 | self.has_pos_emb = has_pos_emb 26 | self.scale = head_dim ** -0.5 27 | 28 | self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) 29 | self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) 30 | self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) 31 | if has_pos_emb: 32 | self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) 33 | self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) 34 | self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) 35 | self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) 36 | self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) 37 | self.attn_drop = nn.Dropout(dropout) 38 | self.ff_mlp = nn.Sequential( 39 | nn.Linear(hidden_dim, hidden_dim * 4), 40 | nn.ReLU(inplace=True), 41 | nn.Dropout(dropout), 42 | nn.Linear(hidden_dim * 4, hidden_dim), 43 | ) 44 | if bipartite: 45 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) 46 | self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) 47 | else: 48 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) 49 | self.attn_prenorm_x_dst = self.attn_prenorm_x_src 50 | if has_pos_emb: 51 | self.attn_prenorm_r = nn.LayerNorm(hidden_dim) 52 | self.attn_postnorm = nn.LayerNorm(hidden_dim) 53 | self.ff_prenorm = nn.LayerNorm(hidden_dim) 54 | self.ff_postnorm = nn.LayerNorm(hidden_dim) 55 | self.apply(weight_init) 56 | 57 | def forward(self, 58 | x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 59 | r: Optional[torch.Tensor], 60 | edge_index: torch.Tensor) -> torch.Tensor: 61 | if isinstance(x, torch.Tensor): 62 | x_src = x_dst = self.attn_prenorm_x_src(x) 63 | else: 64 | x_src, x_dst = x 65 | x_src = self.attn_prenorm_x_src(x_src) 66 | x_dst = self.attn_prenorm_x_dst(x_dst) 67 | x = x[1] 68 | if self.has_pos_emb and r is not None: 69 | r = self.attn_prenorm_r(r) 70 | x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) 71 | x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) 72 | return x 73 | 74 | def message(self, 75 | q_i: torch.Tensor, 76 | k_j: torch.Tensor, 77 | v_j: torch.Tensor, 78 | r: Optional[torch.Tensor], 79 | index: torch.Tensor, 80 | ptr: Optional[torch.Tensor]) -> torch.Tensor: 81 | if self.has_pos_emb and r is not None: 82 | k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) 83 | v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) 84 | sim = (q_i * k_j).sum(dim=-1) * self.scale 85 | attn = softmax(sim, index, ptr) 86 | self.attention_weight = attn.sum(-1).detach() 87 | attn = self.attn_drop(attn) 88 | return v_j * attn.unsqueeze(-1) 89 | 90 | def update(self, 91 | inputs: torch.Tensor, 92 | x_dst: torch.Tensor) -> torch.Tensor: 93 | inputs = inputs.view(-1, self.num_heads * self.head_dim) 94 | g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) 95 | return inputs + g * (self.to_s(x_dst) - inputs) 96 | 97 | def _attn_block(self, 98 | x_src: torch.Tensor, 99 | x_dst: torch.Tensor, 100 | r: Optional[torch.Tensor], 101 | edge_index: torch.Tensor) -> torch.Tensor: 102 | q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) 103 | k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) 104 | v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) 105 | agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) 106 | return self.to_out(agg) 107 | 108 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 109 | return self.ff_mlp(x) 110 | -------------------------------------------------------------------------------- /smart/layers/fourier_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional 3 | import torch 4 | import torch.nn as nn 5 | 6 | from smart.utils import weight_init 7 | 8 | 9 | class FourierEmbedding(nn.Module): 10 | 11 | def __init__(self, 12 | input_dim: int, 13 | hidden_dim: int, 14 | num_freq_bands: int) -> None: 15 | super(FourierEmbedding, self).__init__() 16 | self.input_dim = input_dim 17 | self.hidden_dim = hidden_dim 18 | 19 | self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None 20 | self.mlps = nn.ModuleList( 21 | [nn.Sequential( 22 | nn.Linear(num_freq_bands * 2 + 1, hidden_dim), 23 | nn.LayerNorm(hidden_dim), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(hidden_dim, hidden_dim), 26 | ) 27 | for _ in range(input_dim)]) 28 | self.to_out = nn.Sequential( 29 | nn.LayerNorm(hidden_dim), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(hidden_dim, hidden_dim), 32 | ) 33 | self.apply(weight_init) 34 | 35 | def forward(self, 36 | continuous_inputs: Optional[torch.Tensor] = None, 37 | categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: 38 | if continuous_inputs is None: 39 | if categorical_embs is not None: 40 | x = torch.stack(categorical_embs).sum(dim=0) 41 | else: 42 | raise ValueError('Both continuous_inputs and categorical_embs are None') 43 | else: 44 | x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi 45 | # Warning: if your data are noisy, don't use learnable sinusoidal embedding 46 | x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) 47 | continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim 48 | for i in range(self.input_dim): 49 | continuous_embs[i] = self.mlps[i](x[:, i]) 50 | x = torch.stack(continuous_embs).sum(dim=0) 51 | if categorical_embs is not None: 52 | x = x + torch.stack(categorical_embs).sum(dim=0) 53 | return self.to_out(x) 54 | 55 | 56 | class MLPEmbedding(nn.Module): 57 | def __init__(self, 58 | input_dim: int, 59 | hidden_dim: int) -> None: 60 | super(MLPEmbedding, self).__init__() 61 | self.input_dim = input_dim 62 | self.hidden_dim = hidden_dim 63 | self.mlp = nn.Sequential( 64 | nn.Linear(input_dim, 128), 65 | nn.LayerNorm(128), 66 | nn.ReLU(inplace=True), 67 | nn.Linear(128, hidden_dim), 68 | nn.LayerNorm(hidden_dim), 69 | nn.ReLU(inplace=True), 70 | nn.Linear(hidden_dim, hidden_dim)) 71 | self.apply(weight_init) 72 | 73 | def forward(self, 74 | continuous_inputs: Optional[torch.Tensor] = None, 75 | categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: 76 | if continuous_inputs is None: 77 | if categorical_embs is not None: 78 | x = torch.stack(categorical_embs).sum(dim=0) 79 | else: 80 | raise ValueError('Both continuous_inputs and categorical_embs are None') 81 | else: 82 | x = self.mlp(continuous_inputs) 83 | if categorical_embs is not None: 84 | x = x + torch.stack(categorical_embs).sum(dim=0) 85 | return x 86 | -------------------------------------------------------------------------------- /smart/layers/mlp_layer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from smart.utils import weight_init 6 | 7 | 8 | class MLPLayer(nn.Module): 9 | 10 | def __init__(self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | output_dim: int) -> None: 14 | super(MLPLayer, self).__init__() 15 | self.mlp = nn.Sequential( 16 | nn.Linear(input_dim, hidden_dim), 17 | nn.LayerNorm(hidden_dim), 18 | nn.ReLU(inplace=True), 19 | nn.Linear(hidden_dim, output_dim), 20 | ) 21 | self.apply(weight_init) 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | return self.mlp(x) 25 | -------------------------------------------------------------------------------- /smart/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from smart.metrics.average_meter import AverageMeter 3 | from smart.metrics.min_ade import minADE 4 | from smart.metrics.min_fde import minFDE 5 | from smart.metrics.next_token_cls import TokenCls 6 | -------------------------------------------------------------------------------- /smart/metrics/average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torchmetrics import Metric 4 | 5 | 6 | class AverageMeter(Metric): 7 | 8 | def __init__(self, **kwargs) -> None: 9 | super(AverageMeter, self).__init__(**kwargs) 10 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 11 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 12 | 13 | def update(self, val: torch.Tensor) -> None: 14 | self.sum += val.sum() 15 | self.count += val.numel() 16 | 17 | def compute(self) -> torch.Tensor: 18 | return self.sum / self.count 19 | -------------------------------------------------------------------------------- /smart/metrics/min_ade.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional 3 | 4 | import torch 5 | from torchmetrics import Metric 6 | 7 | from smart.metrics.utils import topk 8 | from smart.metrics.utils import valid_filter 9 | 10 | 11 | class minMultiADE(Metric): 12 | 13 | def __init__(self, 14 | max_guesses: int = 6, 15 | **kwargs) -> None: 16 | super(minMultiADE, self).__init__(**kwargs) 17 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 18 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 19 | self.max_guesses = max_guesses 20 | 21 | def update(self, 22 | pred: torch.Tensor, 23 | target: torch.Tensor, 24 | prob: Optional[torch.Tensor] = None, 25 | valid_mask: Optional[torch.Tensor] = None, 26 | keep_invalid_final_step: bool = True, 27 | min_criterion: str = 'FDE') -> None: 28 | pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) 29 | pred_topk, _ = topk(self.max_guesses, pred, prob) 30 | if min_criterion == 'FDE': 31 | inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) 32 | inds_best = torch.norm( 33 | pred_topk[torch.arange(pred.size(0)), :, inds_last] - 34 | target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) 35 | self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * 36 | valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() 37 | elif min_criterion == 'ADE': 38 | self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) * 39 | valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() 40 | else: 41 | raise ValueError('{} is not a valid criterion'.format(min_criterion)) 42 | self.count += pred.size(0) 43 | 44 | def compute(self) -> torch.Tensor: 45 | return self.sum / self.count 46 | 47 | 48 | class minADE(Metric): 49 | 50 | def __init__(self, 51 | max_guesses: int = 6, 52 | **kwargs) -> None: 53 | super(minADE, self).__init__(**kwargs) 54 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 55 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 56 | self.max_guesses = max_guesses 57 | self.eval_timestep = 70 58 | 59 | def update(self, 60 | pred: torch.Tensor, 61 | target: torch.Tensor, 62 | prob: Optional[torch.Tensor] = None, 63 | valid_mask: Optional[torch.Tensor] = None, 64 | keep_invalid_final_step: bool = True, 65 | min_criterion: str = 'ADE') -> None: 66 | # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) 67 | # pred_topk, _ = topk(self.max_guesses, pred, prob) 68 | # if min_criterion == 'FDE': 69 | # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) 70 | # inds_best = torch.norm( 71 | # pred[torch.arange(pred.size(0)), :, inds_last] - 72 | # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) 73 | # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * 74 | # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() 75 | # elif min_criterion == 'ADE': 76 | # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) * 77 | # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() 78 | # else: 79 | # raise ValueError('{} is not a valid criterion'.format(min_criterion)) 80 | eval_timestep = min(self.eval_timestep, pred.shape[1]) 81 | self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum() 82 | self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum() 83 | 84 | def compute(self) -> torch.Tensor: 85 | return self.sum / self.count 86 | -------------------------------------------------------------------------------- /smart/metrics/min_fde.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from smart.metrics.utils import topk 7 | from smart.metrics.utils import valid_filter 8 | 9 | 10 | class minMultiFDE(Metric): 11 | 12 | def __init__(self, 13 | max_guesses: int = 6, 14 | **kwargs) -> None: 15 | super(minMultiFDE, self).__init__(**kwargs) 16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 18 | self.max_guesses = max_guesses 19 | 20 | def update(self, 21 | pred: torch.Tensor, 22 | target: torch.Tensor, 23 | prob: Optional[torch.Tensor] = None, 24 | valid_mask: Optional[torch.Tensor] = None, 25 | keep_invalid_final_step: bool = True) -> None: 26 | pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) 27 | pred_topk, _ = topk(self.max_guesses, pred, prob) 28 | inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) 29 | self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] - 30 | target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), 31 | p=2, dim=-1).min(dim=-1)[0].sum() 32 | self.count += pred.size(0) 33 | 34 | def compute(self) -> torch.Tensor: 35 | return self.sum / self.count 36 | 37 | 38 | class minFDE(Metric): 39 | 40 | def __init__(self, 41 | max_guesses: int = 6, 42 | **kwargs) -> None: 43 | super(minFDE, self).__init__(**kwargs) 44 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 45 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 46 | self.max_guesses = max_guesses 47 | self.eval_timestep = 70 48 | 49 | def update(self, 50 | pred: torch.Tensor, 51 | target: torch.Tensor, 52 | prob: Optional[torch.Tensor] = None, 53 | valid_mask: Optional[torch.Tensor] = None, 54 | keep_invalid_final_step: bool = True) -> None: 55 | eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1 56 | self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) * 57 | valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum() 58 | self.count += valid_mask[:, eval_timestep-1].sum() 59 | 60 | def compute(self) -> torch.Tensor: 61 | return self.sum / self.count 62 | -------------------------------------------------------------------------------- /smart/metrics/next_token_cls.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from smart.metrics.utils import topk 7 | from smart.metrics.utils import valid_filter 8 | 9 | 10 | class TokenCls(Metric): 11 | 12 | def __init__(self, 13 | max_guesses: int = 6, 14 | **kwargs) -> None: 15 | super(TokenCls, self).__init__(**kwargs) 16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 18 | self.max_guesses = max_guesses 19 | 20 | def update(self, 21 | pred: torch.Tensor, 22 | target: torch.Tensor, 23 | valid_mask: Optional[torch.Tensor] = None) -> None: 24 | target = target[..., None] 25 | acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask 26 | self.sum += acc.sum() 27 | self.count += valid_mask.sum() 28 | 29 | def compute(self) -> torch.Tensor: 30 | return self.sum / self.count 31 | -------------------------------------------------------------------------------- /smart/metrics/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch_scatter import gather_csr 5 | from torch_scatter import segment_csr 6 | 7 | 8 | def topk( 9 | max_guesses: int, 10 | pred: torch.Tensor, 11 | prob: Optional[torch.Tensor] = None, 12 | ptr: Optional[torch.Tensor] = None, 13 | joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: 14 | max_guesses = min(max_guesses, pred.size(1)) 15 | if max_guesses == pred.size(1): 16 | if prob is not None: 17 | prob = prob / prob.sum(dim=-1, keepdim=True) 18 | else: 19 | prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses 20 | return pred, prob 21 | else: 22 | if prob is not None: 23 | if joint: 24 | if ptr is None: 25 | inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), 26 | k=max_guesses, dim=-1, largest=True, sorted=True)[1] 27 | inds_topk = inds_topk.repeat(pred.size(0), 1) 28 | else: 29 | inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, 30 | reduce='mean'), 31 | k=max_guesses, dim=-1, largest=True, sorted=True)[1] 32 | inds_topk = gather_csr(src=inds_topk, indptr=ptr) 33 | else: 34 | inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] 35 | pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] 36 | prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] 37 | prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) 38 | else: 39 | pred_topk = pred[:, :max_guesses] 40 | prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses 41 | return pred_topk, prob_topk 42 | 43 | 44 | def topkind( 45 | max_guesses: int, 46 | pred: torch.Tensor, 47 | prob: Optional[torch.Tensor] = None, 48 | ptr: Optional[torch.Tensor] = None, 49 | joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 50 | max_guesses = min(max_guesses, pred.size(1)) 51 | if max_guesses == pred.size(1): 52 | if prob is not None: 53 | prob = prob / prob.sum(dim=-1, keepdim=True) 54 | else: 55 | prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses 56 | return pred, prob, None 57 | else: 58 | if prob is not None: 59 | if joint: 60 | if ptr is None: 61 | inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), 62 | k=max_guesses, dim=-1, largest=True, sorted=True)[1] 63 | inds_topk = inds_topk.repeat(pred.size(0), 1) 64 | else: 65 | inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, 66 | reduce='mean'), 67 | k=max_guesses, dim=-1, largest=True, sorted=True)[1] 68 | inds_topk = gather_csr(src=inds_topk, indptr=ptr) 69 | else: 70 | inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] 71 | pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] 72 | prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] 73 | prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) 74 | else: 75 | pred_topk = pred[:, :max_guesses] 76 | prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses 77 | return pred_topk, prob_topk, inds_topk 78 | 79 | 80 | def valid_filter( 81 | pred: torch.Tensor, 82 | target: torch.Tensor, 83 | prob: Optional[torch.Tensor] = None, 84 | valid_mask: Optional[torch.Tensor] = None, 85 | ptr: Optional[torch.Tensor] = None, 86 | keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], 87 | torch.Tensor, torch.Tensor]: 88 | if valid_mask is None: 89 | valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool) 90 | if keep_invalid_final_step: 91 | filter_mask = valid_mask.any(dim=-1) 92 | else: 93 | filter_mask = valid_mask[:, -1] 94 | pred = pred[filter_mask] 95 | target = target[filter_mask] 96 | if prob is not None: 97 | prob = prob[filter_mask] 98 | valid_mask = valid_mask[filter_mask] 99 | if ptr is not None: 100 | num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum') 101 | ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,)) 102 | torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:]) 103 | else: 104 | ptr = target.new_tensor([0, target.size(0)]) 105 | return pred, target, prob, valid_mask, ptr 106 | 107 | 108 | def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6): 109 | """ 110 | 111 | Args: 112 | pred_trajs (batch_size, num_modes, num_timestamps, 7) 113 | pred_scores (batch_size, num_modes): 114 | dist_thresh (float): 115 | num_ret_modes (int, optional): Defaults to 6. 116 | 117 | Returns: 118 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) 119 | ret_scores (batch_size, num_ret_modes) 120 | ret_idxs (batch_size, num_ret_modes) 121 | """ 122 | batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape 123 | pred_goals = pred_trajs[:, :, -1, :] 124 | dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1) 125 | nearby_neighbor = dist < dist_thresh 126 | pred_scores = nearby_neighbor.sum(dim=-1) / num_modes 127 | 128 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True) 129 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) 130 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] 131 | sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) 132 | sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) 133 | 134 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) 135 | point_cover_mask = (dist < dist_thresh) 136 | 137 | point_val = sorted_pred_scores.clone() # (batch_size, N) 138 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N) 139 | 140 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() 141 | ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) 142 | ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) 143 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs) 144 | 145 | for k in range(num_ret_modes): 146 | cur_idx = point_val.argmax(dim=-1) # (batch_size) 147 | ret_idxs[:, k] = cur_idx 148 | 149 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) 150 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N) 151 | point_val_selected[bs_idxs, cur_idx] = -1 152 | point_val += point_val_selected 153 | 154 | ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] 155 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] 156 | 157 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) 158 | 159 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs] 160 | return ret_trajs, ret_scores, ret_idxs 161 | 162 | 163 | def batch_nms(pred_trajs, pred_scores, 164 | dist_thresh, num_ret_modes=6, 165 | mode='static', speed=None): 166 | """ 167 | 168 | Args: 169 | pred_trajs (batch_size, num_modes, num_timestamps, 7) 170 | pred_scores (batch_size, num_modes): 171 | dist_thresh (float): 172 | num_ret_modes (int, optional): Defaults to 6. 173 | 174 | Returns: 175 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) 176 | ret_scores (batch_size, num_ret_modes) 177 | ret_idxs (batch_size, num_ret_modes) 178 | """ 179 | batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape 180 | 181 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True) 182 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) 183 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] 184 | sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) 185 | sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) 186 | 187 | if mode == "speed": 188 | scale = torch.ones(batch_size).to(sorted_pred_goals.device) 189 | lon_dist_thresh = 4 * scale 190 | lat_dist_thresh = 0.5 * scale 191 | lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1) 192 | lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1) 193 | point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None]) 194 | else: 195 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) 196 | point_cover_mask = (dist < dist_thresh) 197 | 198 | point_val = sorted_pred_scores.clone() # (batch_size, N) 199 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N) 200 | 201 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() 202 | ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) 203 | ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) 204 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs) 205 | 206 | for k in range(num_ret_modes): 207 | cur_idx = point_val.argmax(dim=-1) # (batch_size) 208 | ret_idxs[:, k] = cur_idx 209 | 210 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) 211 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N) 212 | point_val_selected[bs_idxs, cur_idx] = -1 213 | point_val += point_val_selected 214 | 215 | ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] 216 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] 217 | 218 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) 219 | 220 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs] 221 | return ret_trajs, ret_scores, ret_idxs 222 | 223 | 224 | def batch_nms_token(pred_trajs, pred_scores, 225 | dist_thresh, num_ret_modes=6, 226 | mode='static', speed=None): 227 | """ 228 | Args: 229 | pred_trajs (batch_size, num_modes, num_timestamps, 7) 230 | pred_scores (batch_size, num_modes): 231 | dist_thresh (float): 232 | num_ret_modes (int, optional): Defaults to 6. 233 | 234 | Returns: 235 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) 236 | ret_scores (batch_size, num_ret_modes) 237 | ret_idxs (batch_size, num_ret_modes) 238 | """ 239 | batch_size, num_modes, num_feat_dim = pred_trajs.shape 240 | 241 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True) 242 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) 243 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] 244 | sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) 245 | 246 | if mode == "nearby": 247 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) 248 | values, indices = torch.topk(dist, 5, dim=-1, largest=False) 249 | thresh_hold = values[..., -1] 250 | point_cover_mask = dist < thresh_hold[..., None] 251 | else: 252 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) 253 | point_cover_mask = (dist < dist_thresh) 254 | 255 | point_val = sorted_pred_scores.clone() # (batch_size, N) 256 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N) 257 | 258 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() 259 | ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim) 260 | ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes) 261 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs) 262 | 263 | for k in range(num_ret_modes): 264 | cur_idx = point_val.argmax(dim=-1) # (batch_size) 265 | ret_idxs[:, k] = cur_idx 266 | 267 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) 268 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N) 269 | point_val_selected[bs_idxs, cur_idx] = -1 270 | point_val += point_val_selected 271 | 272 | ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx] 273 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] 274 | 275 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) 276 | 277 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs] 278 | return ret_goals, ret_scores, ret_idxs 279 | -------------------------------------------------------------------------------- /smart/model/__init__.py: -------------------------------------------------------------------------------- 1 | from smart.model.smart import SMART 2 | -------------------------------------------------------------------------------- /smart/model/smart.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | from torch_geometric.data import Batch 6 | from torch_geometric.data import HeteroData 7 | from smart.metrics import minADE 8 | from smart.metrics import minFDE 9 | from smart.metrics import TokenCls 10 | from smart.modules import SMARTDecoder 11 | from torch.optim.lr_scheduler import LambdaLR 12 | import math 13 | import numpy as np 14 | import pickle 15 | from collections import defaultdict 16 | import os 17 | from waymo_open_dataset.protos import sim_agents_submission_pb2 18 | 19 | 20 | def cal_polygon_contour(x, y, theta, width, length): 21 | left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) 22 | left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) 23 | left_front = (left_front_x, left_front_y) 24 | 25 | right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) 26 | right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) 27 | right_front = (right_front_x, right_front_y) 28 | 29 | right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) 30 | right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) 31 | right_back = (right_back_x, right_back_y) 32 | 33 | left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) 34 | left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) 35 | left_back = (left_back_x, left_back_y) 36 | polygon_contour = [left_front, right_front, right_back, left_back] 37 | 38 | return polygon_contour 39 | 40 | 41 | def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene: 42 | states = states.numpy() 43 | simulated_trajectories = [] 44 | for i_object in range(len(object_ids)): 45 | simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory( 46 | center_x=states[i_object, :, 0], center_y=states[i_object, :, 1], 47 | center_z=states[i_object, :, 2], heading=states[i_object, :, 3], 48 | object_id=object_ids[i_object].item() 49 | )) 50 | return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories) 51 | 52 | 53 | class SMART(pl.LightningModule): 54 | 55 | def __init__(self, model_config) -> None: 56 | super(SMART, self).__init__() 57 | self.save_hyperparameters() 58 | self.model_config = model_config 59 | self.warmup_steps = model_config.warmup_steps 60 | self.lr = model_config.lr 61 | self.total_steps = model_config.total_steps 62 | self.dataset = model_config.dataset 63 | self.input_dim = model_config.input_dim 64 | self.hidden_dim = model_config.hidden_dim 65 | self.output_dim = model_config.output_dim 66 | self.output_head = model_config.output_head 67 | self.num_historical_steps = model_config.num_historical_steps 68 | self.num_future_steps = model_config.decoder.num_future_steps 69 | self.num_freq_bands = model_config.num_freq_bands 70 | self.vis_map = False 71 | self.noise = True 72 | module_dir = os.path.dirname(os.path.dirname(__file__)) 73 | self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') 74 | self.init_map_token() 75 | self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl') 76 | token_data = self.get_trajectory_token() 77 | self.encoder = SMARTDecoder( 78 | dataset=model_config.dataset, 79 | input_dim=model_config.input_dim, 80 | hidden_dim=model_config.hidden_dim, 81 | num_historical_steps=model_config.num_historical_steps, 82 | num_freq_bands=model_config.num_freq_bands, 83 | num_heads=model_config.num_heads, 84 | head_dim=model_config.head_dim, 85 | dropout=model_config.dropout, 86 | num_map_layers=model_config.decoder.num_map_layers, 87 | num_agent_layers=model_config.decoder.num_agent_layers, 88 | pl2pl_radius=model_config.decoder.pl2pl_radius, 89 | pl2a_radius=model_config.decoder.pl2a_radius, 90 | a2a_radius=model_config.decoder.a2a_radius, 91 | time_span=model_config.decoder.time_span, 92 | map_token={'traj_src': self.map_token['traj_src']}, 93 | token_data=token_data, 94 | token_size=model_config.decoder.token_size 95 | ) 96 | self.minADE = minADE(max_guesses=1) 97 | self.minFDE = minFDE(max_guesses=1) 98 | self.TokenCls = TokenCls(max_guesses=1) 99 | 100 | self.test_predictions = dict() 101 | self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) 102 | self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) 103 | self.inference_token = False 104 | self.rollout_num = 1 105 | 106 | def get_trajectory_token(self): 107 | token_data = pickle.load(open(self.token_path, 'rb')) 108 | self.trajectory_token = token_data['token'] 109 | self.trajectory_token_traj = token_data['traj'] 110 | self.trajectory_token_all = token_data['token_all'] 111 | return token_data 112 | 113 | def init_map_token(self): 114 | self.argmin_sample_len = 3 115 | map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) 116 | self.map_token = {'traj_src': map_token_traj['traj_src'], } 117 | traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1], 118 | self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0]) 119 | indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long() 120 | self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float) 121 | self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) 122 | self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float) 123 | 124 | def forward(self, data: HeteroData): 125 | res = self.encoder(data) 126 | return res 127 | 128 | def inference(self, data: HeteroData): 129 | res = self.encoder.inference(data) 130 | return res 131 | 132 | def maybe_autocast(self, dtype=torch.float16): 133 | enable_autocast = self.device != torch.device("cpu") 134 | 135 | if enable_autocast: 136 | return torch.cuda.amp.autocast(dtype=dtype) 137 | else: 138 | return contextlib.nullcontext() 139 | 140 | def training_step(self, 141 | data, 142 | batch_idx): 143 | data = self.match_token_map(data) 144 | data = self.sample_pt_pred(data) 145 | if isinstance(data, Batch): 146 | data['agent']['av_index'] += data['agent']['ptr'][:-1] 147 | pred = self(data) 148 | next_token_prob = pred['next_token_prob'] 149 | next_token_idx_gt = pred['next_token_idx_gt'] 150 | next_token_eval_mask = pred['next_token_eval_mask'] 151 | cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) 152 | loss = cls_loss 153 | self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) 154 | self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) 155 | return loss 156 | 157 | def validation_step(self, 158 | data, 159 | batch_idx): 160 | data = self.match_token_map(data) 161 | data = self.sample_pt_pred(data) 162 | if isinstance(data, Batch): 163 | data['agent']['av_index'] += data['agent']['ptr'][:-1] 164 | pred = self(data) 165 | next_token_idx = pred['next_token_idx'] 166 | next_token_idx_gt = pred['next_token_idx_gt'] 167 | next_token_eval_mask = pred['next_token_eval_mask'] 168 | next_token_prob = pred['next_token_prob'] 169 | cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) 170 | loss = cls_loss 171 | self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], 172 | valid_mask=next_token_eval_mask[next_token_eval_mask]) 173 | self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) 174 | self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) 175 | 176 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3) 177 | if self.inference_token: 178 | pred = self.inference(data) 179 | pos_a = pred['pos_a'] 180 | gt = pred['gt'] 181 | valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:] 182 | pred_traj = pred['pred_traj'] 183 | # next_token_idx = pred['next_token_idx'][..., None] 184 | # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] 185 | # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:] 186 | # next_token_eval_mask[:, 1:] = False 187 | # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], 188 | # valid_mask=next_token_eval_mask[next_token_eval_mask]) 189 | # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) 190 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] 191 | 192 | self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) 193 | self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) 194 | # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute()) 195 | 196 | self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) 197 | self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) 198 | 199 | def on_validation_start(self): 200 | self.gt = [] 201 | self.pred = [] 202 | self.scenario_rollouts = [] 203 | self.batch_metric = defaultdict(list) 204 | 205 | def configure_optimizers(self): 206 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 207 | 208 | def lr_lambda(current_step): 209 | if current_step + 1 < self.warmup_steps: 210 | return float(current_step + 1) / float(max(1, self.warmup_steps)) 211 | return max( 212 | 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)))) 213 | ) 214 | 215 | lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) 216 | return [optimizer], [lr_scheduler] 217 | 218 | def load_params_from_file(self, filename, logger, to_cpu=False): 219 | if not os.path.isfile(filename): 220 | raise FileNotFoundError 221 | 222 | logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU')) 223 | loc_type = torch.device('cpu') if to_cpu else None 224 | checkpoint = torch.load(filename, map_location=loc_type) 225 | model_state_disk = checkpoint['state_dict'] 226 | 227 | version = checkpoint.get("version", None) 228 | if version is not None: 229 | logger.info('==> Checkpoint trained from version: %s' % version) 230 | 231 | logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}') 232 | model_state = self.state_dict() 233 | model_state_disk_filter = {} 234 | for key, val in model_state_disk.items(): 235 | if key in model_state and model_state_disk[key].shape == model_state[key].shape: 236 | model_state_disk_filter[key] = val 237 | else: 238 | if key not in model_state: 239 | print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}') 240 | else: 241 | print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}') 242 | 243 | model_state_disk = model_state_disk_filter 244 | 245 | missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False) 246 | 247 | logger.info(f'Missing keys: {missing_keys}') 248 | logger.info(f'The number of missing keys: {len(missing_keys)}') 249 | logger.info(f'The number of unexpected keys: {len(unexpected_keys)}') 250 | logger.info('==> Done (total keys %d)' % (len(model_state))) 251 | 252 | epoch = checkpoint.get('epoch', -1) 253 | it = checkpoint.get('it', 0.0) 254 | 255 | return it, epoch 256 | 257 | def match_token_map(self, data): 258 | traj_pos = data['map_save']['traj_pos'].to(torch.float) 259 | traj_theta = data['map_save']['traj_theta'].to(torch.float) 260 | pl_idx_list = data['map_save']['pl_idx_list'] 261 | token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device) 262 | token_src = self.map_token['traj_src'].to(traj_pos.device) 263 | max_traj_len = self.map_token['traj_src'].shape[1] 264 | pl_num = traj_pos.shape[0] 265 | 266 | pt_token_pos = traj_pos[:, 0, :].clone() 267 | pt_token_orientation = traj_theta.clone() 268 | cos, sin = traj_theta.cos(), traj_theta.sin() 269 | rot_mat = traj_theta.new_zeros(pl_num, 2, 2) 270 | rot_mat[..., 0, 0] = cos 271 | rot_mat[..., 0, 1] = -sin 272 | rot_mat[..., 1, 0] = sin 273 | rot_mat[..., 1, 1] = cos 274 | traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) 275 | distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)) 276 | pt_token_id = torch.argmin(distance, dim=1) 277 | 278 | if self.noise: 279 | topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8] 280 | sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) 281 | pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) 282 | 283 | cos, sin = traj_theta.cos(), traj_theta.sin() 284 | rot_mat = traj_theta.new_zeros(pl_num, 2, 2) 285 | rot_mat[..., 0, 0] = cos 286 | rot_mat[..., 0, 1] = sin 287 | rot_mat[..., 1, 0] = -sin 288 | rot_mat[..., 1, 1] = cos 289 | token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2), 290 | rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :] 291 | token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2) 292 | 293 | pl_idx_full = pl_idx_list.clone() 294 | token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) 295 | count_nums = [] 296 | for pl in pl_idx_full.unique(): 297 | pt = token2pl[0, token2pl[1, :] == pl] 298 | left_side = (data['pt_token']['side'][pt] == 0).sum() 299 | right_side = (data['pt_token']['side'][pt] == 1).sum() 300 | center_side = (data['pt_token']['side'][pt] == 2).sum() 301 | count_nums.append(torch.Tensor([left_side, right_side, center_side])) 302 | count_nums = torch.stack(count_nums, dim=0) 303 | num_polyline = int(count_nums.max().item()) 304 | traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) 305 | idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) 306 | idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) # 307 | counts_num_expanded = count_nums.unsqueeze(-1) 308 | mask_update = idx_matrix < counts_num_expanded 309 | traj_mask[mask_update] = True 310 | 311 | data['pt_token']['traj_mask'] = traj_mask 312 | data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), 313 | device=traj_pos.device, dtype=torch.float)], dim=-1) 314 | data['pt_token']['orientation'] = pt_token_orientation 315 | data['pt_token']['height'] = data['pt_token']['position'][:, -1] 316 | data[('pt_token', 'to', 'map_polygon')] = {} 317 | data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl 318 | data['pt_token']['token_idx'] = pt_token_id 319 | return data 320 | 321 | def sample_pt_pred(self, data): 322 | traj_mask = data['pt_token']['traj_mask'] 323 | raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1) 324 | masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)] 325 | masked_pt_index = torch.sort(masked_pt_index, -1)[0] 326 | pt_valid_mask = traj_mask.clone() 327 | pt_valid_mask.scatter_(2, masked_pt_index, False) 328 | pt_pred_mask = traj_mask.clone() 329 | pt_pred_mask.scatter_(2, masked_pt_index, False) 330 | tmp_mask = pt_pred_mask.clone() 331 | tmp_mask[:, :, :] = True 332 | tmp_mask.scatter_(2, masked_pt_index-1, False) 333 | pt_pred_mask.masked_fill_(tmp_mask, False) 334 | pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2) 335 | pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2) 336 | 337 | data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask] 338 | data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask] 339 | data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask] 340 | 341 | return data 342 | -------------------------------------------------------------------------------- /smart/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from smart.modules.smart_decoder import SMARTDecoder 2 | from smart.modules.map_decoder import SMARTMapDecoder 3 | from smart.modules.agent_decoder import SMARTAgentDecoder 4 | -------------------------------------------------------------------------------- /smart/modules/agent_decoder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Dict, Mapping, Optional 3 | import torch 4 | import torch.nn as nn 5 | from smart.layers import MLPLayer 6 | from smart.layers.attention_layer import AttentionLayer 7 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding 8 | from torch_cluster import radius, radius_graph 9 | from torch_geometric.data import Batch, HeteroData 10 | from torch_geometric.utils import dense_to_sparse, subgraph 11 | from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle 12 | import math 13 | 14 | 15 | def cal_polygon_contour(x, y, theta, width, length): 16 | left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) 17 | left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) 18 | left_front = (left_front_x, left_front_y) 19 | 20 | right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) 21 | right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) 22 | right_front = (right_front_x, right_front_y) 23 | 24 | right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) 25 | right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) 26 | right_back = (right_back_x, right_back_y) 27 | 28 | left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) 29 | left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) 30 | left_back = (left_back_x, left_back_y) 31 | polygon_contour = [left_front, right_front, right_back, left_back] 32 | 33 | return polygon_contour 34 | 35 | 36 | class SMARTAgentDecoder(nn.Module): 37 | 38 | def __init__(self, 39 | dataset: str, 40 | input_dim: int, 41 | hidden_dim: int, 42 | num_historical_steps: int, 43 | time_span: Optional[int], 44 | pl2a_radius: float, 45 | a2a_radius: float, 46 | num_freq_bands: int, 47 | num_layers: int, 48 | num_heads: int, 49 | head_dim: int, 50 | dropout: float, 51 | token_data: Dict, 52 | token_size=512) -> None: 53 | super(SMARTAgentDecoder, self).__init__() 54 | self.dataset = dataset 55 | self.input_dim = input_dim 56 | self.hidden_dim = hidden_dim 57 | self.num_historical_steps = num_historical_steps 58 | self.time_span = time_span if time_span is not None else num_historical_steps 59 | self.pl2a_radius = pl2a_radius 60 | self.a2a_radius = a2a_radius 61 | self.num_freq_bands = num_freq_bands 62 | self.num_layers = num_layers 63 | self.num_heads = num_heads 64 | self.head_dim = head_dim 65 | self.dropout = dropout 66 | 67 | input_dim_x_a = 2 68 | input_dim_r_t = 4 69 | input_dim_r_pt2a = 3 70 | input_dim_r_a2a = 3 71 | input_dim_token = 8 72 | 73 | self.type_a_emb = nn.Embedding(4, hidden_dim) 74 | self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) 75 | 76 | self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) 77 | self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) 78 | self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, 79 | num_freq_bands=num_freq_bands) 80 | self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, 81 | num_freq_bands=num_freq_bands) 82 | self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) 83 | self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) 84 | self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) 85 | self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim) 86 | 87 | self.t_attn_layers = nn.ModuleList( 88 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, 89 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)] 90 | ) 91 | self.pt2a_attn_layers = nn.ModuleList( 92 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, 93 | bipartite=True, has_pos_emb=True) for _ in range(num_layers)] 94 | ) 95 | self.a2a_attn_layers = nn.ModuleList( 96 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, 97 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)] 98 | ) 99 | self.token_size = token_size 100 | self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, 101 | output_dim=self.token_size) 102 | self.trajectory_token = token_data['token'] 103 | self.trajectory_token_traj = token_data['traj'] 104 | self.trajectory_token_all = token_data['token_all'] 105 | self.apply(weight_init) 106 | self.shift = 5 107 | self.beam_size = 5 108 | self.hist_mask = True 109 | 110 | def transform_rel(self, token_traj, prev_pos, prev_heading=None): 111 | if prev_heading is None: 112 | diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] 113 | prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) 114 | 115 | num_agent, num_step, traj_num, traj_dim = token_traj.shape 116 | cos, sin = prev_heading.cos(), prev_heading.sin() 117 | rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) 118 | rot_mat[:, :, 0, 0] = cos 119 | rot_mat[:, :, 0, 1] = -sin 120 | rot_mat[:, :, 1, 0] = sin 121 | rot_mat[:, :, 1, 1] = cos 122 | agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) 123 | agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] 124 | return agent_pred_rel 125 | 126 | def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False): 127 | num_agent, num_step, traj_dim = pos_a.shape 128 | motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), 129 | pos_a[:, 1:] - pos_a[:, :-1]], dim=1) 130 | 131 | agent_type = data['agent']['type'] 132 | veh_mask = (agent_type == 0) 133 | cyc_mask = (agent_type == 2) 134 | ped_mask = (agent_type == 1) 135 | trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) 136 | self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) 137 | trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) 138 | self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1)) 139 | trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) 140 | self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1)) 141 | 142 | if inference: 143 | agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) 144 | trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to( 145 | torch.float) 146 | trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to( 147 | torch.float) 148 | trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to( 149 | torch.float) 150 | agent_token_traj_all[veh_mask] = torch.cat( 151 | [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) 152 | agent_token_traj_all[ped_mask] = torch.cat( 153 | [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) 154 | agent_token_traj_all[cyc_mask] = torch.cat( 155 | [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) 156 | 157 | agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) 158 | agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]] 159 | agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]] 160 | agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]] 161 | 162 | agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device) 163 | agent_token_traj[veh_mask] = trajectory_token_veh 164 | agent_token_traj[ped_mask] = trajectory_token_ped 165 | agent_token_traj[cyc_mask] = trajectory_token_cyc 166 | 167 | vel = data['agent']['token_velocity'] 168 | 169 | categorical_embs = [ 170 | self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step, 171 | dim=0), 172 | 173 | self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave( 174 | repeats=num_step, 175 | dim=0) 176 | ] 177 | feature_a = torch.stack( 178 | [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), 179 | angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), 180 | ], dim=-1) 181 | 182 | x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), 183 | categorical_embs=categorical_embs) 184 | x_a = x_a.view(-1, num_step, self.hidden_dim) 185 | 186 | feat_a = torch.cat((agent_token_emb, x_a), dim=-1) 187 | feat_a = self.fusion_emb(feat_a) 188 | 189 | if inference: 190 | return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs 191 | else: 192 | return feat_a, agent_token_traj 193 | 194 | def agent_predict_next(self, data, agent_category, feat_a): 195 | num_agent, num_step, traj_dim = data['agent']['token_pos'].shape 196 | agent_type = data['agent']['type'] 197 | veh_mask = (agent_type == 0) # * agent_category==3 198 | cyc_mask = (agent_type == 2) # * agent_category==3 199 | ped_mask = (agent_type == 1) # * agent_category==3 200 | token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device) 201 | token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) 202 | token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) 203 | token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) 204 | return token_res 205 | 206 | def agent_predict_next_inf(self, data, agent_category, feat_a): 207 | num_agent, traj_dim = feat_a.shape 208 | agent_type = data['agent']['type'] 209 | 210 | veh_mask = (agent_type == 0) # * agent_category==3 211 | cyc_mask = (agent_type == 2) # * agent_category==3 212 | ped_mask = (agent_type == 1) # * agent_category==3 213 | 214 | token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device) 215 | token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) 216 | token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) 217 | token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) 218 | 219 | return token_res 220 | 221 | def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None): 222 | pos_t = pos_a.reshape(-1, self.input_dim) 223 | head_t = head_a.reshape(-1) 224 | head_vector_t = head_vector_a.reshape(-1, 2) 225 | hist_mask = mask.clone() 226 | 227 | if self.hist_mask and self.training: 228 | hist_mask[ 229 | torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False 230 | mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) 231 | elif inference_mask is not None: 232 | mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) 233 | else: 234 | mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) 235 | 236 | edge_index_t = dense_to_sparse(mask_t)[0] 237 | edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]] 238 | edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift] 239 | rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] 240 | rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) 241 | r_t = torch.stack( 242 | [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), 243 | angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), 244 | rel_head_t, 245 | edge_index_t[0] - edge_index_t[1]], dim=-1) 246 | r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) 247 | return edge_index_t, r_t 248 | 249 | def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s): 250 | pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) 251 | head_s = head_a.transpose(0, 1).reshape(-1) 252 | head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) 253 | edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, 254 | max_num_neighbors=300) 255 | edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0] 256 | rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] 257 | rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) 258 | r_a2a = torch.stack( 259 | [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), 260 | angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), 261 | rel_head_a2a], dim=-1) 262 | r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) 263 | return edge_index_a2a, r_a2a 264 | 265 | def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask, 266 | batch_s, batch_pl): 267 | mask_pl2a = mask.clone() 268 | mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) 269 | pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) 270 | head_s = head_a.transpose(0, 1).reshape(-1) 271 | head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) 272 | pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() 273 | orient_pl = data['pt_token']['orientation'].contiguous() 274 | pos_pl = pos_pl.repeat(num_step, 1) 275 | orient_pl = orient_pl.repeat(num_step) 276 | edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, 277 | batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) 278 | edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] 279 | rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] 280 | rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) 281 | r_pl2a = torch.stack( 282 | [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), 283 | angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), 284 | rel_orient_pl2a], dim=-1) 285 | r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) 286 | return edge_index_pl2a, r_pl2a 287 | 288 | def forward(self, 289 | data: HeteroData, 290 | map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 291 | pos_a = data['agent']['token_pos'] 292 | head_a = data['agent']['token_heading'] 293 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) 294 | num_agent, num_step, traj_dim = pos_a.shape 295 | agent_category = data['agent']['category'] 296 | agent_token_index = data['agent']['token_idx'] 297 | feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index, 298 | pos_a, head_vector_a) 299 | 300 | agent_valid_mask = data['agent']['agent_valid_mask'].clone() 301 | # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] 302 | # agent_valid_mask[~eval_mask] = False 303 | mask = agent_valid_mask 304 | edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask) 305 | 306 | if isinstance(data, Batch): 307 | batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t 308 | for t in range(num_step)], dim=0) 309 | batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t 310 | for t in range(num_step)], dim=0) 311 | else: 312 | batch_s = torch.arange(num_step, 313 | device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) 314 | batch_pl = torch.arange(num_step, 315 | device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) 316 | 317 | mask_s = mask.transpose(0, 1).reshape(-1) 318 | edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s) 319 | mask[agent_category != 3] = False 320 | edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, 321 | head_vector_a, mask, batch_s, batch_pl) 322 | 323 | for i in range(self.num_layers): 324 | feat_a = feat_a.reshape(-1, self.hidden_dim) 325 | feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) 326 | feat_a = feat_a.reshape(-1, num_step, 327 | self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) 328 | feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( 329 | repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( 330 | -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) 331 | feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) 332 | feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) 333 | 334 | num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape 335 | next_token_prob = self.token_predict_head(feat_a) 336 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) 337 | _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) 338 | 339 | next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) 340 | next_token_eval_mask = mask.clone() 341 | next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) 342 | next_token_eval_mask[:, -1] = False 343 | 344 | return {'x_a': feat_a, 345 | 'next_token_idx': next_token_idx, 346 | 'next_token_prob': next_token_prob, 347 | 'next_token_idx_gt': next_token_index_gt, 348 | 'next_token_eval_mask': next_token_eval_mask, 349 | } 350 | 351 | def inference(self, 352 | data: HeteroData, 353 | map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 354 | eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] 355 | pos_a = data['agent']['token_pos'].clone() 356 | head_a = data['agent']['token_heading'].clone() 357 | num_agent, num_step, traj_dim = pos_a.shape 358 | pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 359 | head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 360 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) 361 | 362 | agent_valid_mask = data['agent']['agent_valid_mask'].clone() 363 | agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True 364 | agent_valid_mask[~eval_mask] = False 365 | agent_token_index = data['agent']['token_idx'] 366 | agent_category = data['agent']['category'] 367 | feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding( 368 | data, 369 | agent_category, 370 | agent_token_index, 371 | pos_a, 372 | head_vector_a, 373 | inference=True) 374 | 375 | agent_type = data["agent"]["type"] 376 | veh_mask = (agent_type == 0) # * agent_category==3 377 | cyc_mask = (agent_type == 2) # * agent_category==3 378 | ped_mask = (agent_type == 1) # * agent_category==3 379 | av_mask = data["agent"]["av_index"] 380 | 381 | self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps 382 | pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device) 383 | pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device) 384 | pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device) 385 | next_token_idx_list = [] 386 | mask = agent_valid_mask.clone() 387 | feat_a_t_dict = {} 388 | for t in range(self.num_recurrent_steps_val // self.shift): 389 | if t == 0: 390 | inference_mask = mask.clone() 391 | inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False 392 | else: 393 | inference_mask = torch.zeros_like(mask) 394 | inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True 395 | edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask) 396 | if isinstance(data, Batch): 397 | batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t 398 | for t in range(num_step)], dim=0) 399 | batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t 400 | for t in range(num_step)], dim=0) 401 | else: 402 | batch_s = torch.arange(num_step, 403 | device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) 404 | batch_pl = torch.arange(num_step, 405 | device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) 406 | # In the inference stage, we only infer the current stage for recurrent 407 | edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, 408 | head_vector_a, 409 | inference_mask, batch_s, 410 | batch_pl) 411 | mask_s = inference_mask.transpose(0, 1).reshape(-1) 412 | edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, 413 | batch_s, mask_s) 414 | 415 | for i in range(self.num_layers): 416 | if i in feat_a_t_dict: 417 | feat_a = feat_a_t_dict[i] 418 | feat_a = feat_a.reshape(-1, self.hidden_dim) 419 | feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) 420 | feat_a = feat_a.reshape(-1, num_step, 421 | self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) 422 | feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( 423 | repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( 424 | -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) 425 | feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) 426 | feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) 427 | 428 | if i+1 not in feat_a_t_dict: 429 | feat_a_t_dict[i+1] = feat_a 430 | else: 431 | feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] 432 | 433 | next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) 434 | 435 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) 436 | 437 | topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) 438 | 439 | expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) 440 | next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index) 441 | 442 | theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] 443 | cos, sin = theta.cos(), theta.sin() 444 | rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) 445 | rot_mat[:, 0, 0] = cos 446 | rot_mat[:, 0, 1] = sin 447 | rot_mat[:, 1, 0] = -sin 448 | rot_mat[:, 1, 1] = cos 449 | agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), 450 | rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view( 451 | -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2) 452 | agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...] 453 | 454 | sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device) 455 | agent_pred_rel = agent_pred_rel.gather(dim=1, 456 | index=sample_index[..., None, None, None].expand(-1, -1, 6, 4, 457 | 2))[:, 0, ...] 458 | pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0] 459 | pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) 460 | diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] 461 | pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) 462 | 463 | pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) 464 | diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] 465 | theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) 466 | head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta 467 | next_token_idx = next_token_idx.gather(dim=1, index=sample_index) 468 | next_token_idx = next_token_idx.squeeze(-1) 469 | next_token_idx_list.append(next_token_idx[:, None]) 470 | agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[ 471 | next_token_idx[veh_mask]] 472 | agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[ 473 | next_token_idx[ped_mask]] 474 | agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[ 475 | next_token_idx[cyc_mask]] 476 | motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), 477 | pos_a[:, 1:] - pos_a[:, :-1]], dim=1) 478 | 479 | head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) 480 | 481 | vel = motion_vector_a.clone() / (0.1 * self.shift) 482 | vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 483 | motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 484 | x_a = torch.stack( 485 | [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), 486 | angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1) 487 | 488 | x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), 489 | categorical_embs=categorical_embs) 490 | x_a = x_a.view(-1, num_step, self.hidden_dim) 491 | 492 | feat_a = torch.cat((agent_token_emb, x_a), dim=-1) 493 | feat_a = self.fusion_emb(feat_a) 494 | 495 | agent_valid_mask[agent_category != 3] = False 496 | 497 | return { 498 | 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:], 499 | 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:], 500 | 'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(), 501 | 'valid_mask': agent_valid_mask[:, self.num_historical_steps:], 502 | 'pred_traj': pred_traj, 503 | 'pred_head': pred_head, 504 | 'next_token_idx': torch.cat(next_token_idx_list, dim=-1), 505 | 'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1), 506 | 'next_token_eval_mask': data['agent']['agent_valid_mask'], 507 | 'pred_prob': pred_prob, 508 | 'vel': vel 509 | } 510 | -------------------------------------------------------------------------------- /smart/modules/map_decoder.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Dict 3 | import torch 4 | import torch.nn as nn 5 | from torch_cluster import radius_graph 6 | from torch_geometric.data import Batch 7 | from torch_geometric.data import HeteroData 8 | from torch_geometric.utils import dense_to_sparse, subgraph 9 | from smart.utils.nan_checker import check_nan_inf 10 | from smart.layers.attention_layer import AttentionLayer 11 | from smart.layers import MLPLayer 12 | from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding 13 | from smart.utils import angle_between_2d_vectors 14 | from smart.utils import merge_edges 15 | from smart.utils import weight_init 16 | from smart.utils import wrap_angle 17 | import pickle 18 | 19 | 20 | class SMARTMapDecoder(nn.Module): 21 | 22 | def __init__(self, 23 | dataset: str, 24 | input_dim: int, 25 | hidden_dim: int, 26 | num_historical_steps: int, 27 | pl2pl_radius: float, 28 | num_freq_bands: int, 29 | num_layers: int, 30 | num_heads: int, 31 | head_dim: int, 32 | dropout: float, 33 | map_token) -> None: 34 | super(SMARTMapDecoder, self).__init__() 35 | self.dataset = dataset 36 | self.input_dim = input_dim 37 | self.hidden_dim = hidden_dim 38 | self.num_historical_steps = num_historical_steps 39 | self.pl2pl_radius = pl2pl_radius 40 | self.num_freq_bands = num_freq_bands 41 | self.num_layers = num_layers 42 | self.num_heads = num_heads 43 | self.head_dim = head_dim 44 | self.dropout = dropout 45 | 46 | if input_dim == 2: 47 | input_dim_r_pt2pt = 3 48 | elif input_dim == 3: 49 | input_dim_r_pt2pt = 4 50 | else: 51 | raise ValueError('{} is not a valid dimension'.format(input_dim)) 52 | 53 | self.type_pt_emb = nn.Embedding(17, hidden_dim) 54 | self.side_pt_emb = nn.Embedding(4, hidden_dim) 55 | self.polygon_type_emb = nn.Embedding(4, hidden_dim) 56 | self.light_pl_emb = nn.Embedding(4, hidden_dim) 57 | 58 | self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim, 59 | num_freq_bands=num_freq_bands) 60 | self.pt2pt_layers = nn.ModuleList( 61 | [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, 62 | bipartite=False, has_pos_emb=True) for _ in range(num_layers)] 63 | ) 64 | self.token_size = 1024 65 | self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, 66 | output_dim=self.token_size) 67 | input_dim_token = 22 68 | self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) 69 | self.map_token = map_token 70 | self.apply(weight_init) 71 | self.mask_pt = False 72 | 73 | def maybe_autocast(self, dtype=torch.float32): 74 | return torch.cuda.amp.autocast(dtype=dtype) 75 | 76 | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: 77 | pt_valid_mask = data['pt_token']['pt_valid_mask'] 78 | pt_pred_mask = data['pt_token']['pt_pred_mask'] 79 | pt_target_mask = data['pt_token']['pt_target_mask'] 80 | mask_s = pt_valid_mask 81 | 82 | pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous() 83 | orient_pt = data['pt_token']['orientation'].contiguous() 84 | orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) 85 | token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float) 86 | pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1)) 87 | pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']] 88 | 89 | if self.input_dim == 2: 90 | x_pt = pt_token_emb 91 | elif self.input_dim == 3: 92 | x_pt = pt_token_emb 93 | else: 94 | raise ValueError('{} is not a valid dimension'.format(self.input_dim)) 95 | 96 | token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index'] 97 | token_light_type = data['map_polygon']['light_type'][token2pl[1]] 98 | x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()), 99 | self.polygon_type_emb(data['pt_token']['pl_type'].long()), 100 | self.light_pl_emb(token_light_type.long()),] 101 | x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) 102 | edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius, 103 | batch=data['pt_token']['batch'] if isinstance(data, Batch) else None, 104 | loop=False, max_num_neighbors=100) 105 | if self.mask_pt: 106 | edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0] 107 | rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] 108 | rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]]) 109 | if self.input_dim == 2: 110 | r_pt2pt = torch.stack( 111 | [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), 112 | angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], 113 | nbr_vector=rel_pos_pt2pt[:, :2]), 114 | rel_orient_pt2pt], dim=-1) 115 | elif self.input_dim == 3: 116 | r_pt2pt = torch.stack( 117 | [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), 118 | angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], 119 | nbr_vector=rel_pos_pt2pt[:, :2]), 120 | rel_pos_pt2pt[:, -1], 121 | rel_orient_pt2pt], dim=-1) 122 | else: 123 | raise ValueError('{} is not a valid dimension'.format(self.input_dim)) 124 | r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) 125 | for i in range(self.num_layers): 126 | x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) 127 | 128 | next_token_prob = self.token_predict_head(x_pt[pt_pred_mask]) 129 | next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) 130 | _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) 131 | next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask] 132 | 133 | return { 134 | 'x_pt': x_pt, 135 | 'map_next_token_idx': next_token_idx, 136 | 'map_next_token_prob': next_token_prob, 137 | 'map_next_token_idx_gt': next_token_index_gt, 138 | 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask] 139 | } 140 | -------------------------------------------------------------------------------- /smart/modules/smart_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.data import HeteroData 5 | from smart.modules.agent_decoder import SMARTAgentDecoder 6 | from smart.modules.map_decoder import SMARTMapDecoder 7 | 8 | 9 | class SMARTDecoder(nn.Module): 10 | 11 | def __init__(self, 12 | dataset: str, 13 | input_dim: int, 14 | hidden_dim: int, 15 | num_historical_steps: int, 16 | pl2pl_radius: float, 17 | time_span: Optional[int], 18 | pl2a_radius: float, 19 | a2a_radius: float, 20 | num_freq_bands: int, 21 | num_map_layers: int, 22 | num_agent_layers: int, 23 | num_heads: int, 24 | head_dim: int, 25 | dropout: float, 26 | map_token: Dict, 27 | token_data: Dict, 28 | use_intention=False, 29 | token_size=512) -> None: 30 | super(SMARTDecoder, self).__init__() 31 | self.map_encoder = SMARTMapDecoder( 32 | dataset=dataset, 33 | input_dim=input_dim, 34 | hidden_dim=hidden_dim, 35 | num_historical_steps=num_historical_steps, 36 | pl2pl_radius=pl2pl_radius, 37 | num_freq_bands=num_freq_bands, 38 | num_layers=num_map_layers, 39 | num_heads=num_heads, 40 | head_dim=head_dim, 41 | dropout=dropout, 42 | map_token=map_token 43 | ) 44 | self.agent_encoder = SMARTAgentDecoder( 45 | dataset=dataset, 46 | input_dim=input_dim, 47 | hidden_dim=hidden_dim, 48 | num_historical_steps=num_historical_steps, 49 | time_span=time_span, 50 | pl2a_radius=pl2a_radius, 51 | a2a_radius=a2a_radius, 52 | num_freq_bands=num_freq_bands, 53 | num_layers=num_agent_layers, 54 | num_heads=num_heads, 55 | head_dim=head_dim, 56 | dropout=dropout, 57 | token_size=token_size, 58 | token_data=token_data 59 | ) 60 | self.map_enc = None 61 | 62 | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: 63 | map_enc = self.map_encoder(data) 64 | agent_enc = self.agent_encoder(data, map_enc) 65 | return {**map_enc, **agent_enc} 66 | 67 | def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]: 68 | map_enc = self.map_encoder(data) 69 | agent_enc = self.agent_encoder.inference(data, map_enc) 70 | return {**map_enc, **agent_enc} 71 | 72 | def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]: 73 | agent_enc = self.agent_encoder.inference(data, map_enc) 74 | return {**map_enc, **agent_enc} 75 | -------------------------------------------------------------------------------- /smart/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/preprocess/__init__.py -------------------------------------------------------------------------------- /smart/preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import torch 5 | from typing import Any, Dict, List, Optional 6 | 7 | predict_unseen_agents = False 8 | vector_repr = True 9 | _agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background'] 10 | _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] 11 | _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] 12 | _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', 13 | 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', 14 | 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', 15 | 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] 16 | _point_sides = ['LEFT', 'RIGHT', 'CENTER'] 17 | _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] 18 | _polygon_is_intersections = [True, False, None] 19 | 20 | 21 | Lane_type_hash = { 22 | 4: "BIKE", 23 | 3: "VEHICLE", 24 | 2: "VEHICLE", 25 | 1: "BUS" 26 | } 27 | 28 | boundary_type_hash = { 29 | 5: "UNKNOWN", 30 | 6: "DASHED_WHITE", 31 | 7: "SOLID_WHITE", 32 | 8: "DOUBLE_DASH_WHITE", 33 | 9: "DASHED_YELLOW", 34 | 10: "DOUBLE_DASH_YELLOW", 35 | 11: "SOLID_YELLOW", 36 | 12: "DOUBLE_SOLID_YELLOW", 37 | 13: "DASH_SOLID_YELLOW", 38 | 14: "UNKNOWN", 39 | 15: "EDGE", 40 | 16: "EDGE" 41 | } 42 | 43 | 44 | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]: 45 | if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps 46 | historical_df = df[df['timestep'] == num_historical_steps-1] 47 | agent_ids = list(historical_df['track_id'].unique()) 48 | df = df[df['track_id'].isin(agent_ids)] 49 | else: 50 | agent_ids = list(df['track_id'].unique()) 51 | 52 | num_agents = len(agent_ids) 53 | # initialization 54 | valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) 55 | current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) 56 | predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) 57 | agent_id: List[Optional[str]] = [None] * num_agents 58 | agent_type = torch.zeros(num_agents, dtype=torch.uint8) 59 | agent_category = torch.zeros(num_agents, dtype=torch.uint8) 60 | position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) 61 | heading = torch.zeros(num_agents, num_steps, dtype=torch.float) 62 | velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) 63 | shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) 64 | 65 | for track_id, track_df in df.groupby('track_id'): 66 | agent_idx = agent_ids.index(track_id) 67 | agent_steps = track_df['timestep'].values 68 | 69 | valid_mask[agent_idx, agent_steps] = True 70 | current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] 71 | predict_mask[agent_idx, agent_steps] = True 72 | if vector_repr: # a time step t is valid only when both t and t-1 are valid 73 | valid_mask[agent_idx, 1: num_historical_steps] = ( 74 | valid_mask[agent_idx, :num_historical_steps - 1] & 75 | valid_mask[agent_idx, 1: num_historical_steps]) 76 | valid_mask[agent_idx, 0] = False 77 | predict_mask[agent_idx, :num_historical_steps] = False 78 | if not current_valid_mask[agent_idx]: 79 | predict_mask[agent_idx, num_historical_steps:] = False 80 | 81 | agent_id[agent_idx] = track_id 82 | agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) 83 | agent_category[agent_idx] = track_df['object_category'].values[0] 84 | position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values, 85 | track_df['position_y'].values, 86 | track_df['position_z'].values], 87 | axis=-1)).float() 88 | heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float() 89 | velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values, 90 | track_df['velocity_y'].values], 91 | axis=-1)).float() 92 | shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values, 93 | track_df['width'].values, 94 | track_df["height"].values], 95 | axis=-1)).float() 96 | av_idx = agent_id.index(av_id) 97 | 98 | return { 99 | 'num_nodes': num_agents, 100 | 'av_index': av_idx, 101 | 'valid_mask': valid_mask, 102 | 'predict_mask': predict_mask, 103 | 'id': agent_id, 104 | 'type': agent_type, 105 | 'category': agent_category, 106 | 'position': position, 107 | 'heading': heading, 108 | 'velocity': velocity, 109 | 'shape': shape 110 | } -------------------------------------------------------------------------------- /smart/tokens/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/__init__.py -------------------------------------------------------------------------------- /smart/tokens/cluster_frame_5_2048.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/cluster_frame_5_2048.pkl -------------------------------------------------------------------------------- /smart/tokens/map_traj_token5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rainmaker22/SMART/aaf1213ebabd50bb9e280c82cbd78912650d5d0f/smart/tokens/map_traj_token5.pkl -------------------------------------------------------------------------------- /smart/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from smart.transforms.target_builder import WaymoTargetBuilder 2 | -------------------------------------------------------------------------------- /smart/transforms/target_builder.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch_geometric.data import HeteroData 5 | from torch_geometric.transforms import BaseTransform 6 | from smart.utils import wrap_angle 7 | from smart.utils.log import Logging 8 | 9 | 10 | def to_16(data): 11 | if isinstance(data, dict): 12 | for key, value in data.items(): 13 | new_value = to_16(value) 14 | data[key] = new_value 15 | if isinstance(data, torch.Tensor): 16 | if data.dtype == torch.float32: 17 | data = data.to(torch.float16) 18 | return data 19 | 20 | 21 | def tofloat32(data): 22 | for name in data: 23 | value = data[name] 24 | if isinstance(value, dict): 25 | value = tofloat32(value) 26 | elif isinstance(value, torch.Tensor) and value.dtype == torch.float64: 27 | value = value.to(torch.float32) 28 | data[name] = value 29 | return data 30 | 31 | 32 | class WaymoTargetBuilder(BaseTransform): 33 | 34 | def __init__(self, 35 | num_historical_steps: int, 36 | num_future_steps: int, 37 | mode="train") -> None: 38 | self.num_historical_steps = num_historical_steps 39 | self.num_future_steps = num_future_steps 40 | self.mode = mode 41 | self.num_features = 3 42 | self.augment = False 43 | self.logger = Logging().log(level='DEBUG') 44 | 45 | def score_ego_agent(self, agent): 46 | av_index = agent['av_index'] 47 | agent["category"][av_index] = 5 48 | return agent 49 | 50 | def clip(self, agent, max_num=32): 51 | av_index = agent["av_index"] 52 | valid = agent['valid_mask'] 53 | ego_pos = agent["position"][av_index] 54 | obstacle_mask = agent['type'] == 3 55 | distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car 56 | distance[obstacle_mask] = 10e5 57 | sort_idx = distance.sort()[1] 58 | mask = torch.zeros(valid.shape[0]) 59 | mask[sort_idx[:max_num]] = 1 60 | mask = mask.to(torch.bool) 61 | mask[av_index] = True 62 | new_av_index = mask[:av_index].sum() 63 | agent["num_nodes"] = int(mask.sum()) 64 | agent["av_index"] = int(new_av_index) 65 | excluded = ["num_nodes", "av_index", "ego"] 66 | for key, val in agent.items(): 67 | if key in excluded: 68 | continue 69 | if key == "id": 70 | val = list(np.array(val)[mask]) 71 | agent[key] = val 72 | continue 73 | if len(val.size()) > 1: 74 | agent[key] = val[mask, ...] 75 | else: 76 | agent[key] = val[mask] 77 | return agent 78 | 79 | def score_nearby_vehicle(self, agent, max_num=10): 80 | av_index = agent['av_index'] 81 | agent["category"] = torch.zeros_like(agent["category"]) 82 | obstacle_mask = agent['type'] == 3 83 | pos = agent["position"][av_index, self.num_historical_steps, :2] 84 | distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) 85 | distance[obstacle_mask] = 10e5 86 | sort_idx = distance.sort()[1] 87 | nearby_mask = torch.zeros(distance.shape[0]) 88 | nearby_mask[sort_idx[1:max_num]] = 1 89 | nearby_mask = nearby_mask.bool() 90 | agent["category"][nearby_mask] = 3 91 | agent["category"][obstacle_mask] = 0 92 | 93 | def score_trained_vehicle(self, agent, max_num=10, min_distance=0): 94 | av_index = agent['av_index'] 95 | agent["category"] = torch.zeros_like(agent["category"]) 96 | pos = agent["position"][av_index, self.num_historical_steps, :2] 97 | distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) 98 | distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1) 99 | invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters 100 | agent["valid_mask"] = agent["valid_mask"] * invalid_mask 101 | # we do not predict vehicle too far away from ego car 102 | closet_vehicle = distance < 100 103 | valid = agent['valid_mask'] 104 | valid_current = valid[:, (self.num_historical_steps):] 105 | valid_counts = valid_current.sum(1) 106 | counts_vehicle = valid_counts >= 1 107 | no_backgroud = agent['type'] != 3 108 | vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud 109 | if vehicle2pred.sum() > max_num: 110 | # too many still vehicle so that train the model using the moving vehicle as much as possible 111 | true_indices = torch.nonzero(vehicle2pred).squeeze(1) 112 | selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]] 113 | vehicle2pred.fill_(False) 114 | vehicle2pred[selected_indices] = True 115 | agent["category"][vehicle2pred] = 3 116 | 117 | def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps): 118 | origin = position[:, num_historical_steps - 1] 119 | theta = heading[:, num_historical_steps - 1] 120 | cos, sin = theta.cos(), theta.sin() 121 | rot_mat = theta.new_zeros(num_nodes, 2, 2) 122 | rot_mat[:, 0, 0] = cos 123 | rot_mat[:, 0, 1] = -sin 124 | rot_mat[:, 1, 0] = sin 125 | rot_mat[:, 1, 1] = cos 126 | target = origin.new_zeros(num_nodes, num_future_steps, 4) 127 | target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] - 128 | origin[:, :2].unsqueeze(1), rot_mat) 129 | his = origin.new_zeros(num_nodes, num_historical_steps, 4) 130 | his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] - 131 | origin[:, :2].unsqueeze(1), rot_mat) 132 | if position.size(2) == 3: 133 | target[..., 2] = (position[:, num_historical_steps:, 2] - 134 | origin[:, 2].unsqueeze(-1)) 135 | his[..., 2] = (position[:, :num_historical_steps, 2] - 136 | origin[:, 2].unsqueeze(-1)) 137 | target[..., 3] = wrap_angle(heading[:, num_historical_steps:] - 138 | theta.unsqueeze(-1)) 139 | his[..., 3] = wrap_angle(heading[:, :num_historical_steps] - 140 | theta.unsqueeze(-1)) 141 | else: 142 | target[..., 2] = wrap_angle(heading[:, num_historical_steps:] - 143 | theta.unsqueeze(-1)) 144 | his[..., 2] = wrap_angle(heading[:, :num_historical_steps] - 145 | theta.unsqueeze(-1)) 146 | return his, target 147 | 148 | def __call__(self, data) -> HeteroData: 149 | agent = data["agent"] 150 | self.score_ego_agent(agent) 151 | self.score_trained_vehicle(agent, max_num=32) 152 | return HeteroData(data) 153 | -------------------------------------------------------------------------------- /smart/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from smart.utils.geometry import angle_between_2d_vectors 3 | from smart.utils.geometry import angle_between_3d_vectors 4 | from smart.utils.geometry import side_to_directed_lineseg 5 | from smart.utils.geometry import wrap_angle 6 | from smart.utils.graph import add_edges 7 | from smart.utils.graph import bipartite_dense_to_sparse 8 | from smart.utils.graph import complete_graph 9 | from smart.utils.graph import merge_edges 10 | from smart.utils.graph import unbatch 11 | from smart.utils.list import safe_list_index 12 | from smart.utils.weight_init import weight_init 13 | -------------------------------------------------------------------------------- /smart/utils/cluster_reader.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | import pandas as pd 4 | import json 5 | 6 | 7 | class LoadScenarioFromCeph: 8 | def __init__(self): 9 | from petrel_client.client import Client 10 | self.file_client = Client('~/petreloss.conf') 11 | 12 | def list(self, dir_path): 13 | return list(self.file_client.list(dir_path)) 14 | 15 | def save(self, data, url): 16 | self.file_client.put(url, pickle.dumps(data)) 17 | 18 | def read_correct_csv(self, scenario_path): 19 | output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python") 20 | return output 21 | 22 | def contains(self, url): 23 | return self.file_client.contains(url) 24 | 25 | def read_string(self, csv_url): 26 | from io import StringIO 27 | df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False) 28 | return df 29 | 30 | def read(self, scenario_path): 31 | with io.BytesIO(self.file_client.get(scenario_path)) as f: 32 | datas = pickle.load(f) 33 | return datas 34 | 35 | def read_json(self, path): 36 | with io.BytesIO(self.file_client.get(path)) as f: 37 | data = json.load(f) 38 | return data 39 | 40 | def read_csv(self, scenario_path): 41 | return pickle.loads(self.file_client.get(scenario_path)) 42 | 43 | def read_model(self, model_path): 44 | with io.BytesIO(self.file_client.get(model_path)) as f: 45 | pass 46 | -------------------------------------------------------------------------------- /smart/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import easydict 4 | 5 | 6 | def load_config_act(path): 7 | """ load config file""" 8 | with open(path, 'r') as f: 9 | cfg = yaml.load(f, Loader=yaml.FullLoader) 10 | return easydict.EasyDict(cfg) 11 | 12 | 13 | def load_config_init(path): 14 | """ load config file""" 15 | path = os.path.join('init/configs', f'{path}.yaml') 16 | with open(path, 'r') as f: 17 | cfg = yaml.load(f, Loader=yaml.FullLoader) 18 | return cfg 19 | -------------------------------------------------------------------------------- /smart/utils/geometry.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | 4 | import torch 5 | 6 | 7 | def angle_between_2d_vectors( 8 | ctr_vector: torch.Tensor, 9 | nbr_vector: torch.Tensor) -> torch.Tensor: 10 | return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0], 11 | (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1)) 12 | 13 | 14 | def angle_between_3d_vectors( 15 | ctr_vector: torch.Tensor, 16 | nbr_vector: torch.Tensor) -> torch.Tensor: 17 | return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1), 18 | (ctr_vector * nbr_vector).sum(dim=-1)) 19 | 20 | 21 | def side_to_directed_lineseg( 22 | query_point: torch.Tensor, 23 | start_point: torch.Tensor, 24 | end_point: torch.Tensor) -> str: 25 | cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) - 26 | (end_point[1] - start_point[1]) * (query_point[0] - start_point[0])) 27 | if cond > 0: 28 | return 'LEFT' 29 | elif cond < 0: 30 | return 'RIGHT' 31 | else: 32 | return 'CENTER' 33 | 34 | 35 | def wrap_angle( 36 | angle: torch.Tensor, 37 | min_val: float = -math.pi, 38 | max_val: float = math.pi) -> torch.Tensor: 39 | return min_val + (angle + max_val) % (max_val - min_val) 40 | -------------------------------------------------------------------------------- /smart/utils/graph.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import torch 5 | from torch_geometric.utils import coalesce 6 | from torch_geometric.utils import degree 7 | 8 | 9 | def add_edges( 10 | from_edge_index: torch.Tensor, 11 | to_edge_index: torch.Tensor, 12 | from_edge_attr: Optional[torch.Tensor] = None, 13 | to_edge_attr: Optional[torch.Tensor] = None, 14 | replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 15 | from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype) 16 | mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) & 17 | (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0))) 18 | if replace: 19 | to_mask = mask.any(dim=1) 20 | if from_edge_attr is not None and to_edge_attr is not None: 21 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) 22 | to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0) 23 | to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1) 24 | else: 25 | from_mask = mask.any(dim=0) 26 | if from_edge_attr is not None and to_edge_attr is not None: 27 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) 28 | to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0) 29 | to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1) 30 | return to_edge_index, to_edge_attr 31 | 32 | 33 | def merge_edges( 34 | edge_indices: List[torch.Tensor], 35 | edge_attrs: Optional[List[torch.Tensor]] = None, 36 | reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 37 | edge_index = torch.cat(edge_indices, dim=1) 38 | if edge_attrs is not None: 39 | edge_attr = torch.cat(edge_attrs, dim=0) 40 | else: 41 | edge_attr = None 42 | return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce) 43 | 44 | 45 | def complete_graph( 46 | num_nodes: Union[int, Tuple[int, int]], 47 | ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, 48 | loop: bool = False, 49 | device: Optional[Union[torch.device, str]] = None) -> torch.Tensor: 50 | if ptr is None: 51 | if isinstance(num_nodes, int): 52 | num_src, num_dst = num_nodes, num_nodes 53 | else: 54 | num_src, num_dst = num_nodes 55 | edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), 56 | torch.arange(num_dst, dtype=torch.long, device=device)).t() 57 | else: 58 | if isinstance(ptr, torch.Tensor): 59 | ptr_src, ptr_dst = ptr, ptr 60 | num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1] 61 | else: 62 | ptr_src, ptr_dst = ptr 63 | num_src_batch = ptr_src[1:] - ptr_src[:-1] 64 | num_dst_batch = ptr_dst[1:] - ptr_dst[:-1] 65 | edge_index = torch.cat( 66 | [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), 67 | torch.arange(num_dst, dtype=torch.long, device=device)) + p 68 | for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))], 69 | dim=0) 70 | edge_index = edge_index.t() 71 | if isinstance(num_nodes, int) and not loop: 72 | edge_index = edge_index[:, edge_index[0] != edge_index[1]] 73 | return edge_index.contiguous() 74 | 75 | 76 | def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor: 77 | index = adj.nonzero(as_tuple=True) 78 | if len(index) == 3: 79 | batch_src = index[0] * adj.size(1) 80 | batch_dst = index[0] * adj.size(2) 81 | index = (batch_src + index[1], batch_dst + index[2]) 82 | return torch.stack(index, dim=0) 83 | 84 | 85 | def unbatch( 86 | src: torch.Tensor, 87 | batch: torch.Tensor, 88 | dim: int = 0) -> List[torch.Tensor]: 89 | sizes = degree(batch, dtype=torch.long).tolist() 90 | return src.split(sizes, dim) 91 | -------------------------------------------------------------------------------- /smart/utils/list.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, List, Optional 3 | 4 | 5 | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: 6 | try: 7 | return ls.index(elem) 8 | except ValueError: 9 | return None 10 | -------------------------------------------------------------------------------- /smart/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import os 4 | 5 | 6 | class Logging: 7 | 8 | def make_log_dir(self, dirname='logs'): 9 | now_dir = os.path.dirname(__file__) 10 | path = os.path.join(now_dir, dirname) 11 | path = os.path.normpath(path) 12 | if not os.path.exists(path): 13 | os.mkdir(path) 14 | return path 15 | 16 | def get_log_filename(self): 17 | filename = "{}.log".format(time.strftime("%Y-%m-%d",time.localtime())) 18 | filename = os.path.join(self.make_log_dir(), filename) 19 | filename = os.path.normpath(filename) 20 | return filename 21 | 22 | def log(self, level='DEBUG', name="simagent"): 23 | logger = logging.getLogger(name) 24 | level = getattr(logging, level) 25 | logger.setLevel(level) 26 | if not logger.handlers: 27 | sh = logging.StreamHandler() 28 | fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") 29 | fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") 30 | sh.setFormatter(fmt=fmt) 31 | fh.setFormatter(fmt=fmt) 32 | logger.addHandler(sh) 33 | logger.addHandler(fh) 34 | return logger 35 | 36 | def add_log(self, logger, level='DEBUG'): 37 | level = getattr(logging, level) 38 | logger.setLevel(level) 39 | if not logger.handlers: 40 | sh = logging.StreamHandler() 41 | fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") 42 | fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") 43 | sh.setFormatter(fmt=fmt) 44 | fh.setFormatter(fmt=fmt) 45 | logger.addHandler(sh) 46 | logger.addHandler(fh) 47 | return logger 48 | 49 | 50 | if __name__ == '__main__': 51 | logger = Logging().log(level='INFO') 52 | logger.debug("1111111111111111111111") #使用日志器生成日志 53 | logger.info("222222222222222222222222") 54 | logger.error("附件为IP飞机外婆家二分IP文件放") 55 | logger.warning("3333333333333333333333333333") 56 | logger.critical("44444444444444444444444444") 57 | -------------------------------------------------------------------------------- /smart/utils/nan_checker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def check_nan_inf(t, s): 4 | assert not torch.isinf(t).any(), f"{s} is inf, {t}" 5 | assert not torch.isnan(t).any(), f"{s} is nan, {t}" -------------------------------------------------------------------------------- /smart/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | 5 | def weight_init(m: nn.Module) -> None: 6 | if isinstance(m, nn.Linear): 7 | nn.init.xavier_uniform_(m.weight) 8 | if m.bias is not None: 9 | nn.init.zeros_(m.bias) 10 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 11 | fan_in = m.in_channels / m.groups 12 | fan_out = m.out_channels / m.groups 13 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 14 | nn.init.uniform_(m.weight, -bound, bound) 15 | if m.bias is not None: 16 | nn.init.zeros_(m.bias) 17 | elif isinstance(m, nn.Embedding): 18 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 19 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 20 | nn.init.ones_(m.weight) 21 | nn.init.zeros_(m.bias) 22 | elif isinstance(m, nn.LayerNorm): 23 | nn.init.ones_(m.weight) 24 | nn.init.zeros_(m.bias) 25 | elif isinstance(m, nn.MultiheadAttention): 26 | if m.in_proj_weight is not None: 27 | fan_in = m.embed_dim 28 | fan_out = m.embed_dim 29 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 30 | nn.init.uniform_(m.in_proj_weight, -bound, bound) 31 | else: 32 | nn.init.xavier_uniform_(m.q_proj_weight) 33 | nn.init.xavier_uniform_(m.k_proj_weight) 34 | nn.init.xavier_uniform_(m.v_proj_weight) 35 | if m.in_proj_bias is not None: 36 | nn.init.zeros_(m.in_proj_bias) 37 | nn.init.xavier_uniform_(m.out_proj.weight) 38 | if m.out_proj.bias is not None: 39 | nn.init.zeros_(m.out_proj.bias) 40 | if m.bias_k is not None: 41 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02) 42 | if m.bias_v is not None: 43 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02) 44 | elif isinstance(m, (nn.LSTM, nn.LSTMCell)): 45 | for name, param in m.named_parameters(): 46 | if 'weight_ih' in name: 47 | for ih in param.chunk(4, 0): 48 | nn.init.xavier_uniform_(ih) 49 | elif 'weight_hh' in name: 50 | for hh in param.chunk(4, 0): 51 | nn.init.orthogonal_(hh) 52 | elif 'weight_hr' in name: 53 | nn.init.xavier_uniform_(param) 54 | elif 'bias_ih' in name: 55 | nn.init.zeros_(param) 56 | elif 'bias_hh' in name: 57 | nn.init.zeros_(param) 58 | nn.init.ones_(param.chunk(4, 0)[1]) 59 | elif isinstance(m, (nn.GRU, nn.GRUCell)): 60 | for name, param in m.named_parameters(): 61 | if 'weight_ih' in name: 62 | for ih in param.chunk(3, 0): 63 | nn.init.xavier_uniform_(ih) 64 | elif 'weight_hh' in name: 65 | for hh in param.chunk(3, 0): 66 | nn.init.orthogonal_(hh) 67 | elif 'bias_ih' in name: 68 | nn.init.zeros_(param) 69 | elif 'bias_hh' in name: 70 | nn.init.zeros_(param) 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | from argparse import ArgumentParser 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import LearningRateMonitor 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from pytorch_lightning.strategies import DDPStrategy 7 | from smart.utils.config import load_config_act 8 | from smart.datamodules import MultiDataModule 9 | from smart.model import SMART 10 | from smart.utils.log import Logging 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = ArgumentParser() 15 | Predictor_hash = {"smart": SMART, } 16 | parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml') 17 | parser.add_argument('--pretrain_ckpt', type=str, default="") 18 | parser.add_argument('--ckpt_path', type=str, default="") 19 | parser.add_argument('--save_ckpt_path', type=str, default="") 20 | args = parser.parse_args() 21 | config = load_config_act(args.config) 22 | Predictor = Predictor_hash[config.Model.predictor] 23 | strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) 24 | Data_config = config.Dataset 25 | datamodule = MultiDataModule(**vars(Data_config)) 26 | 27 | if args.pretrain_ckpt == "": 28 | model = Predictor(config.Model) 29 | else: 30 | logger = Logging().log(level='DEBUG') 31 | model = Predictor(config.Model) 32 | model.load_params_from_file(filename=args.pretrain_ckpt, 33 | logger=logger) 34 | trainer_config = config.Trainer 35 | model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path, 36 | filename="{epoch:02d}", 37 | monitor='val_cls_acc', 38 | every_n_epochs=1, 39 | save_top_k=5, 40 | mode='max') 41 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 42 | trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices, 43 | strategy=strategy, 44 | accumulate_grad_batches=trainer_config.accumulate_grad_batches, 45 | num_nodes=trainer_config.num_nodes, 46 | callbacks=[model_checkpoint, lr_monitor], 47 | max_epochs=trainer_config.max_epochs, 48 | num_sanity_val_steps=0, 49 | gradient_clip_val=0.5) 50 | if args.ckpt_path == "": 51 | trainer.fit(model, 52 | datamodule) 53 | else: 54 | trainer.fit(model, 55 | datamodule, 56 | ckpt_path=args.ckpt_path) 57 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | 2 | from argparse import ArgumentParser 3 | import pytorch_lightning as pl 4 | from torch_geometric.loader import DataLoader 5 | from smart.datasets.scalable_dataset import MultiDataset 6 | from smart.model import SMART 7 | from smart.transforms import WaymoTargetBuilder 8 | from smart.utils.config import load_config_act 9 | from smart.utils.log import Logging 10 | 11 | if __name__ == '__main__': 12 | pl.seed_everything(2, workers=True) 13 | parser = ArgumentParser() 14 | parser.add_argument('--config', type=str, default="configs/validation/validation_scalable.yaml") 15 | parser.add_argument('--pretrain_ckpt', type=str, default="") 16 | parser.add_argument('--ckpt_path', type=str, default="") 17 | parser.add_argument('--save_ckpt_path', type=str, default="") 18 | args = parser.parse_args() 19 | config = load_config_act(args.config) 20 | 21 | data_config = config.Dataset 22 | val_dataset = { 23 | "scalable": MultiDataset, 24 | }[data_config.dataset](root=data_config.root, split='val', 25 | raw_dir=data_config.val_raw_dir, 26 | processed_dir=data_config.val_processed_dir, 27 | transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps)) 28 | dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers, 29 | pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False) 30 | Predictor = SMART 31 | if args.pretrain_ckpt == "": 32 | model = Predictor(config.Model) 33 | else: 34 | logger = Logging().log(level='DEBUG') 35 | model = Predictor(config.Model) 36 | model.load_params_from_file(filename=args.pretrain_ckpt, 37 | logger=logger) 38 | 39 | trainer_config = config.Trainer 40 | trainer = pl.Trainer(accelerator=trainer_config.accelerator, 41 | devices=trainer_config.devices, 42 | strategy='ddp', num_sanity_val_steps=0) 43 | trainer.validate(model, dataloader) 44 | --------------------------------------------------------------------------------