├── .gitignore ├── Evaluation.py ├── LICENSE ├── README.md ├── config ├── Config.py ├── __init__.py ├── config_utils.py └── demo.py ├── docs ├── assets │ ├── Figure1.png │ ├── Figure4.png │ ├── Figure5.png │ ├── TP.png │ ├── bootstrap.min.css │ ├── deciforce.png │ ├── font.css │ ├── region-based-training.png │ ├── style.css │ ├── teasar.png │ ├── tnt.png │ ├── vec.png │ └── vec2.png └── index.html ├── figs ├── format1.png ├── format2.png └── model.png ├── lib ├── __init__.py ├── dataset │ ├── __init__.py │ ├── argoverse_convertor.py │ ├── collate.py │ ├── dataset_for_argoverse.py │ ├── preprocess_utils │ │ ├── __init__.py │ │ ├── agent_utils.py │ │ ├── feature_utils.py │ │ ├── lane_utils.py │ │ ├── map_utils_vec.py │ │ └── object_utils.py │ ├── utils.py │ └── vectorization.py ├── models │ ├── TF_utils.py │ ├── TF_version │ │ ├── __init__.py │ │ └── stacked_transformer.py │ ├── __init__.py │ └── mmTransformer.py └── utils │ ├── __init__.py │ ├── evaluation_utils.py │ ├── parallel │ ├── __init__.py │ ├── data_container.py │ ├── dataparallel.py │ ├── scatter.py │ └── scatter_utils.py │ └── utilities.py └── requirement.txt /.gitignore: -------------------------------------------------------------------------------- 1 | /vis 2 | /.DS_Store 3 | /interm_data 4 | /sample_data 5 | /.vscode 6 | /__pycache__ 7 | /data 8 | /gpu_mem.log 9 | /logs 10 | /submission.csv 11 | /vis_res 12 | /result_h5 13 | *.pkl 14 | *.zarr 15 | *.pyc 16 | *.pt 17 | *.h5 18 | *.model 19 | -------------------------------------------------------------------------------- /Evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | from functools import partial 11 | 12 | from config.Config import Config 13 | # ============= Dataset ===================== 14 | from lib.dataset.collate import collate_single_cpu 15 | from lib.dataset.dataset_for_argoverse import STFDataset as ArgoverseDataset 16 | # ============= Models ====================== 17 | from lib.models.mmTransformer import mmTrans 18 | from lib.utils.evaluation_utils import compute_forecasting_metrics, FormatData 19 | 20 | # from lib.utils.traj_nms import traj_nms 21 | from lib.utils.utilities import load_checkpoint, load_model_class 22 | 23 | 24 | def parse_args(): 25 | 26 | parser = argparse.ArgumentParser(description='Evaluate the mmTransformer') 27 | parser.add_argument('config', help='config file path') 28 | parser.add_argument('--model-name', type=str, default='demo') 29 | parser.add_argument('--model-save-path', type=str, default='./models/') 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | start_time = time.time() 38 | gpu_num = torch.cuda.device_count() 39 | print("gpu number:{}".format(gpu_num)) 40 | 41 | args = parse_args() 42 | cfg = Config.fromfile(args.config) 43 | 44 | # ================================== INIT DATASET ========================================================== 45 | validation_cfg = cfg.get('val_dataset') 46 | val_dataset = ArgoverseDataset(validation_cfg) 47 | val_dataloader = DataLoader(val_dataset, 48 | shuffle=validation_cfg["shuffle"], 49 | batch_size=validation_cfg["batch_size"], 50 | num_workers=validation_cfg["workers_per_gpu"], 51 | collate_fn=collate_single_cpu) 52 | # =================================== Metric Initial ======================================================= 53 | format_results = FormatData() 54 | evaluate = partial(compute_forecasting_metrics, 55 | max_n_guesses=6, 56 | horizon=30, 57 | miss_threshold=2.0) 58 | # =================================== INIT MODEL =========================================================== 59 | model_cfg = cfg.get('model') 60 | stacked_transfomre = load_model_class(model_cfg['type']) 61 | model = mmTrans(stacked_transfomre, model_cfg).cuda() 62 | model_name = os.path.join(args.model_save_path, 63 | '{}.pt'.format(args.model_name)) 64 | model = load_checkpoint(model_name, model) 65 | print('Successfully Loaded model: {}'.format(model_name)) 66 | print('Finished Initialization in {:.3f}s!!!'.format( 67 | time.time()-start_time)) 68 | # ==================================== EVALUATION LOOP ===================================================== 69 | model.eval() 70 | progress_bar = tqdm(val_dataloader) 71 | with torch.no_grad(): 72 | for j, data in enumerate(progress_bar): 73 | for key in data.keys(): 74 | if isinstance(data[key], torch.Tensor): 75 | data[key] = data[key].cuda() 76 | 77 | out = model(data) 78 | format_results(data, out) 79 | 80 | print(evaluate(**format_results.results)) 81 | print('Validation Process Finished!!') 82 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mmTransformer 2 | 3 | ## Introduction 4 | 5 | - This repo is official implementation for [mmTransformer](https://github.com/decisionforce/mmTransformer) in pytorch. Currently, the core code of mmTransformer is implemented in the commercial project, we provide **inference code** of model with six trajectory propopals for your reference. 6 | 7 | - For other information, please refer to our paper **Multimodal Motion Prediction with Stacked Transformers**. (CVPR 2021) [[Paper](https://arxiv.org/pdf/2103.11624.pdf)] [[Webpage](https://decisionforce.github.io/mmTransformer/)] 8 | 9 | ![img](./figs/model.png) 10 | 11 | ## Set up your virtual environment 12 | 13 | - Initialize virtual environment: 14 | 15 | conda create -n mmTrans python=3.7 16 | 17 | - Install agoverse api. Please refer to [this page](https://github.com/argoai/argoverse-api). 18 | 19 | - Install the [pytorch](https://pytorch.org/). The latest codes are tested on Ubuntu 16.04, CUDA11.1, PyTorch 1.8 and Python 3.7: 20 | (Note that we require the version of torch >= 1.5.0 for testing with pretrained model) 21 | 22 | pip install torch==1.8.0+cu111\ 23 | torchvision==0.9.0+cu111\ 24 | torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 25 | 26 | - For other requirement, please install with following command: 27 | 28 | pip install -r requirement.txt 29 | 30 | 31 | ## Preparation 32 | 33 | ### Download the code, model and data 34 | 35 | 1. Clone this repo from the GitHub. 36 | 37 | git clone https://github.com/decisionforce/mmTransformer.git 38 | 39 | 2. Download the pretrained model and data [[here](https://drive.google.com/file/d/10koDID95zoOnU3pb6AkHAqJInupMScJd/view?usp=sharing)] (map.pkl for Python 3.7 is available [[here](https://drive.google.com/file/d/1HbsgutM1PKjPj-3IIA5kG3mJEMhHGhS0/view?usp=sharing)]) and save it to `./models` and `./interm_data`. 40 | 41 | cd mmTransformer 42 | mkdir models 43 | mkdir interm_data 44 | 45 | 3. Finally, your directory structure should look something like this: 46 | 47 | mmTransformer 48 | └── models 49 | └── demo.pt 50 | └── interm_data 51 | └── argoverse_info_val.pkl 52 | └── map.pkl 53 | 54 | ### Preprocess the dataset 55 | 56 | Alternatively, you can process the data from scratch using following commands. 57 | 58 | 1. Download Argoverse dataset and create a symbolic link to `./data` folder or use following commands. 59 | 60 | cd path/to/mmtransformer/root 61 | mkdir data 62 | cd data 63 | wget https://s3.amazonaws.com/argoai-argoverse/forecasting_val_v1.1.tar.gz 64 | tar -zxvf forecasting_val_v1.1.tar.gz 65 | 66 | 2. Then extract the agent and map information from raw data via Argoverse API: 67 | 68 | python -m lib.dataset.argoverse_convertor ./config/demo.py 69 | 70 | 3. Finally, your directory structure should look something like above illustrated. 71 | 72 | 73 | Format of processed data in ‘argoverse_info_val.pkl’: 74 | 75 | ![img](./figs/format1.png) 76 | 77 | Format of map information in ‘map.pkl’: 78 | 79 | ![img](./figs/format2.png) 80 | 81 | 82 | ## Run the mmTransformer 83 | 84 | For testing: 85 | 86 | python Evaluation.py ./config/demo.py --model-name demo 87 | 88 | ## Results 89 | 90 | Here we showcase the expected results on validation set: 91 | 92 | | Model | Expected results | Results in paper 93 | |--|--|--| 94 | | minADE | 0.709 | 0.713 | 95 | | minFDE | 1.081 | 1.153 | 96 | | MR (K=6) | 10.2 | 10.6 | 97 | 98 | ## TODO 99 | 100 | - We are going to open source our visualization tools and a demo result. (TBD) 101 | 102 | ## Contact us 103 | If you have any issues with the code, please contact to this email: 104 | 105 | ## Citation 106 | If you find our work useful for your research, please consider citing the paper 107 | ``` 108 | @article{liu2021multimodal, 109 | title={Multimodal Motion Prediction with Stacked Transformers}, 110 | author={Liu, Yicheng and Zhang, Jinghuai and Fang, Liangji and Jiang, Qinhong and Zhou, Bolei}, 111 | journal={Computer Vision and Pattern Recognition}, 112 | year={2021} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /config/Config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import ast 3 | import copy 4 | import os 5 | import os.path as osp 6 | import platform 7 | import shutil 8 | import sys 9 | import tempfile 10 | import uuid 11 | import warnings 12 | from argparse import Action, ArgumentParser 13 | from collections import abc 14 | from importlib import import_module 15 | 16 | from addict import Dict 17 | from yapf.yapflib.yapf_api import FormatCode 18 | 19 | 20 | from .config_utils import check_file_exist, import_modules_from_strings 21 | 22 | if platform.system() == 'Windows': 23 | import regex as re 24 | else: 25 | import re 26 | 27 | BASE_KEY = '_base_' 28 | DELETE_KEY = '_delete_' 29 | RESERVED_KEYS = ['filename', 'text', 'pretty_text'] 30 | 31 | 32 | # borrowed from mmcv 33 | class ConfigDict(Dict): 34 | 35 | def __missing__(self, name): 36 | raise KeyError(name) 37 | 38 | def __getattr__(self, name): 39 | try: 40 | value = super(ConfigDict, self).__getattr__(name) 41 | except KeyError: 42 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 43 | f"attribute '{name}'") 44 | except Exception as e: 45 | ex = e 46 | else: 47 | return value 48 | raise ex 49 | 50 | 51 | def add_args(parser, cfg, prefix=''): 52 | for k, v in cfg.items(): 53 | if isinstance(v, str): 54 | parser.add_argument('--' + prefix + k) 55 | elif isinstance(v, int): 56 | parser.add_argument('--' + prefix + k, type=int) 57 | elif isinstance(v, float): 58 | parser.add_argument('--' + prefix + k, type=float) 59 | elif isinstance(v, bool): 60 | parser.add_argument('--' + prefix + k, action='store_true') 61 | elif isinstance(v, dict): 62 | add_args(parser, v, prefix + k + '.') 63 | elif isinstance(v, abc.Iterable): 64 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 65 | else: 66 | print(f'cannot parse key {prefix + k} of type {type(v)}') 67 | return parser 68 | 69 | 70 | class Config: 71 | """A facility for config and config files. 72 | 73 | It supports common file formats as configs: python/json/yaml. The interface 74 | is the same as a dict object and also allows access config values as 75 | attributes. 76 | 77 | Example: 78 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 79 | >>> cfg.a 80 | 1 81 | >>> cfg.b 82 | {'b1': [0, 1]} 83 | >>> cfg.b.b1 84 | [0, 1] 85 | >>> cfg = Config.fromfile('tests/data/config/a.py') 86 | >>> cfg.filename 87 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 88 | >>> cfg.item4 89 | 'test' 90 | >>> cfg 91 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 92 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 93 | """ 94 | 95 | @staticmethod 96 | def _validate_py_syntax(filename): 97 | with open(filename, 'r', encoding='utf-8') as f: 98 | # Setting encoding explicitly to resolve coding issue on windows 99 | content = f.read() 100 | try: 101 | ast.parse(content) 102 | except SyntaxError as e: 103 | raise SyntaxError('There are syntax errors in config ' 104 | f'file {filename}: {e}') 105 | 106 | @staticmethod 107 | def _substitute_predefined_vars(filename, temp_config_name): 108 | file_dirname = osp.dirname(filename) 109 | file_basename = osp.basename(filename) 110 | file_basename_no_extension = osp.splitext(file_basename)[0] 111 | file_extname = osp.splitext(filename)[1] 112 | support_templates = dict( 113 | fileDirname=file_dirname, 114 | fileBasename=file_basename, 115 | fileBasenameNoExtension=file_basename_no_extension, 116 | fileExtname=file_extname) 117 | with open(filename, 'r', encoding='utf-8') as f: 118 | # Setting encoding explicitly to resolve coding issue on windows 119 | config_file = f.read() 120 | for key, value in support_templates.items(): 121 | regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' 122 | value = value.replace('\\', '/') 123 | config_file = re.sub(regexp, value, config_file) 124 | with open(temp_config_name, 'w') as tmp_config_file: 125 | tmp_config_file.write(config_file) 126 | 127 | @staticmethod 128 | def _pre_substitute_base_vars(filename, temp_config_name): 129 | """Substitute base variable placehoders to string, so that parsing 130 | would work.""" 131 | with open(filename, 'r', encoding='utf-8') as f: 132 | # Setting encoding explicitly to resolve coding issue on windows 133 | config_file = f.read() 134 | base_var_dict = {} 135 | regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' 136 | base_vars = set(re.findall(regexp, config_file)) 137 | for base_var in base_vars: 138 | randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' 139 | base_var_dict[randstr] = base_var 140 | regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' 141 | config_file = re.sub(regexp, f'"{randstr}"', config_file) 142 | with open(temp_config_name, 'w') as tmp_config_file: 143 | tmp_config_file.write(config_file) 144 | return base_var_dict 145 | 146 | @staticmethod 147 | def _substitute_base_vars(cfg, base_var_dict, base_cfg): 148 | """Substitute variable strings to their actual values.""" 149 | cfg = copy.deepcopy(cfg) 150 | 151 | if isinstance(cfg, dict): 152 | for k, v in cfg.items(): 153 | if isinstance(v, str) and v in base_var_dict: 154 | new_v = base_cfg 155 | for new_k in base_var_dict[v].split('.'): 156 | new_v = new_v[new_k] 157 | cfg[k] = new_v 158 | elif isinstance(v, (list, tuple, dict)): 159 | cfg[k] = Config._substitute_base_vars( 160 | v, base_var_dict, base_cfg) 161 | elif isinstance(cfg, tuple): 162 | cfg = tuple( 163 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 164 | for c in cfg) 165 | elif isinstance(cfg, list): 166 | cfg = [ 167 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 168 | for c in cfg 169 | ] 170 | elif isinstance(cfg, str) and cfg in base_var_dict: 171 | new_v = base_cfg 172 | for new_k in base_var_dict[cfg].split('.'): 173 | new_v = new_v[new_k] 174 | cfg = new_v 175 | 176 | return cfg 177 | 178 | @staticmethod 179 | def _file2dict(filename, use_predefined_variables=True): 180 | filename = osp.abspath(osp.expanduser(filename)) 181 | check_file_exist(filename) 182 | fileExtname = osp.splitext(filename)[1] 183 | if fileExtname not in ['.py', '.json', '.yaml', '.yml']: 184 | raise IOError('Only py/yml/yaml/json type are supported now!') 185 | 186 | with tempfile.TemporaryDirectory() as temp_config_dir: 187 | temp_config_file = tempfile.NamedTemporaryFile( 188 | dir=temp_config_dir, suffix=fileExtname) 189 | if platform.system() == 'Windows': 190 | temp_config_file.close() 191 | temp_config_name = osp.basename(temp_config_file.name) 192 | # Substitute predefined variables 193 | if use_predefined_variables: 194 | Config._substitute_predefined_vars(filename, 195 | temp_config_file.name) 196 | else: 197 | shutil.copyfile(filename, temp_config_file.name) 198 | # Substitute base variables from placeholders to strings 199 | base_var_dict = Config._pre_substitute_base_vars( 200 | temp_config_file.name, temp_config_file.name) 201 | 202 | if filename.endswith('.py'): 203 | temp_module_name = osp.splitext(temp_config_name)[0] 204 | sys.path.insert(0, temp_config_dir) 205 | Config._validate_py_syntax(filename) 206 | mod = import_module(temp_module_name) 207 | sys.path.pop(0) 208 | cfg_dict = { 209 | name: value 210 | for name, value in mod.__dict__.items() 211 | if not name.startswith('__') 212 | } 213 | # delete imported module 214 | del sys.modules[temp_module_name] 215 | elif filename.endswith(('.yml', '.yaml', '.json')): 216 | import mmcv 217 | cfg_dict = mmcv.load(temp_config_file.name) 218 | # close temp file 219 | temp_config_file.close() 220 | 221 | cfg_text = filename + '\n' 222 | with open(filename, 'r', encoding='utf-8') as f: 223 | # Setting encoding explicitly to resolve coding issue on windows 224 | cfg_text += f.read() 225 | 226 | if BASE_KEY in cfg_dict: 227 | cfg_dir = osp.dirname(filename) 228 | base_filename = cfg_dict.pop(BASE_KEY) 229 | base_filename = base_filename if isinstance( 230 | base_filename, list) else [base_filename] 231 | 232 | cfg_dict_list = list() 233 | cfg_text_list = list() 234 | for f in base_filename: 235 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) 236 | cfg_dict_list.append(_cfg_dict) 237 | cfg_text_list.append(_cfg_text) 238 | 239 | base_cfg_dict = dict() 240 | for c in cfg_dict_list: 241 | if len(base_cfg_dict.keys() & c.keys()) > 0: 242 | raise KeyError('Duplicate key is not allowed among bases') 243 | base_cfg_dict.update(c) 244 | 245 | # Subtitute base variables from strings to their actual values 246 | cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, 247 | base_cfg_dict) 248 | 249 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) 250 | cfg_dict = base_cfg_dict 251 | 252 | # merge cfg_text 253 | cfg_text_list.append(cfg_text) 254 | cfg_text = '\n'.join(cfg_text_list) 255 | 256 | return cfg_dict, cfg_text 257 | 258 | @staticmethod 259 | def _merge_a_into_b(a, b, allow_list_keys=False): 260 | """merge dict ``a`` into dict ``b`` (non-inplace). 261 | 262 | Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid 263 | in-place modifications. 264 | 265 | Args: 266 | a (dict): The source dict to be merged into ``b``. 267 | b (dict): The origin dict to be fetch keys from ``a``. 268 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 269 | are allowed in source ``a`` and will replace the element of the 270 | corresponding index in b if b is a list. Default: False. 271 | 272 | Returns: 273 | dict: The modified dict of ``b`` using ``a``. 274 | 275 | Examples: 276 | # Normally merge a into b. 277 | >>> Config._merge_a_into_b( 278 | ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) 279 | {'obj': {'a': 2}} 280 | 281 | # Delete b first and merge a into b. 282 | >>> Config._merge_a_into_b( 283 | ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) 284 | {'obj': {'a': 2}} 285 | 286 | # b is a list 287 | >>> Config._merge_a_into_b( 288 | ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) 289 | [{'a': 2}, {'b': 2}] 290 | """ 291 | b = b.copy() 292 | for k, v in a.items(): 293 | if allow_list_keys and k.isdigit() and isinstance(b, list): 294 | k = int(k) 295 | if len(b) <= k: 296 | raise KeyError(f'Index {k} exceeds the length of list {b}') 297 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 298 | elif isinstance(v, 299 | dict) and k in b and not v.pop(DELETE_KEY, False): 300 | allowed_types = (dict, list) if allow_list_keys else dict 301 | if not isinstance(b[k], allowed_types): 302 | raise TypeError( 303 | f'{k}={v} in child config cannot inherit from base ' 304 | f'because {k} is a dict in the child config but is of ' 305 | f'type {type(b[k])} in base config. You may set ' 306 | f'`{DELETE_KEY}=True` to ignore the base config') 307 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 308 | else: 309 | b[k] = v 310 | return b 311 | 312 | @staticmethod 313 | def fromfile(filename, 314 | use_predefined_variables=True, 315 | import_custom_modules=True): 316 | cfg_dict, cfg_text = Config._file2dict(filename, 317 | use_predefined_variables) 318 | if import_custom_modules and cfg_dict.get('custom_imports', None): 319 | import_modules_from_strings(**cfg_dict['custom_imports']) 320 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename) 321 | 322 | @staticmethod 323 | def fromstring(cfg_str, file_format): 324 | """Generate config from config str. 325 | 326 | Args: 327 | cfg_str (str): Config str. 328 | file_format (str): Config file format corresponding to the 329 | config str. Only py/yml/yaml/json type are supported now! 330 | 331 | Returns: 332 | obj:`Config`: Config obj. 333 | """ 334 | if file_format not in ['.py', '.json', '.yaml', '.yml']: 335 | raise IOError('Only py/yml/yaml/json type are supported now!') 336 | if file_format != '.py' and 'dict(' in cfg_str: 337 | # check if users specify a wrong suffix for python 338 | warnings.warn( 339 | 'Please check "file_format", the file format may be .py') 340 | with tempfile.NamedTemporaryFile( 341 | 'w', suffix=file_format, delete=False) as temp_file: 342 | temp_file.write(cfg_str) 343 | # on windows, previous implementation cause error 344 | # see PR 1077 for details 345 | cfg = Config.fromfile(temp_file.name) 346 | os.remove(temp_file.name) 347 | return cfg 348 | 349 | @staticmethod 350 | def auto_argparser(description=None): 351 | """Generate argparser from config file automatically (experimental)""" 352 | partial_parser = ArgumentParser(description=description) 353 | partial_parser.add_argument('config', help='config file path') 354 | cfg_file = partial_parser.parse_known_args()[0].config 355 | cfg = Config.fromfile(cfg_file) 356 | parser = ArgumentParser(description=description) 357 | parser.add_argument('config', help='config file path') 358 | add_args(parser, cfg) 359 | return parser, cfg 360 | 361 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 362 | if cfg_dict is None: 363 | cfg_dict = dict() 364 | elif not isinstance(cfg_dict, dict): 365 | raise TypeError('cfg_dict must be a dict, but ' 366 | f'got {type(cfg_dict)}') 367 | for key in cfg_dict: 368 | if key in RESERVED_KEYS: 369 | raise KeyError(f'{key} is reserved for config file') 370 | 371 | super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 372 | super(Config, self).__setattr__('_filename', filename) 373 | if cfg_text: 374 | text = cfg_text 375 | elif filename: 376 | with open(filename, 'r') as f: 377 | text = f.read() 378 | else: 379 | text = '' 380 | super(Config, self).__setattr__('_text', text) 381 | 382 | @property 383 | def filename(self): 384 | return self._filename 385 | 386 | @property 387 | def text(self): 388 | return self._text 389 | 390 | @property 391 | def pretty_text(self): 392 | 393 | indent = 4 394 | 395 | def _indent(s_, num_spaces): 396 | s = s_.split('\n') 397 | if len(s) == 1: 398 | return s_ 399 | first = s.pop(0) 400 | s = [(num_spaces * ' ') + line for line in s] 401 | s = '\n'.join(s) 402 | s = first + '\n' + s 403 | return s 404 | 405 | def _format_basic_types(k, v, use_mapping=False): 406 | if isinstance(v, str): 407 | v_str = f"'{v}'" 408 | else: 409 | v_str = str(v) 410 | 411 | if use_mapping: 412 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 413 | attr_str = f'{k_str}: {v_str}' 414 | else: 415 | attr_str = f'{str(k)}={v_str}' 416 | attr_str = _indent(attr_str, indent) 417 | 418 | return attr_str 419 | 420 | def _format_list(k, v, use_mapping=False): 421 | # check if all items in the list are dict 422 | if all(isinstance(_, dict) for _ in v): 423 | v_str = '[\n' 424 | v_str += '\n'.join( 425 | f'dict({_indent(_format_dict(v_), indent)}),' 426 | for v_ in v).rstrip(',') 427 | if use_mapping: 428 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 429 | attr_str = f'{k_str}: {v_str}' 430 | else: 431 | attr_str = f'{str(k)}={v_str}' 432 | attr_str = _indent(attr_str, indent) + ']' 433 | else: 434 | attr_str = _format_basic_types(k, v, use_mapping) 435 | return attr_str 436 | 437 | def _contain_invalid_identifier(dict_str): 438 | contain_invalid_identifier = False 439 | for key_name in dict_str: 440 | contain_invalid_identifier |= \ 441 | (not str(key_name).isidentifier()) 442 | return contain_invalid_identifier 443 | 444 | def _format_dict(input_dict, outest_level=False): 445 | r = '' 446 | s = [] 447 | 448 | use_mapping = _contain_invalid_identifier(input_dict) 449 | if use_mapping: 450 | r += '{' 451 | for idx, (k, v) in enumerate(input_dict.items()): 452 | is_last = idx >= len(input_dict) - 1 453 | end = '' if outest_level or is_last else ',' 454 | if isinstance(v, dict): 455 | v_str = '\n' + _format_dict(v) 456 | if use_mapping: 457 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 458 | attr_str = f'{k_str}: dict({v_str}' 459 | else: 460 | attr_str = f'{str(k)}=dict({v_str}' 461 | attr_str = _indent(attr_str, indent) + ')' + end 462 | elif isinstance(v, list): 463 | attr_str = _format_list(k, v, use_mapping) + end 464 | else: 465 | attr_str = _format_basic_types(k, v, use_mapping) + end 466 | 467 | s.append(attr_str) 468 | r += '\n'.join(s) 469 | if use_mapping: 470 | r += '}' 471 | return r 472 | 473 | cfg_dict = self._cfg_dict.to_dict() 474 | text = _format_dict(cfg_dict, outest_level=True) 475 | # copied from setup.cfg 476 | yapf_style = dict( 477 | based_on_style='pep8', 478 | blank_line_before_nested_class_or_def=True, 479 | split_before_expression_after_opening_paren=True) 480 | text, _ = FormatCode(text, style_config=yapf_style, verify=True) 481 | 482 | return text 483 | 484 | def __repr__(self): 485 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 486 | 487 | def __len__(self): 488 | return len(self._cfg_dict) 489 | 490 | def __getattr__(self, name): 491 | return getattr(self._cfg_dict, name) 492 | 493 | def __getitem__(self, name): 494 | return self._cfg_dict.__getitem__(name) 495 | 496 | def __setattr__(self, name, value): 497 | if isinstance(value, dict): 498 | value = ConfigDict(value) 499 | self._cfg_dict.__setattr__(name, value) 500 | 501 | def __setitem__(self, name, value): 502 | if isinstance(value, dict): 503 | value = ConfigDict(value) 504 | self._cfg_dict.__setitem__(name, value) 505 | 506 | def __iter__(self): 507 | return iter(self._cfg_dict) 508 | 509 | def __getstate__(self): 510 | return (self._cfg_dict, self._filename, self._text) 511 | 512 | def __setstate__(self, state): 513 | _cfg_dict, _filename, _text = state 514 | super(Config, self).__setattr__('_cfg_dict', _cfg_dict) 515 | super(Config, self).__setattr__('_filename', _filename) 516 | super(Config, self).__setattr__('_text', _text) 517 | 518 | def dump(self, file=None): 519 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() 520 | if self.filename.endswith('.py'): 521 | if file is None: 522 | return self.pretty_text 523 | else: 524 | with open(file, 'w') as f: 525 | f.write(self.pretty_text) 526 | else: 527 | import mmcv 528 | if file is None: 529 | file_format = self.filename.split('.')[-1] 530 | return mmcv.dump(cfg_dict, file_format=file_format) 531 | else: 532 | mmcv.dump(cfg_dict, file) 533 | 534 | def merge_from_dict(self, options, allow_list_keys=True): 535 | """Merge list into cfg_dict. 536 | 537 | Merge the dict parsed by MultipleKVAction into this cfg. 538 | 539 | Examples: 540 | >>> options = {'model.backbone.depth': 50, 541 | ... 'model.backbone.with_cp':True} 542 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 543 | >>> cfg.merge_from_dict(options) 544 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 545 | >>> assert cfg_dict == dict( 546 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 547 | 548 | # Merge list element 549 | >>> cfg = Config(dict(pipeline=[ 550 | ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) 551 | >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) 552 | >>> cfg.merge_from_dict(options, allow_list_keys=True) 553 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 554 | >>> assert cfg_dict == dict(pipeline=[ 555 | ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) 556 | 557 | Args: 558 | options (dict): dict of configs to merge from. 559 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 560 | are allowed in ``options`` and will replace the element of the 561 | corresponding index in the config if the config is a list. 562 | Default: True. 563 | """ 564 | option_cfg_dict = {} 565 | for full_key, v in options.items(): 566 | d = option_cfg_dict 567 | key_list = full_key.split('.') 568 | for subkey in key_list[:-1]: 569 | d.setdefault(subkey, ConfigDict()) 570 | d = d[subkey] 571 | subkey = key_list[-1] 572 | d[subkey] = v 573 | 574 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 575 | super(Config, self).__setattr__( 576 | '_cfg_dict', 577 | Config._merge_a_into_b( 578 | option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) 579 | 580 | 581 | if __name__ == '__main__': 582 | 583 | Config().fromfile('./config/demo.py') -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/config/__init__.py -------------------------------------------------------------------------------- /config/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | 3 | from importlib import import_module 4 | import os.path as osp 5 | 6 | # From PyTorch internals 7 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 8 | if not osp.isfile(filename): 9 | raise FileNotFoundError(msg_tmpl.format(filename)) 10 | 11 | def import_modules_from_strings(imports, allow_failed_imports=False): 12 | """Import modules from the given list of strings. 13 | 14 | Args: 15 | imports (list | str | None): The given module names to be imported. 16 | allow_failed_imports (bool): If True, the failed imports will return 17 | None. Otherwise, an ImportError is raise. Default: False. 18 | 19 | Returns: 20 | list[module] | module | None: The imported modules. 21 | 22 | Examples: 23 | >>> osp, sys = import_modules_from_strings( 24 | ... ['os.path', 'sys']) 25 | >>> import os.path as osp_ 26 | >>> import sys as sys_ 27 | >>> assert osp == osp_ 28 | >>> assert sys == sys_ 29 | """ 30 | if not imports: 31 | return 32 | single_import = False 33 | if isinstance(imports, str): 34 | single_import = True 35 | imports = [imports] 36 | if not isinstance(imports, list): 37 | raise TypeError( 38 | f'custom_imports must be a list but got type {type(imports)}') 39 | imported = [] 40 | for imp in imports: 41 | if not isinstance(imp, str): 42 | raise TypeError( 43 | f'{imp} is of type {type(imp)} and cannot be imported.') 44 | try: 45 | imported_tmp = import_module(imp) 46 | except ImportError: 47 | if allow_failed_imports: 48 | warnings.warn(f'{imp} failed to import and is ignored.', 49 | UserWarning) 50 | imported_tmp = None 51 | else: 52 | raise ImportError 53 | imported.append(imported_tmp) 54 | if single_import: 55 | imported = imported[0] 56 | return imported -------------------------------------------------------------------------------- /config/demo.py: -------------------------------------------------------------------------------- 1 | # cfg for preprocess 2 | preprocess_dataset = dict( 3 | RAW_DATA_FORMAT={ 4 | "TIMESTAMP": 0, 5 | "TRACK_ID": 1, 6 | "OBJECT_TYPE": 2, 7 | "X": 3, 8 | "Y": 4, 9 | "CITY_NAME": 5, 10 | }, 11 | LANE_WIDTH={'MIA': 3.84, 'PIT': 3.97}, 12 | # to be considered as static 13 | VELOCITY_THRESHOLD=0.0, 14 | # number of timesteps the track should exist to be considered in social context 15 | EXIST_THRESHOLD=(5), 16 | # index of the sorted velocity to look at, to call it as stationary 17 | STATIONARY_THRESHOLD=(13), 18 | LANE_RADIUS=65, # nearby lanes 19 | OBJ_RADIUS=56, # nearby objects 20 | OBS_LEN=20, 21 | DATA_DIR='./data', 22 | INTERMEDIATE_DATA_DIR='./interm_data', 23 | info_prefix='argoverse_info_', 24 | VIS=False, 25 | # sepecify which fold in data dir will be processed 26 | specific_data_fold_list = ['train','val','test','sample'], 27 | vectorization_cfg = dict( 28 | starighten = True, 29 | ) 30 | ) 31 | 32 | 33 | model = dict( 34 | type='stacked_transformer', 35 | history_num_frames= 20, 36 | future_num_frames= 30, 37 | # mode setting 38 | in_channels= 4, 39 | lane_channels= 7, 40 | out_channels= 60, #future_frame*2 !!!!!!!!!!!!! should change with num frame 41 | K= 6, 42 | increasetime= 3, 43 | queries= 6, 44 | num_guesses= 6, 45 | queries_dim= 64, 46 | enc_dim= 64, 47 | aux_task= False, 48 | 49 | 50 | #mmTrans main cfg 51 | subgraph_width = 32, 52 | num_subgraph_layres =2, 53 | lane_length = 10, 54 | ) 55 | 56 | 57 | dataset = dict( 58 | samples_per_gpu=1, 59 | workers_per_gpu=0, 60 | traj_processor_cfg=preprocess_dataset, 61 | ) 62 | 63 | from copy import deepcopy 64 | 65 | train_dataset= deepcopy(dataset) 66 | train_dataset.update(dict( 67 | type= "STFDataset", 68 | batch_size= 128, 69 | shuffle= True, 70 | num_workers= 4, 71 | Providing_GT= True, 72 | lane_length= 10, 73 | dataset_path= './data/train', 74 | processed_data_path= './interm_data/argoverse_info_train.pkl', 75 | processed_maps_path='./interm_data/map.pkl', 76 | )) 77 | 78 | val_dataset= deepcopy(dataset) 79 | val_dataset.update(dict( 80 | type= "STFDataset", 81 | batch_size= 32, 82 | shuffle= False, 83 | Providing_GT= True, 84 | lane_length= 10, 85 | dataset_path= './data/val', 86 | processed_data_path= './interm_data/argoverse_info_val.pkl', 87 | processed_maps_path='./interm_data/map.pkl' 88 | ) 89 | ) 90 | -------------------------------------------------------------------------------- /docs/assets/Figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/Figure1.png -------------------------------------------------------------------------------- /docs/assets/Figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/Figure4.png -------------------------------------------------------------------------------- /docs/assets/Figure5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/Figure5.png -------------------------------------------------------------------------------- /docs/assets/TP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/TP.png -------------------------------------------------------------------------------- /docs/assets/deciforce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/deciforce.png -------------------------------------------------------------------------------- /docs/assets/font.css: -------------------------------------------------------------------------------- 1 | /* Homepage Font */ 2 | 3 | /* latin-ext */ 4 | @font-face { 5 | font-family: 'Lato'; 6 | font-style: normal; 7 | font-weight: 400; 8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); 9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 10 | } 11 | 12 | /* latin */ 13 | @font-face { 14 | font-family: 'Lato'; 15 | font-style: normal; 16 | font-weight: 400; 17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); 18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 19 | } 20 | 21 | /* latin-ext */ 22 | @font-face { 23 | font-family: 'Lato'; 24 | font-style: normal; 25 | font-weight: 700; 26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); 27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 28 | } 29 | 30 | /* latin */ 31 | @font-face { 32 | font-family: 'Lato'; 33 | font-style: normal; 34 | font-weight: 700; 35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); 36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 37 | } 38 | -------------------------------------------------------------------------------- /docs/assets/region-based-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/region-based-training.png -------------------------------------------------------------------------------- /docs/assets/style.css: -------------------------------------------------------------------------------- 1 | /* Body */ 2 | body { 3 | background: #e3e5e8; 4 | color: #ffffff; 5 | font-family: 'Lato', Verdana, Helvetica, sans-serif; 6 | font-weight: 300; 7 | font-size: 14pt; 8 | } 9 | 10 | /* Hyperlinks */ 11 | a {text-decoration: none;} 12 | a:link {color: #1772d0;} 13 | a:visited {color: #1772d0;} 14 | a:active {color: red;} 15 | a:hover {color: #f09228;} 16 | 17 | /* Pre-formatted Text */ 18 | pre { 19 | margin: 5pt 0; 20 | border: 0; 21 | font-size: 12pt; 22 | background: #fcfcfc; 23 | } 24 | 25 | /* Project Page Style */ 26 | /* Section */ 27 | .section { 28 | width: 768pt; 29 | min-height: 100pt; 30 | margin: 15pt auto; 31 | padding: 20pt 30pt; 32 | border: 1pt hidden #000; 33 | text-align: justify; 34 | color: #000000; 35 | background: #ffffff; 36 | } 37 | 38 | /* Header (Title and Logo) */ 39 | .section .header { 40 | min-height: 80pt; 41 | margin-top: 30pt; 42 | } 43 | .section .header .logo { 44 | width: 80pt; 45 | margin-left: 10pt; 46 | float: left; 47 | } 48 | .section .header .logo img { 49 | width: 80pt; 50 | object-fit: cover; 51 | } 52 | .section .header .title { 53 | margin: 0 100pt; 54 | text-align: center; 55 | font-size: 20pt; 56 | } 57 | 58 | /* Author */ 59 | .section .author { 60 | margin: 5pt 0; 61 | text-align: center; 62 | font-size: 14pt; 63 | } 64 | 65 | /* Institution */ 66 | .section .institution { 67 | margin: 5pt 0; 68 | text-align: center; 69 | font-size: 14pt; 70 | } 71 | 72 | /* Conference */ 73 | .section .conference { 74 | margin: 5pt 0; 75 | text-align: center; 76 | font-size: 14pt; 77 | } 78 | 79 | /* Hyperlink (such as Paper and Code) */ 80 | .section .link { 81 | margin: 5pt 0; 82 | text-align: center; 83 | font-size: 16pt; 84 | } 85 | 86 | /* Teaser */ 87 | .section .teaser { 88 | margin: 20pt 0; 89 | text-align: center; 90 | } 91 | .section .teaser img { 92 | width: 95%; 93 | } 94 | 95 | /* Section Title */ 96 | .section .title { 97 | text-align: center; 98 | font-size: 18pt; 99 | margin: 5pt 0 15pt 0; /* top right bottom left */ 100 | } 101 | 102 | /* Section Body */ 103 | .section .body { 104 | margin-bottom: 15pt; 105 | text-align: justify; 106 | font-size: 14pt; 107 | } 108 | 109 | /* BibTeX */ 110 | .section .bibtex { 111 | margin: 5pt 0; 112 | text-align: left; 113 | font-size: 22pt; 114 | } 115 | 116 | /* Related Work */ 117 | .section .ref { 118 | margin: 20pt 0 10pt 0; /* top right bottom left */ 119 | text-align: left; 120 | font-size: 18pt; 121 | font-weight: bold; 122 | } 123 | 124 | /* Citation */ 125 | .section .citation { 126 | min-height: 60pt; 127 | margin: 10pt 0; 128 | } 129 | .section .citation .image { 130 | width: 120pt; 131 | float: left; 132 | } 133 | .section .citation .image img { 134 | max-height: 60pt; 135 | width: 120pt; 136 | object-fit: cover; 137 | } 138 | .section .citation .comment{ 139 | margin-left: 130pt; 140 | text-align: left; 141 | font-size: 14pt; 142 | } 143 | -------------------------------------------------------------------------------- /docs/assets/teasar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/teasar.png -------------------------------------------------------------------------------- /docs/assets/tnt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/tnt.png -------------------------------------------------------------------------------- /docs/assets/vec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/vec.png -------------------------------------------------------------------------------- /docs/assets/vec2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/docs/assets/vec2.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Multimodal Motion Prediction with Stacked Transformers 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 |
25 | 28 |
29 | Multimodal Motion Prediction with Stacked Transformers 30 |
31 |
32 | 33 | 41 |
42 | The Chinese Univsersity of Hong Kong1, SenseTime Research2 43 |
44 |
45 | Computer Vision and Pattern Recognition (CVPR), 2021 46 |
47 | 51 |
52 | 53 |
54 |
55 | 56 | 57 | 58 | 59 |
60 |
Overview
61 |
62 | We propose a novel end-to-end motion prediction framework (mmTransformer) for multimodal motion prediction. Firstly, we utilize stacked transformers architecture to incoporate multiple channels of contextual information, and model the multimodality at feature level with a set of trajectory proposals. Then, we induce the multimodality via a tailored region-based training strategy. By steering the training of proposals, our model effectively mitigates the complexity of motion prediction while ensuring the multimodal outputs. 63 |
64 |
65 | 66 | 67 | 68 | 69 |
70 |
Method
71 |
72 | 73 |
74 |
75 | The proposed architecture of mmTransformer (MultiModal Transformer). The backbone is composed of stacked transformers, which aggregate the contextual information progressively. Proposal feature decoder further generates the trajectory and confidence score for each learned trajectory proposal through the trajectory generator and selector, respectively. 76 |
77 |
78 | 79 |
80 | Overview of the Region-based Training Strategy (RTS). We distribute each proposal to one of the M regions. These proposals, shown in colored rectangles, learn corresponding proposal feature through the stacked transformers. In training stage, we select the proposals assigned to the region where the GT endpoints locate, generate their trajectories and confidence scores, and then calculate the losses for them. 81 |
82 | 83 |
84 | Visualization of the multimodal prediction results on Argoverse validation set. We utilize all trajectory proposals to generate multiple trajectories for each scenario and visualize all the predicted endpoints (black background) in the figures. Colored points indicate the prediction results of a specific group of proposals (after filtering by score). We observe that the endpoints generated by each group of regional proposals are within the associated region. 85 |
86 | 87 | 88 | 89 | 90 |
91 |
Results
92 |
93 | Qualitative comparison between mmTransformer (6 proposals) and mmTransformer+RTS (36 proposals): 94 | 95 | 96 | 97 | 98 | 99 |
100 | Demo video link: 101 |
102 | 103 | 104 |
105 | 111 |
112 |
113 | Demo video of multimodal motion prediction by mmTransformer. For each moving vehicle nearby the ego car, three plausible future trajectories are visualized. 114 |
115 |
116 | Also can be found in: 117 | [bilibili] 118 |
119 |
120 | 121 | 122 | 123 | 124 |
125 |
BibTeX
126 |
127 | @article{liu2021multimodal,
128 |   title={Multimodal Motion Prediction with Stacked Transformers},
129 |   author={Liu, Yicheng and Zhang, Jinghuai and Fang, Liangji and Jiang, Qinhong and Zhou, Bolei},
130 |   journal={Computer Vision and Pattern Recognition},
131 |   year={2021}
132 | }
133 | 
134 | 135 | 136 | 137 |
Related Work
138 | 149 | 160 | 171 |
172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /figs/format1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/figs/format1.png -------------------------------------------------------------------------------- /figs/format2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/figs/format2.png -------------------------------------------------------------------------------- /figs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/figs/model.png -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/__init__.py -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/dataset/__init__.py -------------------------------------------------------------------------------- /lib/dataset/argoverse_convertor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | import sys 5 | from typing import Any, Dict, List 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from argoverse.data_loading.argoverse_forecasting_loader import \ 10 | ArgoverseForecastingLoader 11 | from argoverse.map_representation.map_api import ArgoverseMap 12 | from tqdm import tqdm 13 | 14 | from .preprocess_utils.feature_utils import compute_feature_for_one_seq, save_features 15 | from .preprocess_utils.map_utils_vec import save_map 16 | 17 | # vectorization 18 | from .vectorization import VectorizedCase 19 | 20 | 21 | class ArgoverseConvertor(object): 22 | 23 | def __init__(self, cfg): 24 | 25 | self.data_dir = cfg['DATA_DIR'] 26 | self.obs_len = cfg['OBS_LEN'] 27 | self.lane_radius = cfg['LANE_RADIUS'] 28 | self.object_radius = cfg['OBJ_RADIUS'] 29 | self.raw_dataformat = cfg['RAW_DATA_FORMAT'] 30 | self.am = ArgoverseMap() 31 | self.Afl = ArgoverseForecastingLoader 32 | self.out_dir = cfg['INTERMEDIATE_DATA_DIR'] 33 | self.save_dir_pretext = cfg['info_prefix'] 34 | self.specific_data_fold_list = cfg['specific_data_fold_list'] 35 | 36 | # vectorization 37 | self.vec_processor = VectorizedCase(cfg['vectorization_cfg']) 38 | 39 | def preprocess_map(self): 40 | 41 | os.makedirs(self.out_dir, exist_ok=True) 42 | if not os.path.exists(os.path.join(self.out_dir, 'map.pkl')): 43 | print("Processing maps ...") 44 | save_map(self.out_dir) 45 | print('Map is save at '+ os.path.join(self.out_dir, 'map.pkl')) 46 | 47 | def process(self,): 48 | 49 | # preprocess the map 50 | self.preprocess_map() 51 | 52 | # storage the case infomation 53 | data_info = {} 54 | 55 | for folder in os.listdir(self.data_dir): 56 | 57 | if folder not in self.specific_data_fold_list: 58 | continue 59 | 60 | afl = self.Afl(os.path.join(self.data_dir, folder, 'data')) 61 | info_dict = {} 62 | data_info[folder] = {} 63 | 64 | for path_name_ext in tqdm(afl.seq_list): 65 | 66 | afl_ = afl.get(path_name_ext) 67 | path, name_ext = os.path.split(path_name_ext) 68 | name, ext = os.path.splitext(name_ext) 69 | 70 | info_dict[name] = self.process_case(afl_.seq_df) 71 | 72 | out_path = os.path.join( 73 | self.out_dir, self.save_dir_pretext + f'{folder}.pkl') 74 | with open(out_path, 'wb') as f: 75 | pickle.dump(info_dict, f, pickle.HIGHEST_PROTOCOL) 76 | 77 | data_info[folder]['sample_num'] = len(afl.seq_list) 78 | print('Data is save at ' + out_path) 79 | 80 | # print info 81 | print("Finish Preprocessing.") 82 | for k in data_info.keys(): 83 | print('dataset name: ' + k + 84 | '\n sample num: {}'.format(data_info[k]['sample_num'])) 85 | 86 | def preprocess_case(self, seq_df): 87 | ''' 88 | Args: 89 | seq_df: 90 | 91 | ''' 92 | # retrieve info from csv 93 | agent_feature, obj_feature_ls, nearby_lane_ids, norm_center, city_name =\ 94 | compute_feature_for_one_seq( 95 | seq_df, 96 | self.am, 97 | self.obs_len, 98 | self.lane_radius, 99 | self.object_radius, 100 | self.raw_dataformat, 101 | viz=False, 102 | mode='nearby' 103 | ) 104 | 105 | # pack as the output 106 | dic = save_features( 107 | agent_feature, obj_feature_ls, nearby_lane_ids, norm_center, city_name 108 | ) 109 | 110 | return dic 111 | 112 | def process_case(self, seq_df): 113 | 114 | # tensorized 115 | data = self.preprocess_case(seq_df) 116 | # vectorized 117 | vec_dic = self.vec_processor.process_case(data) 118 | 119 | return vec_dic 120 | 121 | 122 | if __name__ == '__main__': 123 | 124 | import argparse 125 | parser = argparse.ArgumentParser( 126 | description='Preprocess argoverse dataset') 127 | parser.add_argument('config', help='config file path') 128 | args = parser.parse_args() 129 | from config.Config import Config 130 | cfg = Config.fromfile(args.config) 131 | preprocess_cfg = cfg.get('preprocess_dataset') 132 | processor = ArgoverseConvertor(preprocess_cfg) 133 | processor.process() 134 | -------------------------------------------------------------------------------- /lib/dataset/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | padding_keys = ['HISTORY', 'LANE', 'POS'] 6 | stacking_keys = ['VALID_LEN'] 7 | listing_keys = ['CITY_NAME', 'NAME', 'FUTURE', 8 | 'LANE_ID', 'THETA', 'NORM_CENTER'] 9 | 10 | 11 | def collate_single_cpu(batch): 12 | """ 13 | We only pad the HISTORY and LANE data. 14 | For other data, we append data with same key into a list. 15 | """ 16 | 17 | keys = batch[0].keys() 18 | 19 | out = {k: [] for k in keys} 20 | 21 | for data in batch: 22 | for k, v in data.items(): 23 | out[k].append(v) 24 | 25 | # stacking 26 | for k in stacking_keys: 27 | out[k] = torch.stack(out[k], dim=0) 28 | 29 | # padding 30 | for k in padding_keys: 31 | out[k] = pad_sequence(out[k], batch_first=True) 32 | 33 | return out 34 | -------------------------------------------------------------------------------- /lib/dataset/dataset_for_argoverse.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | # Tensorizerization & Vectorization 10 | from .argoverse_convertor import ArgoverseConvertor 11 | # Utilities 12 | from .utils import transform_coord, transform_coord_flip 13 | from .collate import collate_single_cpu 14 | 15 | 16 | class STFDataset(Dataset): 17 | """ 18 | dataset object similar to `torchvision` 19 | """ 20 | 21 | def __init__(self, cfg: dict): 22 | super(STFDataset, self).__init__() 23 | self.cfg = cfg 24 | 25 | self.processed_data_path = cfg['processed_data_path'] 26 | self.processed_maps_path = cfg['processed_maps_path'] 27 | 28 | self.traj_processor = ArgoverseConvertor(cfg['traj_processor_cfg']) 29 | 30 | # Load lane data 31 | with open(self.processed_maps_path, 'rb') as f: 32 | self.map, self.lane_id2idx = pickle.load(f) 33 | 34 | # Load processed trajs, land id and Misc. 35 | with open(self.processed_data_path, 'rb') as f: 36 | self.data = pickle.load(f) 37 | 38 | # get data list 39 | self.data_list = sorted(self.data.keys()) 40 | 41 | def __len__(self): 42 | return len(self.data_list) 43 | 44 | def __getitem__(self, idx): 45 | ''' 46 | Returns: 47 | the shape in here you can refer the format sheet at [here](./README.md) 48 | ''' 49 | 50 | name = self.data_list[idx] 51 | data_dict = {'NAME': name, 'MAX_LEN': [68, 248], } 52 | data_dict.update(self.get_data(name)) 53 | 54 | return data_dict 55 | 56 | @classmethod 57 | def get_data_path_ls(cls, dir_): 58 | return [os.path.join(dir_, data_path) for data_path in os.listdir(dir_)] 59 | 60 | def get_data(self, name): 61 | ''' 62 | the file name of the case 63 | 64 | Since we have processed the case. 65 | 66 | this function only needs to retrieve the lanes. 67 | ''' 68 | out_dict = {} 69 | 70 | # load from pkl 71 | datadict = self.data[name] 72 | out_dict.update(datadict) 73 | 74 | # ----- LANE -------------- 75 | lane = self.get_lane( 76 | datadict['LANE_ID'], datadict['THETA'], datadict['NORM_CENTER'], datadict['CITY_NAME']) 77 | 78 | out_dict.update(dict(LANE=lane)) 79 | 80 | for k, v in out_dict.items(): 81 | if isinstance(v, np.ndarray): 82 | v = torch.from_numpy(v) 83 | 84 | if v.dtype == torch.double: 85 | v = v.type(torch.float32) 86 | 87 | out_dict[k] = v 88 | 89 | return out_dict 90 | 91 | def get_lane(self, lane_id, theta, center, city): 92 | ''' 93 | Args: 94 | lane_id: [lane_num] 95 | center: [2] 96 | theta: float 97 | city: str 98 | 99 | self.map the preprocess map data 100 | : Dict[city name, List[]] 101 | 102 | Returns: 103 | lane_feature: num_lane, 10, 5 104 | ''' 105 | 106 | # Get lane 107 | # lane_feature: num_lane, 10, 5 108 | lane_id2idx = self.lane_id2idx[city] 109 | idx = list(map(lambda x: lane_id2idx[x], lane_id)) 110 | lane_feature = self.map[city][idx].copy() # (nline, 10, 5) 111 | lane = lane_feature[:, :, :2] 112 | 113 | # Location normalization 114 | lane = lane - center 115 | lane = transform_coord(lane, theta) 116 | lane_feature[:, :, :2] = lane 117 | 118 | return lane_feature 119 | 120 | 121 | if __name__ == '__main__': 122 | 123 | import argparse 124 | from config.Config import Config 125 | from torch.utils.data import DataLoader 126 | 127 | parser = argparse.ArgumentParser( 128 | description='Preprocess argoverse dataset') 129 | parser.add_argument('config', help='config file path') 130 | args = parser.parse_args() 131 | 132 | cfg = Config.fromfile(args.config) 133 | 134 | validation_cfg = cfg.get('val_dataset') 135 | val_dataset = STFDataset(validation_cfg) 136 | val_dataloader = DataLoader(val_dataset, 137 | shuffle=validation_cfg["shuffle"], 138 | batch_size=validation_cfg["batch_size"], 139 | num_workers=validation_cfg["workers_per_gpu"], 140 | collate_fn=collate_single_cpu) 141 | 142 | val_dataloader = iter(val_dataloader) 143 | DATA = next(val_dataloader) 144 | import ipdb 145 | ipdb.set_trace() 146 | -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/dataset/preprocess_utils/__init__.py -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/agent_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_agent_feature_ls(agent_df, obs_len, norm_center): 5 | """ 6 | args: 7 | returns: 8 | list of (track, timetamp, track_id, gt_track) 9 | """ 10 | xys, gt_xys = agent_df[["X", "Y"]].values[:obs_len], agent_df[[ 11 | "X", "Y"]].values[obs_len:] 12 | xys -= norm_center # normalize to last observed timestamp point of agent 13 | gt_xys -= norm_center # normalize to last observed timestamp point of agent 14 | ts = agent_df['TIMESTAMP'].values[:obs_len] 15 | 16 | return [xys, ts, agent_df['TRACK_ID'].iloc[0], gt_xys] 17 | -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/feature_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from argoverse.data_loading.argoverse_forecasting_loader import \ 7 | ArgoverseForecastingLoader 8 | from argoverse.map_representation.map_api import ArgoverseMap 9 | 10 | from .agent_utils import get_agent_feature_ls 11 | from .lane_utils import get_nearby_lane_feature_ls 12 | from .object_utils import get_nearby_moving_obj_feature_ls 13 | 14 | 15 | def compute_feature_for_one_seq( 16 | traj_df: pd.DataFrame, 17 | am: ArgoverseMap, 18 | obs_len: int = 20, 19 | lane_radius: int = 5, 20 | obj_radius: int = 10, 21 | raw_dataformat: dict = None, 22 | viz: bool = False, 23 | mode='nearby', 24 | query_bbox=[-65, 65, -65, 65]) -> List[List]: 25 | """ 26 | return lane & track features 27 | args: 28 | mode: 'rect' or 'nearby' 29 | returns: 30 | agent_feature_ls: 31 | list of target agent 32 | obj_feature_ls: 33 | list of (list of nearby agent feature) 34 | lane_feature_ls: 35 | list of (list of lane segment feature) 36 | 37 | norm_center np.ndarray: (2, ) 38 | """ 39 | # normalize timestamps 40 | traj_df['TIMESTAMP'] -= np.min(traj_df['TIMESTAMP'].values) 41 | seq_ts = np.unique(traj_df['TIMESTAMP'].values) 42 | 43 | city_name = traj_df['CITY_NAME'].iloc[0] 44 | agent_df = None 45 | agent_x_end, agent_y_end, start_x, start_y, query_x, query_y, norm_center = [ 46 | None] * 7 47 | 48 | # agent traj & its start/end point 49 | for obj_type, remain_df in traj_df.groupby('OBJECT_TYPE'): 50 | # sorted already according to timestamp 51 | if obj_type == 'AGENT': 52 | agent_df = remain_df 53 | start_x, start_y = agent_df[['X', 'Y']].values[0] 54 | agent_x_end, agent_y_end = agent_df[['X', 'Y']].values[-1] 55 | query_x, query_y = agent_df[['X', 'Y']].values[obs_len-1] 56 | norm_center = np.array([query_x, query_y]) 57 | break 58 | else: 59 | raise ValueError(f"cannot find 'agent' object type") 60 | 61 | # get agent features 62 | agent_feature = get_agent_feature_ls(agent_df, obs_len, norm_center) 63 | hist_xy = agent_feature[0] 64 | hist_len = np.sum(np.sqrt( 65 | (hist_xy[1:, 0]-hist_xy[:-1, 0])**2 + (hist_xy[1:, 1]-hist_xy[:-1, 1])**2)) 66 | 67 | # search lanes from the last observed point of agent 68 | nearby_lane_ids = get_nearby_lane_feature_ls( 69 | am, agent_df, obs_len, city_name, lane_radius, norm_center, mode=mode, query_bbox=query_bbox) 70 | 71 | # search nearby moving objects from the last observed point of agent 72 | obj_feature_ls = get_nearby_moving_obj_feature_ls( 73 | agent_df, traj_df, obs_len, seq_ts, obj_radius, norm_center, raw_dataformat) 74 | 75 | return [agent_feature, obj_feature_ls, nearby_lane_ids, norm_center, city_name] 76 | 77 | 78 | def save_features(agent_feature, obj_feature_ls, nearby_lane_ids, norm_center, city_name): 79 | """ 80 | args: 81 | agent_feature_ls: 82 | list of (xys, ts, agent_df['TRACK_ID'].iloc[0], gt_xys) 83 | obj_feature_ls: 84 | list of list of (xys, ts, mask, track_id, gt_xys, gt_mask) 85 | lane_feature_ls: 86 | list of list of lane a segment feature, centerline, lane_info1, lane_info2, lane_id 87 | returns: 88 | Dict[] 89 | """ 90 | nbrs_nd = np.empty((0, 4)) 91 | nbrs_gt = np.empty((0, 3)) 92 | lane_nd = np.empty((0, 7)) 93 | 94 | # agent features 95 | # input: xy,ts,mask 96 | agent_len = agent_feature[0].shape[0] 97 | agent_nd = np.hstack( 98 | (agent_feature[0], agent_feature[1].reshape((-1, 1)), np.ones((agent_len, 1)))) 99 | assert agent_nd.shape[1] == 4, "agent_traj feature dim 1 is not correct" 100 | # gt: xy, mask 101 | gt_len = agent_feature[-1].shape[0] 102 | agent_gt = np.hstack((agent_feature[-1], np.ones((gt_len, 1)))) 103 | assert agent_gt.shape[1] == 3 104 | 105 | # obj features 106 | # input: xy,ts,mask 107 | # gt: xy, mask 108 | if(len(obj_feature_ls) > 0): 109 | for obj_feature in obj_feature_ls: 110 | obj_len = obj_feature[0].shape[0] 111 | obj_nd = np.hstack((obj_feature[0], obj_feature[1].reshape( 112 | (-1, 1)), obj_feature[2].reshape((-1, 1)))) 113 | assert obj_nd.shape[1] == 4, "obj_traj feature dim 1 is not correct" 114 | nbrs_nd = np.vstack([nbrs_nd, obj_nd]) 115 | 116 | gt_len = obj_feature[4].shape[0] 117 | obj_gt = np.hstack( 118 | (obj_feature[4], obj_feature[5].reshape((-1, 1)))) 119 | assert obj_gt.shape[1] == 3, "obj_gt feature dim 1 is not correct" 120 | nbrs_gt = np.vstack([nbrs_gt, obj_gt]) 121 | # nbrs_nd [nbrs_num,20,4] 122 | nbrs_nd = nbrs_nd.reshape([-1, 20, 4]) 123 | # nbrs_gt [nbrs_num,30,3] 124 | nbrs_gt = nbrs_gt.reshape([-1, 30, 3]) 125 | 126 | # matrix of all agents 127 | if(len(obj_feature_ls)>0): 128 | all_agents_nd = np.concatenate([agent_nd.reshape(1,-1,4),nbrs_nd]) 129 | all_agents_gt = np.concatenate([agent_gt.reshape(1,-1,3),nbrs_gt]) 130 | else: 131 | all_agents_nd = agent_nd.reshape(1,-1,4) 132 | all_agents_gt = agent_gt.reshape(1,-1,3) 133 | 134 | # lane ids: (large integer) 135 | lane_id = np.array(nearby_lane_ids) 136 | 137 | # saving 138 | dic = { 139 | "HISTORY": all_agents_nd.astype(np.float32), 140 | "FUTURE": all_agents_gt.astype(np.float32), 141 | "LANE_ID": lane_id.astype(np.int32), 142 | "NORM_CENTER": norm_center.astype(np.float32), 143 | "VALID_LEN": np.array((len(all_agents_nd), len(lane_id))), 144 | "CITY_NAME": city_name 145 | } 146 | 147 | return dic 148 | -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/lane_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | # Only support nearby in our implementation 7 | def get_nearby_lane_feature_ls(am, agent_df, obs_len, city_name, lane_radius, norm_center, has_attr=False, mode='nearby', query_bbox=None): 8 | ''' 9 | compute lane features 10 | args: 11 | norm_center: np.ndarray 12 | mode: 'nearby' return nearby lanes within the radius; 'rect' return lanes within the query bbox 13 | **kwargs: query_bbox= List[int, int, int, int] 14 | returns: 15 | list of list of lane a segment feature, formatted in [centerline, is_intersection, turn_direction, is_traffic_control, lane_id, 16 | predecessor_lanes, successor_lanes, adjacent_lanes] 17 | ''' 18 | query_x, query_y = agent_df[['X', 'Y']].values[obs_len-1] 19 | nearby_lane_ids = am.get_lane_ids_in_xy_bbox( 20 | query_x, query_y, city_name, lane_radius) 21 | 22 | return nearby_lane_ids 23 | -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/map_utils_vec.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | from argoverse.map_representation.map_api import ArgoverseMap 6 | 7 | 8 | def save_map(dir_, name = f"map.pkl"): 9 | am = ArgoverseMap() 10 | lane_dict = am.build_centerline_index() 11 | 12 | # go through each lane segment 13 | dic = {"PIT": [], "MIA": []} 14 | lane_id2idx = {"PIT": {}, "MIA": {}} 15 | for city_name in ["PIT", "MIA"]: 16 | 17 | for i, lane_id in enumerate(lane_dict[city_name].keys()): 18 | # extract from API 19 | lane_cl = am.get_lane_segment_centerline(lane_id, city_name) 20 | centerline = lane_cl[:, :2] 21 | is_intersection = am.lane_is_in_intersection(lane_id, city_name) 22 | turn_direction = am.get_lane_turn_direction(lane_id, city_name) 23 | traffic_control = am.lane_has_traffic_control_measure( 24 | lane_id, city_name) 25 | lane_info1 = 1 26 | if(is_intersection): 27 | lane_info1 = 2 28 | lane_info2 = 1 29 | if(turn_direction == "LEFT"): 30 | lane_info2 = 2 31 | elif(turn_direction == "RIGHT"): 32 | lane_info2 = 3 33 | lane_info3 = 1 34 | if(traffic_control): 35 | lane_info3 = 2 36 | 37 | lane_len = lane_cl.shape[0] 38 | 39 | # there 61 lane is not enough for size 10 40 | if lane_len < 10: 41 | lane_cl = np.pad( 42 | lane_cl, ((0, 10-lane_len), (0, 0)), "edge") 43 | lane_len = 10 44 | 45 | lane_nd = np.concatenate( 46 | [lane_cl[:, :2], 47 | np.ones((lane_len, 1)) * lane_info1, 48 | np.ones((lane_len, 1)) * lane_info2, 49 | np.ones((lane_len, 1)) * lane_info3], axis=-1) 50 | dic[city_name].append(lane_nd) 51 | lane_id2idx[city_name][lane_id] = i 52 | 53 | dic[city_name] = np.stack(dic[city_name], axis=0) 54 | 55 | # saving 56 | with open(os.path.join(dir_, name), 'wb') as f: 57 | pickle.dump([dic, lane_id2idx], f, pickle.HIGHEST_PROTOCOL) 58 | -------------------------------------------------------------------------------- /lib/dataset/preprocess_utils/object_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | def pad_track( 8 | track_df: pd.DataFrame, 9 | seq_timestamps: np.ndarray, 10 | base: int, 11 | track_len: int, 12 | raw_data_format: Dict[str, int], 13 | ) -> np.ndarray: 14 | """Pad incomplete tracks. 15 | Args: 16 | track_df (Dataframe): Dataframe for the track 17 | seq_timestamps (numpy array): All timestamps in the sequence 18 | base: base frame id (0 for observed trajectory, 20 for future trajectory) 19 | track_len (int): Length of whole trajectory (observed + future) 20 | raw_data_format (Dict): Format of the sequence 21 | Returns: 22 | padded_track_array (numpy array) 23 | """ 24 | track_vals = track_df.values 25 | track_timestamps = track_df["TIMESTAMP"].values 26 | seq_timestamps = seq_timestamps[base:base+track_len] 27 | 28 | # start and index of the track in the sequence 29 | start_idx = np.where(seq_timestamps == track_timestamps[0])[0][0] 30 | end_idx = np.where(seq_timestamps == track_timestamps[-1])[0][0] 31 | 32 | # Edge padding in front and rear, i.e., repeat the first and last coordinates 33 | # if self.PADDING_TYPE == "REPEAT" 34 | padded_track_array = np.pad(track_vals, 35 | ((start_idx, track_len - end_idx - 1), 36 | (0, 0)), "edge") 37 | 38 | mask = np.ones((end_idx+1-start_idx)) 39 | mask = np.pad(mask, (start_idx, track_len - end_idx - 1), 'constant') 40 | if padded_track_array.shape[0] < track_len: 41 | # rare case, just ignore 42 | return None, None, False 43 | 44 | # Overwrite the timestamps in padded part 45 | for i in range(padded_track_array.shape[0]): 46 | padded_track_array[i, 0] = seq_timestamps[i] 47 | assert mask.shape[0] == padded_track_array.shape[0] 48 | return padded_track_array, mask, True 49 | 50 | 51 | # Important: We delete some invalid nearby agents according to our criterion!!! Refer to continue option 52 | def get_nearby_moving_obj_feature_ls(agent_df, traj_df, obs_len, seq_ts, obj_radius, norm_center, raw_dataformat, fut_len=30): 53 | """ 54 | args: 55 | returns: list of list, (track, timestamp, mask, track_id, gt_track, gt_mask) 56 | """ 57 | obj_feature_ls = [] 58 | query_x, query_y = agent_df[['X', 'Y']].values[obs_len-1] 59 | p0 = np.array([query_x, query_y]) 60 | for track_id, remain_df in traj_df.groupby('TRACK_ID'): 61 | if remain_df['OBJECT_TYPE'].iloc[0] == 'AGENT': 62 | continue 63 | 64 | hist_df = remain_df[remain_df['TIMESTAMP'] <= 65 | agent_df['TIMESTAMP'].values[obs_len-1]] 66 | if(len(hist_df) == 0): 67 | continue 68 | # pad hist 69 | xys, ts = None, None 70 | if len(hist_df) < obs_len: 71 | paded_nd, mask, flag = pad_track( 72 | hist_df, seq_ts, 0, obs_len, raw_dataformat) 73 | if flag == False: 74 | continue 75 | xys = np.array(paded_nd[:, 3:5], dtype=np.float64) 76 | ts = np.array(paded_nd[:, 0], dtype=np.float64) 77 | else: 78 | xys = hist_df[['X', 'Y']].values 79 | ts = hist_df["TIMESTAMP"].values 80 | mask = np.ones((obs_len)) 81 | 82 | p1 = xys[-1] 83 | if mask[-1] == 0 or np.linalg.norm(p0 - p1) > obj_radius: 84 | continue 85 | if(sum(mask) <= 3): 86 | continue 87 | 88 | fut_df = remain_df[remain_df['TIMESTAMP'] > 89 | agent_df['TIMESTAMP'].values[obs_len-1]] 90 | # pad future 91 | gt_xys = None 92 | if len(fut_df) == 0: 93 | gt_xys = np.zeros((fut_len, 2))+norm_center 94 | gt_mask = np.zeros((fut_len)) 95 | elif len(fut_df) < fut_len: 96 | paded_nd, gt_mask, flag = pad_track( 97 | fut_df, seq_ts, obs_len, fut_len, raw_dataformat) 98 | if flag == False: 99 | continue 100 | gt_xys = np.array(paded_nd[:, 3:5], dtype=np.float64) 101 | else: 102 | gt_xys = fut_df[['X', 'Y']].values 103 | gt_mask = np.ones((fut_len)) 104 | 105 | xys -= norm_center # normalize to last observed timestamp point of agent 106 | gt_xys -= norm_center 107 | 108 | assert xys.shape[0] == obs_len 109 | assert mask.shape[0] == obs_len 110 | assert gt_xys.shape[0] == fut_len 111 | assert gt_mask.shape[0] == fut_len 112 | obj_feature_ls.append( 113 | [xys, ts, mask, track_id, gt_xys, gt_mask]) 114 | return obj_feature_ls 115 | -------------------------------------------------------------------------------- /lib/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from sklearn.linear_model import LinearRegression 4 | 5 | 6 | def get_heading_angle(traj: np.ndarray): 7 | """ 8 | get the heading angle 9 | traj: [N,2] N>=6 10 | """ 11 | # length == 6 12 | # sort position 13 | _traj = traj.copy() 14 | traj = traj.copy() 15 | 16 | traj = traj[traj[:, 0].argsort()] 17 | traj = traj[traj[:, 1].argsort()] 18 | 19 | if traj.T[0].max()-traj.T[0].min() > traj.T[1].max()-traj.T[1].min(): # * dominated by x 20 | reg = LinearRegression().fit(traj[:, 0].reshape(-1, 1), traj[:, 1]) 21 | traj_dir = _traj[-2:].mean(0) - _traj[:2].mean(0) 22 | reg_dir = np.array([1, reg.coef_[0]]) 23 | angle = np.arctan(reg.coef_[0]) 24 | else: 25 | # using y as sample and x as the target to fit a line 26 | reg = LinearRegression().fit(traj[:, 1].reshape(-1, 1), traj[:, 0]) 27 | traj_dir = _traj[-2:].mean(0) - _traj[:2].mean(0) 28 | reg_dir = np.array([reg.coef_[0], 1])*np.sign(reg.coef_[0]) 29 | if reg.coef_[0] == 0: 30 | import pdb 31 | pdb.set_trace() 32 | angle = np.arctan(1/reg.coef_[0]) 33 | 34 | if angle < 0: 35 | angle = 2*np.pi + angle 36 | if (reg_dir*traj_dir).sum() < 0: # not same direction 37 | angle = (angle+np.pi) % (2*np.pi) 38 | # angle from y 39 | angle_to_y = angle-np.pi/2 40 | angle_to_y = -angle_to_y 41 | return angle_to_y 42 | 43 | 44 | def transform_coord(coords, angle): 45 | x = coords[..., 0] 46 | y = coords[..., 1] 47 | x_transform = np.cos(angle)*x-np.sin(angle)*y 48 | y_transform = np.cos(angle)*y+np.sin(angle)*x 49 | output_coords = np.stack((x_transform, y_transform), axis=-1) 50 | 51 | return output_coords 52 | 53 | 54 | def transform_coord_flip(coords, angle): 55 | x = coords[:, 0] 56 | y = coords[:, 1] 57 | x_transform = math.cos(angle)*x-math.sin(angle)*y 58 | y_transform = math.cos(angle)*y+math.sin(angle)*x 59 | x_transform = -1*x_transform # flip 60 | # y_transform = -1*y_transform # flip 61 | output_coords = np.stack((x_transform, y_transform), axis=-1) 62 | return output_coords 63 | -------------------------------------------------------------------------------- /lib/dataset/vectorization.py: -------------------------------------------------------------------------------- 1 | from sklearn.linear_model import LinearRegression 2 | from torch.utils.data import DataLoader, Dataset 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from .utils import get_heading_angle, transform_coord 9 | 10 | 11 | class VectorizedCase(object): 12 | 13 | def __init__(self, cfg): 14 | 15 | self.striaghten = True 16 | self.max_agent_num = 68 17 | self.pad = False 18 | 19 | def get_straighten_angle(self, features): 20 | ''' 21 | agent_features: [20,5] 22 | -------------Calculate-Angle------------------- 23 | trajs which feed into func must satisfy following condition: 24 | 1. long enough (l > 2m) 25 | 2. have same direction 26 | ''' 27 | agent_features = features[0] 28 | 29 | ct = 19 - 6 30 | coord1, coord2 = agent_features[ct, :2], agent_features[19, :2] 31 | traj_dir = agent_features[-1, :2] - agent_features[0, :2] 32 | current_dir = coord2 - coord1 33 | while (np.linalg.norm(coord1-coord2, ord=2) < 2 or (current_dir*traj_dir).sum() < 0) and ct > 0: 34 | ct -= 1 35 | coord1 = agent_features[ct, :2] 36 | current_dir = coord2 - coord1 37 | 38 | theta = get_heading_angle(agent_features[ct:, :2]) 39 | 40 | return theta 41 | 42 | def get_history_traj(self, features, theta): 43 | ''' 44 | features: trajectory features with size (number of agents, history frame num, 5) 45 | Notes: index 0 of axis 0 is the target agent. 46 | ''' 47 | 48 | num_agent = features.shape[0] 49 | features = features[..., :4] 50 | 51 | if self.striaghten: 52 | features = features.reshape(-1, 4) 53 | features[:, :2] = transform_coord(features[:, :2], theta) 54 | features = features.reshape(num_agent, 20, 4) 55 | 56 | v = features[:, 1:, :2] - features[:, :-1, :2] # na, 19, 2 57 | ts = (features[:, 1:, 2] + features[:, :-1, 2])/2 # na, 19 58 | mask = features[:, 1:, 3]*features[:, :-1, 3] # 1,1 =>1; 1,0 =>0; 0,0=>0 59 | 60 | hist_traj = np.concatenate( 61 | [v, ts.reshape(-1, 19, 1), mask.reshape(-1, 19, 1)], -1) 62 | pos = features[:, -1, :2] 63 | assert hist_traj.shape == (num_agent, 19, 4) 64 | assert pos.shape == (num_agent, 2) 65 | 66 | if self.pad: 67 | # padding data 68 | hist_traj = np.pad( 69 | hist_traj, ((0, self.max_agent_num - num_agent), (0, 0), (0, 0)), "constant") 70 | pos = np.pad( 71 | pos, ((0, self.max_agent_num - num_agent), (0, 0)), "constant") 72 | 73 | return dict( 74 | HISTORY=hist_traj, 75 | POS=pos, 76 | ) 77 | 78 | def get_future_traj(self, features, pos, theta): 79 | ''' 80 | pos: nbr2target_translate (n_agent, 2) 81 | features: trajectory features with size (number of agents, history frame num, 3) 82 | Notes: index 0 of axis 0 is the target agent. 83 | ''' 84 | 85 | n_agents = features.shape[0] 86 | 87 | if self.striaghten: 88 | features = features.reshape(-1, 3) 89 | features[:, :2] = transform_coord(features[:, :2], theta) 90 | features = features.reshape(-1, 30, 3) 91 | 92 | v = np.concatenate([(features[:, 0, :2] - pos).reshape(-1, 1, 2), 93 | features[:, 1:, :2]-features[:, :-1, :2]], 1) 94 | mask = features[:, :, 2].reshape(-1, 30, 1) 95 | future_traj = np.concatenate([v, mask], axis=-1) 96 | 97 | assert future_traj.shape == (n_agents, 30, 3) 98 | 99 | if self.pad: 100 | future_traj = np.pad( 101 | future_traj, ((0, self.max_agent_num-n_agents), (0, 0), (0, 0)), "constant") 102 | 103 | return dict(FUTURE=future_traj,) 104 | 105 | def process_case(self, data): 106 | ''' 107 | vectorized each cases 108 | data: [ 109 | FEATURE 110 | GT 111 | LANE 112 | ] 113 | 114 | out: 115 | ["HIST", "FUTURE", "POS", "VALID_AGENT","VALID_AGENT","LANE","MAX_LEN","THETA", "NAME"] 116 | 117 | ''' 118 | 119 | theta = self.get_straighten_angle(data['HISTORY']) 120 | data['THETA'] = theta 121 | 122 | # --------------Trajectory------------------------------------------------------ 123 | 124 | # Histroy traj 125 | hist_dict = self.get_history_traj(data['HISTORY'], data['THETA']) 126 | data.update(hist_dict) 127 | 128 | # Future traj 129 | future_dict = self.get_future_traj(data['FUTURE'], data['POS'], data['THETA']) 130 | data.update(future_dict) 131 | 132 | return data 133 | 134 | def transform_coord(self, coords, angle): 135 | x = coords[:, 0] 136 | y = coords[:, 1] 137 | x_transform = math.cos(angle)*x-math.sin(angle)*y 138 | y_transform = math.cos(angle)*y+math.sin(angle)*x 139 | output_coords = np.stack((x_transform, y_transform), axis=-1) 140 | return output_coords 141 | -------------------------------------------------------------------------------- /lib/models/TF_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from copy import deepcopy 5 | import math 6 | 7 | 8 | class EncoderDecoder(nn.Module): 9 | """ 10 | A standard Encoder-Decoder architecture. Base for this and many 11 | other models. 12 | """ 13 | 14 | def __init__(self, encoder, decoder, src_embed): 15 | super(EncoderDecoder, self).__init__() 16 | self.encoder = encoder 17 | self.decoder = decoder 18 | self.src_embed = src_embed 19 | 20 | def forward(self, src, tgt, src_mask, tgt_mask, query_pos=None): 21 | """ 22 | Take in and process masked src and target sequences. 23 | """ 24 | output = self.encode(src, src_mask) 25 | return self.decode(output, src_mask, tgt, tgt_mask, query_pos) 26 | 27 | def encode(self, src, src_mask): 28 | return self.encoder(self.src_embed(src), src_mask) 29 | 30 | def decode(self, memory, src_mask, tgt, tgt_mask, query_pos=None): 31 | return self.decoder(tgt, memory, src_mask, tgt_mask, query_pos) 32 | 33 | 34 | class Encoder(nn.Module): 35 | """ 36 | Core encoder is a stack of N layers 37 | """ 38 | 39 | def __init__(self, layer, n): 40 | super(Encoder, self).__init__() 41 | self.layers = clones(layer, n) 42 | self.norm = nn.LayerNorm(layer.size) 43 | 44 | def forward(self, x, x_mask): 45 | """ 46 | Pass the input (and mask) through each layer in turn. 47 | """ 48 | for layer in self.layers: 49 | x = layer(x, x_mask) 50 | return self.norm(x) 51 | 52 | 53 | class EncoderLayer(nn.Module): 54 | """ 55 | Encoder is made up of self-attn and feed forward (defined below) 56 | """ 57 | 58 | def __init__(self, size, self_attn, feed_forward, dropout): 59 | super(EncoderLayer, self).__init__() 60 | self.self_attn = self_attn 61 | self.feed_forward = feed_forward 62 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 63 | self.size = size 64 | 65 | def forward(self, x, mask): 66 | """ 67 | Follow Figure 1 (left) for connections. 68 | """ 69 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 70 | return self.sublayer[1](x, self.feed_forward) 71 | 72 | 73 | class Decoder(nn.Module): 74 | """ 75 | Generic N layer decoder with masking. 76 | """ 77 | 78 | def __init__(self, layer, n, return_intermediate=False): 79 | super(Decoder, self).__init__() 80 | self.layers = clones(layer, n) 81 | self.norm = nn.LayerNorm(layer.size) 82 | self.return_intermediate = return_intermediate 83 | 84 | def forward(self, x, memory, src_mask, tgt_mask, query_pos=None): 85 | 86 | intermediate = [] 87 | 88 | for layer in self.layers: 89 | x = layer(x, memory, src_mask, tgt_mask, query_pos) 90 | 91 | if self.return_intermediate: 92 | intermediate.append(self.norm(x)) 93 | 94 | if self.norm is not None: 95 | x = self.norm(x) 96 | if self.return_intermediate: 97 | intermediate.pop() 98 | intermediate.append(x) 99 | 100 | if self.return_intermediate: 101 | return torch.stack(intermediate) 102 | 103 | return x 104 | 105 | 106 | class DecoderLayer(nn.Module): 107 | """ 108 | Decoder is made of self-attn, src-attn, and feed forward (defined below) 109 | """ 110 | 111 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 112 | super(DecoderLayer, self).__init__() 113 | self.size = size 114 | self.self_attn = self_attn 115 | self.src_attn = src_attn 116 | self.feed_forward = feed_forward 117 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 118 | 119 | # TODO How to fusion the feature 120 | def with_pos_embed(self, tensor, pos=None): 121 | return tensor if pos is None else tensor + pos 122 | 123 | def forward(self, x, memory, src_mask, tgt_mask, query_pos=None): 124 | """ 125 | Follow Figure 1 (right) for connections. 126 | """ 127 | m = memory 128 | q = k = self.with_pos_embed(x, query_pos) 129 | x = self.sublayer[0](x, lambda x: self.self_attn(q, k, x, tgt_mask)) 130 | x = self.with_pos_embed(x, query_pos) 131 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 132 | return self.sublayer[2](x, self.feed_forward) 133 | 134 | 135 | class MultiHeadAttention(nn.Module): 136 | def __init__(self, h, d_model, dropout=0.1): 137 | """ 138 | Take in model size and number of heads. 139 | """ 140 | super(MultiHeadAttention, self).__init__() 141 | assert d_model % h == 0 142 | # We assume d_v always equals d_k 143 | self.d_k = d_model // h 144 | self.h = h 145 | self.linears = clones(nn.Linear(d_model, d_model, bias=True), 4) 146 | self.attn = None 147 | self.dropout = nn.Dropout(p=dropout) 148 | 149 | def forward(self, query, key, value, mask=None): 150 | """ 151 | Implements Figure 2 152 | """ 153 | if len(query.shape) > 3: 154 | batch_dim = len(query.shape)-2 155 | batch = query.shape[:batch_dim] 156 | mask_dim = batch_dim 157 | else: 158 | batch = (query.shape[0],) 159 | mask_dim = 1 160 | if mask is not None: 161 | # Same mask applied to all h heads. 162 | mask = mask.unsqueeze(dim=mask_dim) 163 | 164 | # 1) Do all the linear projections in batch from d_model => h x d_k 165 | query, key, value = [l(x).view(*batch, -1, self.h, self.d_k).transpose(-3, -2) for l, x in 166 | zip(self.linears, (query, key, value))] 167 | 168 | # 2) Apply attention on all the projected vectors in batch. 169 | x, self.attn = attention( 170 | query, key, value, mask=mask, dropout=self.dropout) 171 | # 3) "Concat" using a view and apply a final linear. 172 | x = x.transpose(-3, -2).contiguous().view(* 173 | batch, -1, self.h * self.d_k) 174 | return self.linears[-1](x) 175 | 176 | 177 | class PointerwiseFeedforward(nn.Module): 178 | """ 179 | Implements FFN equation. 180 | """ 181 | 182 | def __init__(self, d_model, d_ff, dropout=0.1): 183 | super(PointerwiseFeedforward, self).__init__() 184 | self.w_1 = nn.Linear(d_model, d_ff, bias=True) 185 | self.w_2 = nn.Linear(d_ff, d_model, bias=True) 186 | self.dropout = nn.Dropout(dropout) 187 | self.relu = nn.ReLU() 188 | 189 | def forward(self, x): 190 | return self.w_2(self.dropout(self.relu(self.w_1(x)))) 191 | 192 | 193 | class SublayerConnection(nn.Module): 194 | """ 195 | A residual connection followed by a layer norm. 196 | Note for code simplicity the norm is first as opposed to last. 197 | """ 198 | 199 | def __init__(self, size, dropout): 200 | super(SublayerConnection, self).__init__() 201 | self.norm = nn.LayerNorm(size) 202 | self.dropout = nn.Dropout(dropout) 203 | 204 | def forward(self, x, sublayer): 205 | """ 206 | Apply residual connection to any sublayer with the same size. 207 | """ 208 | return x + self.dropout(sublayer(self.norm(x))) 209 | 210 | 211 | def clones(module, n): 212 | """ 213 | Produce N identical layers. 214 | """ 215 | assert isinstance(module, nn.Module) 216 | return nn.ModuleList([deepcopy(module) for _ in range(n)]) 217 | 218 | 219 | def attention(query, key, value, mask=None, dropout=None): 220 | """ 221 | Compute 'Scaled Dot Product Attention' 222 | """ 223 | d_k = query.size(-1) 224 | 225 | # Q,K,V: [bs,h,num,dim] 226 | # scores: [bs,h,num1,num2] 227 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 228 | # mask: [bs,1,1,num2] => dimension expansion 229 | 230 | if mask is not None: 231 | scores = scores.masked_fill_(mask == 0, value=-1e9) 232 | p_attn = torch.softmax(scores, dim=-1) 233 | if dropout is not None: 234 | p_attn = dropout(p_attn) 235 | return torch.matmul(p_attn, value), p_attn 236 | 237 | 238 | class LinearEmbedding(nn.Module): 239 | def __init__(self, inp_size, d_model): 240 | super(LinearEmbedding, self).__init__() 241 | # lut => lookup table 242 | self.lut = nn.Linear(inp_size, d_model, bias=True) 243 | self.d_model = d_model 244 | 245 | def forward(self, x): 246 | return self.lut(x) * math.sqrt(self.d_model) 247 | 248 | 249 | class PositionalEncoding(nn.Module): 250 | """ 251 | Implement the PE function. 252 | """ 253 | 254 | def __init__(self, d_model, dropout, max_len=5000): 255 | super(PositionalEncoding, self).__init__() 256 | self.dropout = nn.Dropout(p=dropout) 257 | 258 | # Compute the positional encodings once in log space. 259 | pe = torch.zeros(max_len, d_model) 260 | position = torch.arange(0, max_len).unsqueeze(1).float() 261 | div_term = torch.exp(torch.arange( 262 | 0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 263 | pe[:, 0::2] = torch.sin(position * div_term) 264 | pe[:, 1::2] = torch.cos(position * div_term) 265 | self.register_buffer('pe', pe) 266 | 267 | def forward(self, x): 268 | # x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 269 | x = x + Variable(self.pe[:x.shape[-2]], requires_grad=False) 270 | return self.dropout(x) 271 | 272 | 273 | # for 626 274 | class GeneratorWithParallelHeads626(nn.Module): 275 | def __init__(self, d_model, out_size, dropout, reg_h_dim=128, dis_h_dim=128, cls_h_dim=128): 276 | super(GeneratorWithParallelHeads626, self).__init__() 277 | self.reg_mlp = nn.Sequential( 278 | nn.Linear(d_model, reg_h_dim*2, bias=True), 279 | nn.LayerNorm(reg_h_dim*2), 280 | nn.ReLU(), 281 | nn.Linear(reg_h_dim*2, reg_h_dim, bias=True), 282 | nn.Linear(reg_h_dim, out_size, bias=True)) 283 | self.dis_emb = nn.Linear(2, dis_h_dim, bias=True) 284 | self.cls_FFN = PointerwiseFeedforward( 285 | d_model, 2*d_model, dropout=dropout) 286 | self.classification_layer = nn.Sequential( 287 | nn.Linear(d_model, cls_h_dim), 288 | nn.Linear(cls_h_dim, 1, bias=True)) 289 | self.cls_opt = nn.Softmax(dim=-1) 290 | 291 | def forward(self, x): 292 | pred = self.reg_mlp(x) 293 | pred = pred.view(*pred.shape[0:3], -1, 2).cumsum(dim=-2) 294 | # return pred 295 | cls_h = self.cls_FFN(x) 296 | cls_h = self.classification_layer(cls_h).squeeze(dim=-1) 297 | conf = self.cls_opt(cls_h) 298 | return pred, conf 299 | 300 | 301 | class GeneratorWithParallelHeads(nn.Module): 302 | def __init__(self, d_model, out_size, dropout, reg_h_dim=128, region_proposal_num=6): 303 | super(GeneratorWithParallelHeads, self).__init__() 304 | self.reg_mlp = nn.Sequential( 305 | nn.Linear(d_model, reg_h_dim*2, bias=True), 306 | nn.ReLU(), 307 | nn.Linear(reg_h_dim*2, reg_h_dim, bias=True), 308 | nn.ReLU(), 309 | nn.Linear(reg_h_dim, out_size, bias=True)) 310 | # self.dis_emb = nn.Linear(2, dis_h_dim, bias=True) 311 | self.cls_FFN = PointerwiseFeedforward( 312 | d_model, 2*d_model, dropout=dropout) 313 | self.classification_layer = nn.Sequential( 314 | nn.Linear(d_model, d_model//2, bias=True), 315 | nn.Linear(d_model//2, 1, bias=True)) 316 | #self.cls_opt = nn.Softmax(dim=-1) 317 | self.cls_opt = torch.nn.LogSoftmax(dim=-1) 318 | 319 | def forward(self, x): 320 | pred = self.reg_mlp(x) 321 | pred = pred.view(*pred.shape[:-1], -1, 2).cumsum(dim=-2) 322 | # endpoint = pred[...,-1,:].squeeze(dim=-2).detach() 323 | # x = torch.cat((x, endpoint), dim=-1) 324 | cls_h = self.cls_FFN(x) 325 | cls_h = self.classification_layer(cls_h).squeeze(dim=-1) 326 | conf = self.cls_opt(cls_h) 327 | return pred, conf 328 | 329 | 330 | def split_dim(x: torch.Tensor, split_shape: tuple, dim: int): 331 | if dim < 0: 332 | dim = len(x.shape) + dim 333 | return x.reshape(*x.shape[:dim], *split_shape, *x.shape[dim+1:]) 334 | -------------------------------------------------------------------------------- /lib/models/TF_version/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/models/TF_version/__init__.py -------------------------------------------------------------------------------- /lib/models/TF_version/stacked_transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ..TF_utils import (Decoder, DecoderLayer, Encoder, EncoderDecoder, 9 | EncoderLayer, GeneratorWithParallelHeads626, 10 | LinearEmbedding, MultiHeadAttention, 11 | PointerwiseFeedforward, PositionalEncoding, 12 | SublayerConnection) 13 | 14 | 15 | class STF(nn.Module): 16 | def __init__(self, cfg): 17 | super(STF, self).__init__() 18 | "Helper: Construct a model from hyperparameters." 19 | 20 | # Hyperparameters from cfg 21 | hist_inp_size = cfg['in_channels'] 22 | lane_inp_size = cfg['enc_dim'] 23 | num_queries = cfg['queries'] 24 | dec_inp_size = cfg['queries_dim'] 25 | dec_out_size = cfg['out_channels'] 26 | # Hyperparameters predefined 27 | N = 2 28 | N_lane = 2 29 | N_social = 2 30 | d_model = 128 31 | d_ff = 256 32 | pos_dim = 64 33 | dist_dim = 128 34 | h = 2 35 | dropout = 0 36 | # 37 | 38 | self.aux_loss = cfg['aux_task'] 39 | c = copy.deepcopy 40 | dropout_atten = dropout 41 | #dropout_atten = 0.1 42 | attn = MultiHeadAttention(h, d_model, dropout=dropout_atten) 43 | ff = PointerwiseFeedforward(d_model, d_ff, dropout) 44 | position = PositionalEncoding(d_model, dropout) 45 | 46 | self.hist_tf = EncoderDecoder( 47 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 48 | Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), 49 | nn.Sequential(LinearEmbedding(hist_inp_size, d_model), c(position)) 50 | ) 51 | self.lane_enc = Encoder(EncoderLayer( 52 | d_model, c(attn), c(ff), dropout), N_lane) 53 | self.lane_dec = Decoder(DecoderLayer( 54 | d_model, c(attn), c(attn), c(ff), dropout), N_lane) 55 | self.lane_emb = LinearEmbedding(lane_inp_size, d_model) 56 | 57 | self.pos_emb = nn.Sequential( 58 | nn.Linear(2, pos_dim, bias=True), 59 | nn.LayerNorm(pos_dim), 60 | nn.ReLU(), 61 | nn.Linear(pos_dim, pos_dim, bias=True)) 62 | self.dist_emb = nn.Sequential( 63 | nn.Linear(num_queries*d_model, dist_dim, bias=True), 64 | nn.LayerNorm(dist_dim), 65 | nn.ReLU(), 66 | nn.Linear(dist_dim, dist_dim, bias=True)) 67 | 68 | self.fusion1 = nn.Sequential( 69 | nn.Linear(d_model+pos_dim, d_model, bias=True), 70 | nn.LayerNorm(d_model), 71 | nn.ReLU(), 72 | nn.Linear(d_model, d_model, bias=True)) 73 | self.fusion2 = nn.Sequential( 74 | nn.Linear(dist_dim+pos_dim, d_model, bias=True), 75 | nn.LayerNorm(d_model), 76 | nn.ReLU(), 77 | nn.Linear(d_model, d_model, bias=True)) 78 | self.social_enc = Encoder(EncoderLayer( 79 | d_model, c(attn), c(ff), dropout), N_social) 80 | self.social_dec = Decoder(DecoderLayer( 81 | d_model, c(attn), c(attn), c(ff), dropout), N_social) 82 | 83 | # self.g = Generator(d_model*2, dec_out_size) 84 | self.prediction_header = GeneratorWithParallelHeads626( 85 | d_model*2, dec_out_size, dropout) 86 | self.num_queries = num_queries 87 | self.query_embed = nn.Embedding(num_queries, d_model) 88 | 89 | # This was important from their code. 90 | # Initialize parameters with Glorot / fan_avg. 91 | for name, param in self.named_parameters(): 92 | # print(name) 93 | if param.dim() > 1: 94 | nn.init.xavier_uniform_(param) 95 | 96 | self.query_embed = nn.Embedding(self.num_queries, d_model) 97 | self.query_embed.weight.requires_grad == False 98 | nn.init.orthogonal_(self.query_embed.weight) 99 | 100 | # input: [inp, dec_inp, src_att, trg_att] 101 | 102 | def forward(self, traj, pos, social_num, social_mask, lane_enc, lane_mask): 103 | ''' 104 | Args: 105 | traj: [batch size, max_agent_num, 19, 4] 106 | pos: [batch size, max_agent_num, 2] 107 | social_num: float = max_agent_num 108 | social_mask: [batch size, 1, max_agent_num] 109 | lane_enc: [batch size, max_lane_num, 64] 110 | lane_mask: [batch size, 1, max_lane_num] 111 | 112 | Returns: 113 | outputs_coord: [batch size, max_agent_num, num_query, 30, 2] 114 | outputs_class: [batch size, max_agent_num, num_query] 115 | ''' 116 | 117 | self.query_batches = self.query_embed.weight.view( 118 | 1, 1, *self.query_embed.weight.shape).repeat(*traj.shape[:2], 1, 1) 119 | 120 | # Trajectory transfomer 121 | hist_out = self.hist_tf(traj, self.query_batches, None, None) 122 | pos = self.pos_emb(pos) 123 | hist_out = torch.cat([pos.unsqueeze(dim=2).repeat( 124 | 1, 1, self.num_queries, 1), hist_out], dim=-1) 125 | hist_out = self.fusion1(hist_out) 126 | 127 | # Lane encoder 128 | lane_mem = self.lane_enc(self.lane_emb(lane_enc), lane_mask) 129 | lane_mem = lane_mem.unsqueeze(1).repeat(1, social_num, 1, 1) 130 | lane_mask = lane_mask.unsqueeze(1).repeat(1, social_num, 1, 1) 131 | 132 | # Lane decoder 133 | lane_out = self.lane_dec(hist_out, lane_mem, lane_mask, None) 134 | 135 | # Fuse position information 136 | dist = lane_out.view(*traj.shape[0:2], -1) 137 | dist = self.dist_emb(dist) 138 | 139 | # Social layer 140 | social_inp = self.fusion2(torch.cat([pos, dist], -1)) 141 | social_mem = self.social_enc(social_inp, social_mask) 142 | social_out = social_mem.unsqueeze( 143 | dim=2).repeat(1, 1, self.num_queries, 1) 144 | out = torch.cat([social_out, lane_out], -1) 145 | 146 | # Prediction head 147 | outputs_coord, outputs_class = self.prediction_header(out) 148 | 149 | return outputs_coord, outputs_class 150 | 151 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/models/__init__.py -------------------------------------------------------------------------------- /lib/models/mmTransformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class LaneNet(nn.Module): 8 | def __init__(self, in_channels, hidden_unit, num_subgraph_layers): 9 | super(LaneNet, self).__init__() 10 | self.num_subgraph_layers = num_subgraph_layers 11 | self.layer_seq = nn.Sequential() 12 | for i in range(num_subgraph_layers): 13 | self.layer_seq.add_module( 14 | f'lmlp_{i}', MLP(in_channels, hidden_unit)) 15 | in_channels = hidden_unit*2 16 | 17 | def forward(self, lane): 18 | ''' 19 | Extract lane_feature from vectorized lane representation 20 | 21 | Args: 22 | lane: [batch size, max_lane_num, 9, 7] (vectorized representation) 23 | 24 | Returns: 25 | x_max: [batch size, max_lane_num, 64] 26 | ''' 27 | x = lane 28 | for name, layer in self.layer_seq.named_modules(): 29 | if isinstance(layer, MLP): 30 | # x [bs,max_lane_num,9,dim] 31 | x = layer(x) 32 | x_max = torch.max(x, -2)[0] 33 | x_max = x_max.unsqueeze(2).repeat(1, 1, x.shape[2], 1) 34 | x = torch.cat([x, x_max], dim=-1) 35 | x_max = torch.max(x, -2)[0] 36 | return x_max 37 | 38 | 39 | class MLP(nn.Module): 40 | def __init__(self, in_channels, hidden_unit, verbose=False): 41 | super(MLP, self).__init__() 42 | self.mlp = nn.Sequential( 43 | nn.Linear(in_channels, hidden_unit), 44 | nn.LayerNorm(hidden_unit), 45 | nn.ReLU() 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.mlp(x) 50 | return x 51 | 52 | 53 | class mmTrans(nn.Module): 54 | 55 | def __init__(self, stacked_transformer, cfg): 56 | super(mmTrans, self).__init__() 57 | # stacked transformer class 58 | self.stacked_transformer = stacked_transformer(cfg) 59 | 60 | lane_channels = cfg['lane_channels'] 61 | self.hist_feature_size = cfg['in_channels'] 62 | 63 | self.polyline_vec_shape = 2*cfg['subgraph_width'] 64 | self.subgraph = LaneNet( 65 | lane_channels, cfg['subgraph_width'], cfg['num_subgraph_layres']) 66 | 67 | self.FUTURE_LEN = cfg['future_num_frames'] 68 | self.OBS_LEN = cfg['history_num_frames'] - 1 69 | self.lane_length = cfg['lane_length'] 70 | 71 | def preprocess_traj(self, traj): 72 | ''' 73 | Generate the trajectory mask for all agents (including target agent) 74 | 75 | Args: 76 | traj: [batch, max_agent_num, obs_len, 4] 77 | 78 | Returns: 79 | social mask: [batch, 1, max_agent_num] 80 | 81 | ''' 82 | # social mask 83 | social_valid_len = self.traj_valid_len 84 | social_mask = torch.zeros( 85 | (self.B, 1, int(self.max_agent_num))).to(traj.device) 86 | for i in range(self.B): 87 | social_mask[i, 0, :social_valid_len[i]] = 1 88 | 89 | return social_mask 90 | 91 | def preprocess_lane(self, lane): 92 | ''' 93 | preprocess lane segments using LaneNet 94 | 95 | Args: 96 | lane: [batch size, max_lane_num, 10, 5] 97 | 98 | Returns: 99 | lane_feature: [batch size, max_lane_num, 64 (feature_dim)] 100 | lane_mask: [batch size, 1, max_lane_num] 101 | 102 | ''' 103 | 104 | # transform lane to vector 105 | lane_v = torch.cat( 106 | [lane[:, :, :-1, :2], 107 | lane[:, :, 1:, :2], 108 | lane[:, :, 1:, 2:]], dim=-1) # bxnlinex9x7 109 | 110 | # lane mask 111 | lane_valid_len = self.lane_valid_len 112 | lane_mask = torch.zeros( 113 | (self.B, 1, int(self.max_lane_num))).to(lane_v.device) 114 | for i in range(lane_valid_len.shape[0]): 115 | lane_mask[i, 0, :lane_valid_len[i]] = 1 116 | 117 | # use vector like structure process lane 118 | lane_feature = self.subgraph(lane_v) # [batch size, max_lane_num, 64] 119 | 120 | return lane_feature, lane_mask 121 | 122 | def forward(self, data: dict): 123 | """ 124 | Args: 125 | data (Data): 126 | HIST: [batch size, max_agent_num, 19, 4] 127 | POS: [batch size, max_agent_num, 2] 128 | LANE: [batch size, max_lane_num, 10, 5] 129 | VALID_LEN: [batch size, 2] (number of valid agents & valid lanes) 130 | 131 | Note: 132 | max_lane_num/max_agent_num indicates maximum number of agents/lanes after padding in a single batch 133 | """ 134 | # initialized 135 | self.B = data['HISTORY'].shape[0] 136 | 137 | self.traj_valid_len = data['VALID_LEN'][:, 0] 138 | self.max_agent_num = torch.max(self.traj_valid_len) 139 | 140 | self.lane_valid_len = data['VALID_LEN'][:, 1] 141 | self.max_lane_num = torch.max(self.lane_valid_len) 142 | 143 | # preprocess 144 | pos = data['POS'] 145 | trajs = data['HISTORY'] 146 | social_mask = self.preprocess_traj(data['HISTORY']) 147 | lane_enc, lane_mask = self.preprocess_lane(data['LANE']) 148 | 149 | out = self.stacked_transformer(trajs, pos, self.max_agent_num, 150 | social_mask, lane_enc, lane_mask) 151 | 152 | return out 153 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/utils/__init__.py -------------------------------------------------------------------------------- /lib/utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | """This module evaluates the forecasted trajectories against the ground truth.""" 4 | 5 | import math 6 | from typing import Dict, List, Optional 7 | 8 | import numpy as np 9 | from argoverse.map_representation.map_api import ArgoverseMap 10 | from lib.dataset.utils import transform_coord 11 | 12 | from scipy.special import softmax 13 | 14 | LOW_PROB_THRESHOLD_FOR_METRICS = 0.05 15 | 16 | 17 | def get_ade(forecasted_trajectory: np.ndarray, gt_trajectory: np.ndarray) -> float: 18 | """Compute Average Displacement Error. 19 | 20 | Args: 21 | forecasted_trajectory: Predicted trajectory with shape (pred_len x 2) 22 | gt_trajectory: Ground truth trajectory with shape (pred_len x 2) 23 | 24 | Returns: 25 | ade: Average Displacement Error 26 | 27 | """ 28 | pred_len = forecasted_trajectory.shape[0] 29 | ade = float( 30 | sum( 31 | math.sqrt( 32 | (forecasted_trajectory[i, 0] - gt_trajectory[i, 0]) ** 2 33 | + (forecasted_trajectory[i, 1] - gt_trajectory[i, 1]) ** 2 34 | ) 35 | for i in range(pred_len) 36 | ) 37 | / pred_len 38 | ) 39 | return ade 40 | 41 | 42 | def get_fde(forecasted_trajectory: np.ndarray, gt_trajectory: np.ndarray) -> float: 43 | """Compute Final Displacement Error. 44 | 45 | Args: 46 | forecasted_trajectory: Predicted trajectory with shape (pred_len x 2) 47 | gt_trajectory: Ground truth trajectory with shape (pred_len x 2) 48 | 49 | Returns: 50 | fde: Final Displacement Error 51 | 52 | """ 53 | fde = math.sqrt( 54 | (forecasted_trajectory[-1, 0] - gt_trajectory[-1, 0]) ** 2 55 | + (forecasted_trajectory[-1, 1] - gt_trajectory[-1, 1]) ** 2 56 | ) 57 | return fde 58 | 59 | 60 | def get_displacement_errors_and_miss_rate( 61 | forecasted_trajectories: Dict[int, List[np.ndarray]], 62 | gt_trajectories: Dict[int, np.ndarray], 63 | max_guesses: int, 64 | horizon: int, 65 | miss_threshold: float, 66 | forecasted_probabilities: Optional[Dict[int, List[float]]] = None, 67 | ) -> Dict[str, float]: 68 | """Compute min fde and ade for each sample. 69 | 70 | Note: Both min_fde and min_ade values correspond to the trajectory which has minimum fde. 71 | The Brier Score is defined here: 72 | Brier, G. W. Verification of forecasts expressed in terms of probability. Monthly weather review, 1950. 73 | https://journals.ametsoc.org/view/journals/mwre/78/1/1520-0493_1950_078_0001_vofeit_2_0_co_2.xml 74 | 75 | Args: 76 | forecasted_trajectories: Predicted top-k trajectory dict with key as seq_id and value as list of trajectories. 77 | Each element of the list is of shape (pred_len x 2). 78 | gt_trajectories: Ground Truth Trajectory dict with key as seq_id and values as trajectory of 79 | shape (pred_len x 2) 80 | max_guesses: Number of guesses allowed 81 | horizon: Prediction horizon 82 | miss_threshold: Distance threshold for the last predicted coordinate 83 | forecasted_probabilities: Probabilites associated with forecasted trajectories. 84 | 85 | Returns: 86 | metric_results: Metric values for minADE, minFDE, MR, p-minADE, p-minFDE, p-MR, brier-minADE, brier-minFDE 87 | """ 88 | metric_results: Dict[str, float] = {} 89 | min_ade, prob_min_ade, brier_min_ade = [], [], [] 90 | min_fde, prob_min_fde, brier_min_fde = [], [], [] 91 | n_misses, prob_n_misses = [], [] 92 | for k, v in gt_trajectories.items(): 93 | curr_min_ade = float("inf") 94 | curr_min_fde = float("inf") 95 | min_idx = 0 96 | max_num_traj = min(max_guesses, len(forecasted_trajectories[k])) 97 | 98 | # If probabilities available, use the most likely trajectories, else use the first few 99 | if forecasted_probabilities is not None: 100 | sorted_idx = np.argsort( 101 | [-x for x in forecasted_probabilities[k]], kind="stable") 102 | # sorted_idx = np.argsort(forecasted_probabilities[k])[::-1] 103 | pruned_probabilities = [forecasted_probabilities[k][t] 104 | for t in sorted_idx[:max_num_traj]] 105 | # Normalize 106 | prob_sum = sum(pruned_probabilities) 107 | pruned_probabilities = [p / prob_sum for p in pruned_probabilities] 108 | else: 109 | sorted_idx = np.arange(len(forecasted_trajectories[k])) 110 | pruned_trajectories = [forecasted_trajectories[k][t] 111 | for t in sorted_idx[:max_num_traj]] 112 | 113 | for j in range(len(pruned_trajectories)): 114 | fde = get_fde(pruned_trajectories[j][:horizon], v[:horizon]) 115 | if fde < curr_min_fde: 116 | min_idx = j 117 | curr_min_fde = fde 118 | curr_min_ade = get_ade( 119 | pruned_trajectories[min_idx][:horizon], v[:horizon]) 120 | min_ade.append(curr_min_ade) 121 | min_fde.append(curr_min_fde) 122 | n_misses.append(curr_min_fde > miss_threshold) 123 | 124 | if forecasted_probabilities is not None: 125 | prob_n_misses.append(1.0 if curr_min_fde > miss_threshold else ( 126 | 1.0 - pruned_probabilities[min_idx])) 127 | prob_min_ade.append( 128 | min( 129 | -np.log(pruned_probabilities[min_idx]), 130 | -np.log(LOW_PROB_THRESHOLD_FOR_METRICS), 131 | ) 132 | + curr_min_ade 133 | ) 134 | brier_min_ade.append( 135 | (1 - pruned_probabilities[min_idx]) ** 2 + curr_min_ade) 136 | prob_min_fde.append( 137 | min( 138 | -np.log(pruned_probabilities[min_idx]), 139 | -np.log(LOW_PROB_THRESHOLD_FOR_METRICS), 140 | ) 141 | + curr_min_fde 142 | ) 143 | brier_min_fde.append( 144 | (1 - pruned_probabilities[min_idx]) ** 2 + curr_min_fde) 145 | 146 | metric_results["minADE"] = sum(min_ade) / len(min_ade) 147 | metric_results["minFDE"] = sum(min_fde) / len(min_fde) 148 | metric_results["MR"] = sum(n_misses) / len(n_misses) 149 | if forecasted_probabilities is not None: 150 | metric_results["p-minADE"] = sum(prob_min_ade) / len(prob_min_ade) 151 | metric_results["p-minFDE"] = sum(prob_min_fde) / len(prob_min_fde) 152 | metric_results["p-MR"] = sum(prob_n_misses) / len(prob_n_misses) 153 | metric_results["brier-minADE"] = sum(brier_min_ade) / \ 154 | len(brier_min_ade) 155 | metric_results["brier-minFDE"] = sum(brier_min_fde) / \ 156 | len(brier_min_fde) 157 | return metric_results 158 | 159 | 160 | def get_drivable_area_compliance( 161 | forecasted_trajectories: Dict[int, List[np.ndarray]], 162 | city_names: Dict[int, str], 163 | max_n_guesses: int, 164 | ) -> float: 165 | """Compute drivable area compliance metric. 166 | 167 | Args: 168 | forecasted_trajectories: Predicted top-k trajectory dict with key as seq_id and value as list of trajectories. 169 | Each element of the list is of shape (pred_len x 2). 170 | city_names: Dict mapping sequence id to city name. 171 | max_n_guesses: Maximum number of guesses allowed. 172 | 173 | Returns: 174 | Mean drivable area compliance 175 | 176 | """ 177 | avm = ArgoverseMap() 178 | 179 | dac_score = [] 180 | 181 | for seq_id, trajectories in forecasted_trajectories.items(): 182 | city_name = city_names[seq_id] 183 | num_dac_trajectories = 0 184 | n_guesses = min(max_n_guesses, len(trajectories)) 185 | for trajectory in trajectories[:n_guesses]: 186 | raster_layer = avm.get_raster_layer_points_boolean( 187 | trajectory, city_name, "driveable_area") 188 | if np.sum(raster_layer) == raster_layer.shape[0]: 189 | num_dac_trajectories += 1 190 | 191 | dac_score.append(num_dac_trajectories / n_guesses) 192 | 193 | return sum(dac_score) / len(dac_score) 194 | 195 | 196 | def compute_forecasting_metrics( 197 | forecasted_trajectories: Dict[int, List[np.ndarray]], 198 | gt_trajectories: Dict[int, np.ndarray], 199 | city_names: Dict[int, str], 200 | max_n_guesses: int, 201 | horizon: int, 202 | miss_threshold: float, 203 | forecasted_probabilities: Optional[Dict[int, List[float]]] = None, 204 | ) -> Dict[str, float]: 205 | """Compute all the forecasting metrics. 206 | 207 | Args: 208 | forecasted_trajectories: Predicted top-k trajectory dict with key as seq_id and value as list of trajectories. 209 | Each element of the list is of shape (pred_len x 2). 210 | gt_trajectories: Ground Truth Trajectory dict with key as seq_id and values as trajectory of 211 | shape (pred_len x 2) 212 | city_names: Dict mapping sequence id to city name. 213 | max_n_guesses: Number of guesses allowed 214 | horizon: Prediction horizon 215 | miss_threshold: Miss threshold 216 | forecasted_probabilities: Normalized Probabilities associated with each of the forecasted trajectories. 217 | 218 | Returns: 219 | metric_results: Dictionary containing values for all metrics. 220 | """ 221 | metric_results = get_displacement_errors_and_miss_rate( 222 | forecasted_trajectories, 223 | gt_trajectories, 224 | max_n_guesses, 225 | horizon, 226 | miss_threshold, 227 | forecasted_probabilities, 228 | ) 229 | metric_results["DAC"] = get_drivable_area_compliance( 230 | forecasted_trajectories, city_names, max_n_guesses) 231 | 232 | print("------------------------------------------------") 233 | print(f"Prediction Horizon : {horizon}, Max #guesses (K): {max_n_guesses}") 234 | print("------------------------------------------------") 235 | print(metric_results) 236 | print("------------------------------------------------") 237 | 238 | return metric_results 239 | 240 | 241 | class FormatData: 242 | def __init__(self): 243 | 244 | self.forecasted_trajectories = {} 245 | self.gt_trajectories = {} 246 | self.forecasted_probabilities = {} 247 | self.city_names = {} 248 | 249 | def __call__(self, data, predictions): 250 | ''' 251 | format the data for argoverse evaluation program 252 | 253 | ''' 254 | names = data['NAME'] 255 | togloble = data['NORM_CENTER'] 256 | theta = data['THETA'] 257 | city_name = data['CITY_NAME'] 258 | 259 | # format predictions 260 | pred_trajs, pred_confs = predictions 261 | pred_trajs = pred_trajs[:, 0].detach().cpu().numpy() 262 | pred_confs = pred_confs[:, 0].detach().cpu().numpy() 263 | 264 | # Save data 265 | for i, name in enumerate(names): 266 | 267 | # Renormalized the predicted traj 268 | _pred_traj = transform_coord(pred_trajs[i], -theta[i]) 269 | _pred_traj = _pred_traj + togloble[i].numpy() 270 | 271 | # Renormalized groundtruth 272 | gt = data['FUTURE'][i][0,:,:2] 273 | gt = gt.cumsum(axis=-2) 274 | gt = transform_coord(gt, -theta[i]) 275 | gt = gt + togloble[i].numpy() 276 | 277 | self.city_names[name] = city_name[i] 278 | self.gt_trajectories[name] = gt 279 | self.forecasted_trajectories[name] = list(_pred_traj) 280 | self.forecasted_probabilities[name] = list(pred_confs[i]) 281 | 282 | @property 283 | def results(self): 284 | return dict( 285 | forecasted_trajectories=self.forecasted_trajectories, 286 | gt_trajectories=self.gt_trajectories, 287 | city_names=self.city_names, 288 | forecasted_probabilities=self.forecasted_probabilities) 289 | -------------------------------------------------------------------------------- /lib/utils/parallel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/decisionforce/mmTransformer/be25d26118d2dfdac72b1d1e0cf6cbf14f7f4a0b/lib/utils/parallel/__init__.py -------------------------------------------------------------------------------- /lib/utils/parallel/data_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import functools 3 | 4 | import torch 5 | 6 | 7 | def assert_tensor_type(func): 8 | 9 | @functools.wraps(func) 10 | def wrapper(*args, **kwargs): 11 | if not isinstance(args[0].data, torch.Tensor): 12 | raise AttributeError( 13 | f'{args[0].__class__.__name__} has no attribute ' 14 | f'{func.__name__} for type {args[0].datatype}') 15 | return func(*args, **kwargs) 16 | 17 | return wrapper 18 | 19 | 20 | class DataContainer: 21 | """A container for any type of objects. 22 | 23 | Typically tensors will be stacked in the collate function and sliced along 24 | some dimension in the scatter function. This behavior has some limitations. 25 | 1. All tensors have to be the same size. 26 | 2. Types are limited (numpy array or Tensor). 27 | 28 | We design `DataContainer` and `MMDataParallel` to overcome these 29 | limitations. The behavior can be either of the following. 30 | 31 | - copy to GPU, pad all tensors to the same size and stack them 32 | - copy to GPU without stacking 33 | - leave the objects as is and pass it to the model 34 | - pad_dims specifies the number of last few dimensions to do padding 35 | """ 36 | 37 | def __init__(self, 38 | data, 39 | stack=False, 40 | padding_value=0, 41 | cpu_only=False, 42 | pad_dims=2): 43 | self._data = data 44 | self._cpu_only = cpu_only 45 | self._stack = stack 46 | self._padding_value = padding_value 47 | assert pad_dims in [None, 1, 2, 3] 48 | self._pad_dims = pad_dims 49 | 50 | def __repr__(self): 51 | return f'{self.__class__.__name__}({repr(self.data)})' 52 | 53 | def __len__(self): 54 | return len(self._data) 55 | 56 | @property 57 | def data(self): 58 | return self._data 59 | 60 | @property 61 | def datatype(self): 62 | if isinstance(self.data, torch.Tensor): 63 | return self.data.type() 64 | else: 65 | return type(self.data) 66 | 67 | @property 68 | def cpu_only(self): 69 | return self._cpu_only 70 | 71 | @property 72 | def stack(self): 73 | return self._stack 74 | 75 | @property 76 | def padding_value(self): 77 | return self._padding_value 78 | 79 | @property 80 | def pad_dims(self): 81 | return self._pad_dims 82 | 83 | @assert_tensor_type 84 | def size(self, *args, **kwargs): 85 | return self.data.size(*args, **kwargs) 86 | 87 | @assert_tensor_type 88 | def dim(self): 89 | return self.data.dim() 90 | -------------------------------------------------------------------------------- /lib/utils/parallel/dataparallel.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import warnings 3 | from itertools import chain 4 | 5 | import torch 6 | from torch._utils import (_get_all_device_indices, _get_available_device_type, 7 | _get_device_index, _get_devices_properties) 8 | from torch.nn.modules import Module 9 | from torch.nn.parallel import DataParallel 10 | from torch.nn.parallel.parallel_apply import parallel_apply 11 | from torch.nn.parallel.replicate import replicate 12 | from torch.nn.parallel.scatter_gather import gather 13 | 14 | from .scatter import scatter_kwargs 15 | 16 | class MMTransDataParallel(DataParallel): 17 | 18 | def forward(self, *inputs, **kwargs): 19 | 20 | with torch.autograd.profiler.record_function("DataParallel.forward"): 21 | if not self.device_ids: 22 | return self.module(*inputs, **kwargs) 23 | 24 | for t in chain(self.module.parameters(), self.module.buffers()): 25 | if t.device != self.src_device_obj: 26 | raise RuntimeError("module must have its parameters and buffers " 27 | "on device {} (device_ids[0]) but found one of " 28 | "them on device: {}".format(self.src_device_obj, t.device)) 29 | 30 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 31 | # for forward function without any inputs, empty list and dict will be created 32 | # so the module can be executed on one device which is the first one in device_ids 33 | if not inputs and not kwargs: 34 | inputs = ((),) 35 | kwargs = ({},) 36 | 37 | if len(self.device_ids) == 1: 38 | return self.module(*-43[0], **kwargs[0]) 39 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 40 | outputs = self.parallel_apply(replicas, inputs, kwargs) 41 | return self.gather(outputs, self.output_device) 42 | 43 | def scatter(self, inputs, kwargs, device_ids): 44 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /lib/utils/parallel/scatter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | import torch 3 | from torch.nn.parallel._functions import Scatter as OrigScatter 4 | 5 | from .data_container import DataContainer 6 | from .scatter_utils import Scatter 7 | 8 | def scatter(inputs, target_gpus, dim=0): 9 | """Scatter inputs to target gpus. 10 | 11 | The only difference from original :func:`scatter` is to add support for 12 | :type:`~mmcv.parallel.DataContainer`. 13 | 14 | in mmTransformer we use seperate padding so for history and lane we needs to directly trunk 15 | and for other meta data we just follow the process of mmcv 16 | """ 17 | 18 | def scatter_map(obj): 19 | if isinstance(obj, torch.Tensor): 20 | if target_gpus != [-1]: 21 | return OrigScatter.apply(target_gpus, None, dim, obj) 22 | else: 23 | # for CPU inference we use self-implemented scatter 24 | return Scatter.forward(target_gpus, obj) 25 | if isinstance(obj, DataContainer): 26 | if obj.cpu_only: 27 | return obj.data 28 | else: 29 | return Scatter.forward(target_gpus, obj.data) 30 | if isinstance(obj, tuple) and len(obj) > 0: 31 | return list(zip(*map(scatter_map, obj))) 32 | if isinstance(obj, list) and len(obj) > 0: 33 | out = list(map(list, zip(*map(scatter_map, obj)))) 34 | return out 35 | if isinstance(obj, dict) and len(obj) > 0: 36 | out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) 37 | return out 38 | return [obj for targets in target_gpus] 39 | 40 | # After scatter_map is called, a scatter_map cell will exist. This cell 41 | # has a reference to the actual function scatter_map, which has references 42 | # to a closure that has a reference to the scatter_map cell (because the 43 | # fn is recursive). To avoid this reference cycle, we set the function to 44 | # None, clearing the cell 45 | try: 46 | return scatter_map(inputs) 47 | finally: 48 | scatter_map = None 49 | 50 | 51 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): 52 | """Scatter with support for kwargs dictionary.""" 53 | inputs = scatter(inputs, target_gpus, dim) if inputs else [] 54 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 55 | if len(inputs) < len(kwargs): 56 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 57 | elif len(kwargs) < len(inputs): 58 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 59 | inputs = tuple(inputs) 60 | kwargs = tuple(kwargs) 61 | return inputs, kwargs 62 | 63 | -------------------------------------------------------------------------------- /lib/utils/parallel/scatter_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | # we borrowed from parallel code from mmcv 3 | import torch 4 | from torch.nn.parallel._functions import _get_stream 5 | 6 | 7 | def scatter(input, devices, streams=None): 8 | """Scatters tensor across multiple GPUs.""" 9 | if streams is None: 10 | streams = [None] * len(devices) 11 | 12 | if isinstance(input, list): 13 | chunk_size = (len(input) - 1) // len(devices) + 1 14 | outputs = [ 15 | scatter(input[i], [devices[i // chunk_size]], 16 | [streams[i // chunk_size]]) for i in range(len(input)) 17 | ] 18 | return outputs 19 | elif isinstance(input, torch.Tensor): 20 | output = input.contiguous() 21 | # TODO: copy to a pinned buffer first (if copying from CPU) 22 | stream = streams[0] if output.numel() > 0 else None 23 | if devices != [-1]: 24 | with torch.cuda.device(devices[0]), torch.cuda.stream(stream): 25 | output = output.cuda(devices[0], non_blocking=True) 26 | else: 27 | # unsqueeze the first dimension thus the tensor's shape is the 28 | # same as those scattered with GPU. 29 | output = output.unsqueeze(0) 30 | return output 31 | else: 32 | raise Exception(f'Unknown type {type(input)}.') 33 | 34 | 35 | def synchronize_stream(output, devices, streams): 36 | if isinstance(output, list): 37 | chunk_size = len(output) // len(devices) 38 | for i in range(len(devices)): 39 | for j in range(chunk_size): 40 | synchronize_stream(output[i * chunk_size + j], [devices[i]], 41 | [streams[i]]) 42 | elif isinstance(output, torch.Tensor): 43 | if output.numel() != 0: 44 | with torch.cuda.device(devices[0]): 45 | main_stream = torch.cuda.current_stream() 46 | main_stream.wait_stream(streams[0]) 47 | output.record_stream(main_stream) 48 | else: 49 | raise Exception(f'Unknown type {type(output)}.') 50 | 51 | 52 | def get_input_device(input): 53 | if isinstance(input, list): 54 | for item in input: 55 | input_device = get_input_device(item) 56 | if input_device != -1: 57 | return input_device 58 | return -1 59 | elif isinstance(input, torch.Tensor): 60 | return input.get_device() if input.is_cuda else -1 61 | else: 62 | raise Exception(f'Unknown type {type(input)}.') 63 | 64 | 65 | class Scatter: 66 | 67 | @staticmethod 68 | def forward(target_gpus, input): 69 | input_device = get_input_device(input) 70 | streams = None 71 | if input_device == -1 and target_gpus != [-1]: 72 | # Perform CPU to GPU copies in a background stream 73 | streams = [_get_stream(device) for device in target_gpus] 74 | 75 | outputs = scatter(input, target_gpus, streams) 76 | # Synchronize with the copy stream 77 | if streams is not None: 78 | synchronize_stream(outputs, target_gpus, streams) 79 | 80 | return tuple(outputs) 81 | -------------------------------------------------------------------------------- /lib/utils/utilities.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Dict, List, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import yaml 8 | 9 | 10 | def load_config_data(path: str) -> dict: 11 | """Load a config data from a given path 12 | :param path: the path as a string 13 | :return: the config as a dict 14 | """ 15 | with open(path) as f: 16 | cfg: dict = yaml.load(f, Loader=yaml.FullLoader) 17 | return cfg 18 | 19 | 20 | def save_checkpoint(checkpoint_dir, model, optimizer, MR=1.0): 21 | # state_dict: a Python dictionary object that: 22 | # - for a model, maps each layer to its parameter tensor; 23 | # - for an optimizer, contains info about the optimizer’s states and hyperparameters used. 24 | state = { 25 | 'state_dict': model.state_dict(), 26 | 'optimizer': optimizer.state_dict(), 27 | 'BestMissRate': MR 28 | } 29 | 30 | torch.save(state, checkpoint_dir) 31 | print('model saved to %s' % checkpoint_dir) 32 | 33 | 34 | def load_checkpoint(checkpoint_path, model, optimizer=None): 35 | state = torch.load(checkpoint_path) 36 | model.load_state_dict(state['state_dict']) 37 | print('model loaded from %s' % checkpoint_path) 38 | return model 39 | 40 | 41 | def load_model_class(model_name): 42 | import importlib 43 | module_path = f'lib.models.TF_version.{model_name}' 44 | module_name = 'STF' 45 | target_module = importlib.import_module(module_path) 46 | target_class = getattr(target_module, module_name) 47 | return target_class 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | state = torch.load('./models/demo.pt') 53 | state_dict = state['state_dict'] 54 | 55 | from collections import OrderedDict 56 | 57 | new_state_dict = OrderedDict() 58 | 59 | for k,v in state_dict.items(): 60 | 61 | components = k.split('.') 62 | 63 | if components[1] == 'STF': 64 | new_k: str = ['stacked_transformer',] + components[2:] 65 | new_k = '.'.join(new_k) 66 | else: 67 | new_k: str = components[1:] 68 | new_k = '.'.join(new_k) 69 | 70 | new_state_dict[new_k] = v 71 | 72 | new_state = {'state_dict':new_state_dict} 73 | torch.save(new_state, './models/new_demo.pt') 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | addict 2 | yapf --------------------------------------------------------------------------------