├── LICENSE
├── README.md
├── SmartRefine_talk.pdf
├── assets
├── pipeline.png
└── visualization.png
├── ckpts
└── version_5709064
│ ├── checkpoints
│ └── epoch=31-step=205951.ckpt
│ └── events.out.tfevents.1695496506.SH-IDC1-10-5-36-118.21080.0
├── datamodules
├── __init__.py
└── argoverse_v1_datamodule.py
├── datasets
├── __init__.py
└── argoverse_v1_dataset.py
├── eval.py
├── eval.sh
├── eval_store.py
├── losses
├── __init__.py
├── laplace_nll_loss.py
├── score_reg_l1_loss.py
└── soft_target_cross_entropy_loss.py
├── metrics
├── __init__.py
├── ade.py
├── fde.py
└── mr.py
├── models
├── __init__.py
├── decoder.py
├── embedding.py
├── local_encoder.py
├── refine.py
└── target_region.py
├── requirements.txt
├── train.py
├── train.sh
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SmartRefine: A Scenario-Adaptive Refinement Framework for Efficient Motion Prediction
2 |
3 | **_Fast Takeaway_:** We introduce a novel approach to refining motion predictions in autonomous vehicle navigation with minimal additional computation by leveraging scenario-specific properties and adaptive refinement iterations.
4 | 
5 | > Yang Zhou\* , [Hao Shao](http://hao-shao.com/)\* , [Letian Wang](https://letianwang0.wixsite.com/myhome) , [Steven L. Waslander](https://www.trailab.utias.utoronto.ca/stevenwaslander) , [Hongsheng Li](http://www.ee.cuhk.edu.hk/~hsli/) , [Yu Liu](https://liuyu.us/)$^\dagger$.
6 |
7 | This repository contains the official implementation of [SmartRefine: A Scenario-Adaptive Refinement Framework for Efficient Motion Prediction](https://arxiv.org/abs/2403.11492) published in _CVPR 2024_.
8 |
9 | If you have any concern, feel free to contact: kmzy at hnu.edu.cn or kmzy99 at gmail.com.
10 |
11 | [](https://hits.seeyoufarm.com)
12 | [](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
13 |
14 | ## News
15 |
16 | - `[04 Jun., 2024]` We gave a talk at [自动驾驶之心](https://www.zdjszx.com/), the slides can be seen [here](https://github.com/opendilab/SmartRefine/blob/main/SmartRefine_talk.pdf).
17 |
18 | - `[22 Mar., 2024]` We released our code for [Argoverse 1](https://github.com/argoverse/argoverse-api). Give it a try!
19 | - `[18 Mar., 2024]` We released our SmartRefine paper on [_arXiv_](https://arxiv.org/abs/2403.11492).
20 | - `[27 Feb., 2024]` Our SmartRefine was accepted by _CVPR 2024_.
21 |
22 | ## Getting Started
23 | 1\. Clone this repository:
24 | ```bash
25 | cd $YOUR_WORK_SPACE
26 | git clone https://github.com/opendilab/SmartRefine.git
27 | cd SmartRefine
28 | ```
29 | 2\. Install the dependencies:
30 | ```bash
31 | pip install -r requirements.txt
32 | cd ../
33 | ```
34 | You can selectively configure the environment in your favorite way.
35 |
36 | 3\. Install the [Argoverse-API](https://github.com/argoverse/argoverse-api?tab=readme-ov-file#installation) and download the [Argoverse Motion Forecasting Dataset v1.1](https://www.argoverse.org/av1.html) following the corresponding User Guide under `$YOUR_WORK_SPACE`. Here is an example of extracting the downloaded Argoverse data:
37 |
38 | ```bash
39 | cd $YOUR_WORK_SPACE
40 | mkdir argo1_data
41 | tar xzvf forecasting_train_v1.1.tar.gz -C ./argo1_data
42 | tar xzvf forecasting_val_v1.1.tar.gz -C ./argo1_data
43 | ```
44 |
45 | 4\. Download the prediction backbone's outputs at [Here](https://openxlab.org.cn/datasets/kmzy99/SmartRefine/tree/main/prediction_data) and extract:
46 |
47 | ```bash
48 | cd $YOUR_WORK_SPACE
49 | mkdir p1_data
50 | unzip hivt_p1_data.zip -d ./p1_data
51 | ```
52 |
53 | The final fles inside `$YOUR_WORK_SPACE` should be organized as follows:
54 |
55 | ```
56 | $YOUR_WORK_SPACE
57 | ├── argoverse-api
58 | ├── argo1_data
59 | ├── train
60 | │ ├── data
61 | │ │ ├── 1.csv
62 | │ │ ├── 2.csv
63 | │ │ └── ...
64 | └── val
65 | ├── data
66 | │ ├── 1.csv
67 | │ ├── 2.csv
68 | │ └── ...
69 | └── Argoverse-Terms_of_Use.txt
70 | ├── p1_data
71 | ├── train
72 | │ ├── 1.pkl
73 | │ ├── 2.pkl
74 | │ └── ...
75 | └── val
76 | ├── 1.pkl
77 | ├── 2.pkl
78 | └── ...
79 | ├── SmartRefine
80 | ```
81 | Here, each pickle file inside p1_data contains the backbone model's outputs: predicted trajectories with a shape of $[K, T, 2]$ and trajectory features shaped as $[K, -1]$, where $K$ is the number of modalities and $T$ is the trajectory length.
82 |
83 | 5\. **[Optional]** Generate your own model's prediction outputs.
84 |
85 | As mentioned in our paper, SmartRefine is designed to be decoupled from the primary prediction model backbone, and only requires a generic interface to the model backbone (predicted trajectories and trajectory features). Therefore, we present a script `eval_store.py` as an example to show how to store the backbone's outputs. The main idea is to store predicted trajectories with a key of 'traj' and trajectory features as 'embed' into a dictionary.
86 |
87 | ## Training
88 | You can train the model on a single GPU or multiple GPUs to accelerate the training process:
89 |
90 | ```bash
91 | cd $YOUR_WORK_SPACE
92 | cd SmartRefine
93 | bash train.sh
94 | ```
95 |
96 | You can change your training setting. The default `train.sh` looks like as follows:
97 | ```bash
98 | set -x
99 | # change root to your path of dataset root.
100 | data_root=../argo1_data/
101 | # change p1_root to your path of prediction outputs root.
102 | p1_root=../p1_data/
103 | # experiment name used for logging.
104 | exp=smartref_hivt_argo1
105 | # device number.
106 | ngpus=1
107 | pwd
108 |
109 | python train.py \
110 | --data_root $data_root --p1_root $p1_root --exp $exp \
111 | --train_batch_size 32 --val_batch_size 32 \
112 | --gpus $ngpus --embed_dim 64 --refine_num 5 --seg_num 2 \
113 | --refine_radius -1 --r_lo 2 --r_hi 10 \
114 | ```
115 |
116 | **_Note_**: The first training epoch will take longer because it preprocess the data at the same time. The regular training time per epoch is around 20~40 minutes varied by different hardware.
117 |
118 | The training process will be saved in `$exp/lightning_logs/` automatically. To monitor it:
119 | ```bash
120 | cd $exp
121 | tensorboard --logdir lightning_logs/
122 | ```
123 |
124 | ## Evaluation
125 | To evaluate the model performance:
126 | ```bash
127 | cd $YOUR_WORK_SPACE
128 | cd SmartRefine
129 | bash eval.sh
130 | ```
131 |
132 | ## Results
133 | ### Tabular Results
134 | The expected performance is:
135 | | Methods | minFDE | minADE | MR |
136 | | ------------ | ------ | ------ | ---- |
137 | | HiVT | 0.969 | 0.661 | 0.092 |
138 | | HiVT w/ Ours | 0.913 | 0.646 | 0.083 |
139 | ### Visualization Results
140 | The dark blue arrows are multi-nodal predictions of the agent by model and the pink arrow is the ground truth future trajectory respectively. The shortest trajectory gets more aligned toward the ground truth direction, and the trajectory closest to the ground truth gets closer after refinement.
141 | 
142 |
143 | ## Citation
144 | If you find our repo or paper useful, please cite us as:
145 |
146 | ```bibtex
147 | @misc{zhou2024smartrefine,
148 | title={SmartRefine: A Scenario-Adaptive Refinement Framework for Efficient Motion Prediction},
149 | author={Yang Zhou and Hao Shao and Letian Wang and Steven L. Waslander and Hongsheng Li and Yu Liu},
150 | year={2024},
151 | eprint={2403.11492},
152 | archivePrefix={arXiv},
153 | primaryClass={cs.CV}
154 | }
155 | ```
156 |
157 | ## Acknowledgements
158 |
159 | This implementation is based on code from other repositories.
160 | - [HiVT](https://github.com/ZikangZhou/HiVT)
161 | - [LMDrive](https://github.com/opendilab/LMDrive)
162 | - [Forecast-MAE](https://github.com/jchengai/forecast-mae)
163 |
164 | ## License
165 |
166 | All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
--------------------------------------------------------------------------------
/SmartRefine_talk.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/SmartRefine/a4561c348cf8c5b93ff888e543b2358e98f00a32/SmartRefine_talk.pdf
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/SmartRefine/a4561c348cf8c5b93ff888e543b2358e98f00a32/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/SmartRefine/a4561c348cf8c5b93ff888e543b2358e98f00a32/assets/visualization.png
--------------------------------------------------------------------------------
/ckpts/version_5709064/checkpoints/epoch=31-step=205951.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/SmartRefine/a4561c348cf8c5b93ff888e543b2358e98f00a32/ckpts/version_5709064/checkpoints/epoch=31-step=205951.ckpt
--------------------------------------------------------------------------------
/ckpts/version_5709064/events.out.tfevents.1695496506.SH-IDC1-10-5-36-118.21080.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opendilab/SmartRefine/a4561c348cf8c5b93ff888e543b2358e98f00a32/ckpts/version_5709064/events.out.tfevents.1695496506.SH-IDC1-10-5-36-118.21080.0
--------------------------------------------------------------------------------
/datamodules/__init__.py:
--------------------------------------------------------------------------------
1 | from datamodules.argoverse_v1_datamodule import ArgoverseV1DataModule
2 |
--------------------------------------------------------------------------------
/datamodules/argoverse_v1_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional
2 |
3 | from pytorch_lightning import LightningDataModule
4 | from torch_geometric.data import DataLoader
5 |
6 | from datasets import ArgoverseV1Dataset
7 |
8 |
9 | class ArgoverseV1DataModule(LightningDataModule):
10 |
11 | def __init__(self,
12 | data_root: str,
13 | p1_root: str,
14 | train_batch_size: int,
15 | val_batch_size: int,
16 | shuffle: bool = True,
17 | num_workers: int = 8,
18 | pin_memory: bool = True,
19 | persistent_workers: bool = True,
20 | train_transform: Optional[Callable] = None,
21 | val_transform: Optional[Callable] = None,
22 | # used to pre-process map data
23 | local_radius: float = 150) -> None:
24 | super(ArgoverseV1DataModule, self).__init__()
25 | self.data_root = data_root
26 | self.p1_root = p1_root
27 | self.train_batch_size = train_batch_size
28 | self.val_batch_size = val_batch_size
29 | self.shuffle = shuffle
30 | self.pin_memory = pin_memory
31 | self.persistent_workers = persistent_workers
32 | self.num_workers = num_workers
33 | self.train_transform = train_transform
34 | self.val_transform = val_transform
35 | self.local_radius = local_radius
36 |
37 | def prepare_data(self) -> None:
38 | ArgoverseV1Dataset(self.data_root, self.p1_root, 'train', self.train_transform, self.local_radius)
39 | ArgoverseV1Dataset(self.data_root, self.p1_root, 'val', self.val_transform, self.local_radius)
40 |
41 | def setup(self, stage: Optional[str] = None) -> None:
42 | self.train_dataset = ArgoverseV1Dataset(self.data_root, self.p1_root, 'train', self.train_transform, self.local_radius)
43 | self.val_dataset = ArgoverseV1Dataset(self.data_root, self.p1_root, 'val', self.val_transform, self.local_radius)
44 |
45 | def train_dataloader(self):
46 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
47 | num_workers=self.num_workers, pin_memory=self.pin_memory,
48 | persistent_workers=self.persistent_workers)
49 |
50 | def val_dataloader(self):
51 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers,
52 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers)
53 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from datasets.argoverse_v1_dataset import ArgoverseV1Dataset
2 |
--------------------------------------------------------------------------------
/datasets/argoverse_v1_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import permutations
3 | from itertools import product
4 | from typing import Callable, Dict, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from argoverse.map_representation.map_api import ArgoverseMap
10 | from torch_geometric.data import Data
11 | from torch_geometric.data import Dataset
12 | from tqdm import tqdm
13 |
14 | from utils import TemporalData
15 |
16 | import pickle
17 |
18 | class ArgoverseV1Dataset(Dataset):
19 |
20 | def __init__(self,
21 | data_root: str,
22 | p1_root: str,
23 | split: str,
24 | transform: Optional[Callable] = None,
25 | local_radius: float = 150) -> None:
26 | self._split = split
27 | self._local_radius = local_radius
28 |
29 | if split == 'sample':
30 | self._directory = 'forecasting_sample'
31 | elif split == 'train':
32 | self._directory = 'train'
33 | elif split == 'val':
34 | self._directory = 'val'
35 | elif split == 'test':
36 | self._directory = 'test_obs'
37 | else:
38 | raise ValueError(split + ' is not valid')
39 |
40 | self.data_root = data_root
41 | self.p1_root = p1_root
42 | self._raw_file_names = os.listdir(self.raw_dir)
43 |
44 | self._processed_file_names = [os.path.splitext(f)[0] + '.pkl' for f in self.raw_file_names]
45 | self._processed_paths = [os.path.join(self.processed_dir, f) for f in self._processed_file_names]
46 |
47 | self._p1_paths = [os.path.join(self.p1_root, self._directory, f) for f in self._processed_file_names]
48 |
49 | super(ArgoverseV1Dataset, self).__init__(data_root, transform=transform)
50 |
51 | @property
52 | def raw_dir(self) -> str:
53 | return os.path.join(self.data_root, self._directory, 'data')
54 |
55 | @property
56 | def processed_dir(self) -> str:
57 | return os.path.join(self.data_root, self._directory, 'processed')
58 |
59 | @property
60 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
61 | return self._raw_file_names
62 |
63 | @property
64 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
65 | return self._processed_file_names
66 |
67 | @property
68 | def processed_paths(self) -> List[str]:
69 | return self._processed_paths
70 |
71 | def process(self) -> None:
72 | am = ArgoverseMap()
73 | for raw_path in tqdm(self.raw_paths):
74 | data = process_argoverse(self._split, raw_path, am, self._local_radius)
75 | with open(os.path.join(self.processed_dir, str(data['seq_id']) + '.pkl'), 'wb') as handle:
76 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
77 |
78 | def len(self) -> int:
79 | return len(self._raw_file_names)
80 |
81 | def get(self, idx) -> Data:
82 | with open(self.processed_paths[idx], 'rb') as handle:
83 | data = pickle.load(handle)
84 | data = Data.from_dict(data)
85 | with open(self._p1_paths[idx], 'rb') as handle:
86 | p1_data = pickle.load(handle)
87 | return data, p1_data
88 |
89 |
90 | def process_argoverse(split: str,
91 | raw_path: str,
92 | am: ArgoverseMap,
93 | radius: float) -> Dict:
94 | df = pd.read_csv(raw_path)
95 |
96 | # filter out actors that are unseen during the historical time steps
97 | timestamps = list(np.sort(df['TIMESTAMP'].unique()))
98 | historical_timestamps = timestamps[: 20]
99 | historical_df = df[df['TIMESTAMP'].isin(historical_timestamps)]
100 | actor_ids = list(historical_df['TRACK_ID'].unique())
101 | df = df[df['TRACK_ID'].isin(actor_ids)]
102 | num_nodes = len(actor_ids)
103 |
104 | av_df = df[df['OBJECT_TYPE'] == 'AV'].iloc
105 | av_index = actor_ids.index(av_df[0]['TRACK_ID'])
106 |
107 | agent_df = df[df['OBJECT_TYPE'] == 'AGENT'].iloc
108 | agent_index = actor_ids.index(agent_df[0]['TRACK_ID'])
109 | city = df['CITY_NAME'].values[0]
110 |
111 | # make the scene centered at AV
112 | origin = torch.tensor([av_df[19]['X'], av_df[19]['Y']], dtype=torch.float)
113 | av_heading_vector = origin - torch.tensor([av_df[18]['X'], av_df[18]['Y']], dtype=torch.float)
114 | theta = torch.atan2(av_heading_vector[1], av_heading_vector[0])
115 | rotate_mat = torch.tensor([[torch.cos(theta), -torch.sin(theta)],
116 | [torch.sin(theta), torch.cos(theta)]])
117 |
118 | # initialization
119 | x = torch.zeros(num_nodes, 50, 2, dtype=torch.float)
120 | edge_index = torch.LongTensor(list(permutations(range(num_nodes), 2))).t().contiguous()
121 | positions_global = torch.zeros(num_nodes, 50, 2, dtype=torch.float)
122 | padding_mask = torch.ones(num_nodes, 50, dtype=torch.bool)
123 | rotate_angles = torch.zeros(num_nodes, dtype=torch.float)
124 | rotate_angles_global = torch.zeros(num_nodes, dtype=torch.float)
125 |
126 | for actor_id, actor_df in df.groupby('TRACK_ID'):
127 | node_idx = actor_ids.index(actor_id)
128 | node_steps = [timestamps.index(timestamp) for timestamp in actor_df['TIMESTAMP']]
129 | padding_mask[node_idx, node_steps] = False
130 | if padding_mask[node_idx, 19]: # make no predictions for actors that are unseen at the current time step
131 | padding_mask[node_idx, 20:] = True
132 | xy = torch.from_numpy(np.stack([actor_df['X'].values, actor_df['Y'].values], axis=-1)).float()
133 | x[node_idx, node_steps] = torch.matmul(xy - origin, rotate_mat)
134 | positions_global[node_idx, node_steps] = xy
135 | node_historical_steps = list(filter(lambda node_step: node_step < 20, node_steps))
136 | if len(node_historical_steps) > 1: # calculate the heading of the actor (approximately)
137 | heading_vector = x[node_idx, node_historical_steps[-1]] - x[node_idx, node_historical_steps[-2]]
138 | rotate_angles[node_idx] = torch.atan2(heading_vector[1], heading_vector[0])
139 | heading_vector_global = positions_global[node_idx, node_historical_steps[-1]] - positions_global[node_idx, node_historical_steps[-2]]
140 | rotate_angles_global[node_idx] = torch.atan2(heading_vector_global[1], heading_vector_global[0])
141 | else: # make no predictions for the actor if the number of valid time steps is less than 2
142 | padding_mask[node_idx, 20:] = True
143 |
144 | positions = x.clone()
145 | x[:, 20:] = torch.where((padding_mask[:, 19].unsqueeze(-1) | padding_mask[:, 20:]).unsqueeze(-1),
146 | torch.zeros(num_nodes, 30, 2),
147 | x[:, 20:] - x[:, 19].unsqueeze(-2))
148 | x[:, 1: 20] = torch.where((padding_mask[:, : 19] | padding_mask[:, 1: 20]).unsqueeze(-1),
149 | torch.zeros(num_nodes, 19, 2),
150 | x[:, 1: 20] - x[:, : 19])
151 | x[:, 0] = torch.zeros(num_nodes, 2)
152 |
153 | agent_pos = torch.tensor([agent_df[19]['X'], agent_df[19]['Y']], dtype=torch.float).reshape(1, 2)
154 | agent_ind = [agent_index]
155 | (tar_lane_positions, tar_lane_vectors, tar_is_intersections, tar_turn_directions, tar_traffic_controls, tar_id_2_idx, tar_counts, tar_len_counts) = \
156 | get_lane_features_preload(am,
157 | agent_ind,
158 | agent_pos,
159 | origin,
160 | rotate_mat,
161 | city,
162 | radius)
163 |
164 | y = None if split == 'test' else x[:, 20:]
165 | seq_id = os.path.splitext(os.path.basename(raw_path))[0]
166 |
167 | return {
168 | 'x': x[:, :20],
169 | 'positions': positions, # [N, 50, 2]
170 | 'positions_global': positions_global,
171 | 'edge_index': edge_index,
172 | 'y': y, # [N, 30, 2]
173 | 'num_nodes': num_nodes,
174 | 'padding_mask': padding_mask, # [N, 50]
175 | 'rotate_angles': rotate_angles, # [N] # av->agent
176 | 'rotate_angles_global': rotate_angles_global, # global->agent
177 |
178 | 'seq_id': int(seq_id),
179 | 'av_index': av_index,
180 | 'agent_index': agent_index,
181 | 'city': city,
182 | 'origin': origin.unsqueeze(0),
183 | 'theta': theta,
184 |
185 | #! all in av' coord
186 | 'tar_lane_positions': tar_lane_positions, # [L_, 2]
187 | 'tar_lane_vectors': tar_lane_vectors, # [L_, 2]
188 | 'tar_is_intersections': tar_is_intersections, # [L_]
189 | 'tar_turn_directions': tar_turn_directions, # [L_]
190 | 'tar_traffic_controls': tar_traffic_controls, # [L_]
191 | 'tar_lane_points_num': sum(tar_counts),
192 | }
193 |
194 |
195 | def get_lane_features_preload(am: ArgoverseMap,
196 | node_inds: List[int], # node index: origin coord
197 | node_positions: torch.Tensor, # query place: origin coord
198 | origin: torch.Tensor, # origin: origin coord
199 | rotate_mat: torch.Tensor, # rotate_mat
200 | city: str, # city: str
201 | radius: float # radius: int
202 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
203 | torch.Tensor]:
204 | lane_positions, lane_vectors, is_intersections, turn_directions, traffic_controls = [], [], [], [], []
205 | lane_ids = set()
206 | counts = []
207 | id_2_idx = {}
208 | for node_position in node_positions:
209 | # in range radius
210 | lane_ids.update(am.get_lane_ids_in_xy_bbox(node_position[0], node_position[1], city, radius))
211 | # relative pos
212 | node_positions = torch.matmul(node_positions - origin, rotate_mat).float()
213 | for i, lane_id in enumerate(lane_ids):
214 | id_2_idx[f'{lane_id}'] = i
215 | lane_centerline = torch.from_numpy(am.get_lane_segment_centerline(lane_id, city)[:, : 2]).float()
216 | lane_centerline = torch.matmul(lane_centerline - origin, rotate_mat)
217 | is_intersection = am.lane_is_in_intersection(lane_id, city)
218 | turn_direction = am.get_lane_turn_direction(lane_id, city)
219 | traffic_control = am.lane_has_traffic_control_measure(lane_id, city)
220 |
221 | lane_positions.append(lane_centerline[:-1])
222 | lane_vectors.append(lane_centerline[1:] - lane_centerline[:-1])
223 | count = len(lane_centerline) - 1
224 | counts.append(count)
225 | # braod to all point
226 | is_intersections.append(is_intersection * torch.ones(count, dtype=torch.uint8))
227 | if turn_direction == 'NONE':
228 | turn_direction = 0
229 | elif turn_direction == 'LEFT':
230 | turn_direction = 1
231 | elif turn_direction == 'RIGHT':
232 | turn_direction = 2
233 | else:
234 | raise ValueError('turn direction is not valid')
235 | turn_directions.append(turn_direction * torch.ones(count, dtype=torch.uint8))
236 | traffic_controls.append(traffic_control * torch.ones(count, dtype=torch.uint8))
237 |
238 | lane_positions = torch.cat(lane_positions, dim=0) # ok
239 | lane_vectors = torch.cat(lane_vectors, dim=0) # ok
240 | is_intersections = torch.cat(is_intersections, dim=0) # ok
241 | turn_directions = torch.cat(turn_directions, dim=0) # ok
242 | traffic_controls = torch.cat(traffic_controls, dim=0) # ok
243 |
244 | return lane_positions, lane_vectors, is_intersections, turn_directions, traffic_controls, id_2_idx, counts, len(counts)
245 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | from torch_geometric.data import DataLoader
5 |
6 | from datasets import ArgoverseV1Dataset
7 | from models.refine import Refine
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | import os
11 | import numpy as np
12 | import torch
13 | import matplotlib.pyplot as plt
14 |
15 | if __name__ == '__main__':
16 | pl.seed_everything(2024)
17 |
18 | parser = ArgumentParser()
19 | parser.add_argument('--data_root', type=str, required=True)
20 | parser.add_argument('--p1_root', type=str, required=True)
21 | parser.add_argument('--batch_size', type=int, default=32)
22 | parser.add_argument('--num_workers', type=int, default=8)
23 | parser.add_argument('--pin_memory', type=bool, default=True)
24 | parser.add_argument('--persistent_workers', type=bool, default=True)
25 | parser.add_argument('--gpus', type=int, default=1)
26 | parser.add_argument('--ckpt_dir', type=str, required=True)
27 | parser = Refine.add_model_specific_args(parser)
28 | args = parser.parse_args()
29 |
30 | trainer = pl.Trainer.from_argparse_args(args)
31 | ckpt_dir=args.ckpt_dir+'checkpoints/'
32 | ckpt_paths = [ckpt_dir+p for p in os.listdir(ckpt_dir) if p.endswith('ckpt')]
33 | ckpt_paths.sort()
34 | ckpt_path = ckpt_paths[-1]
35 |
36 | model = Refine.load_from_checkpoint(checkpoint_path=ckpt_path, seg_num=2, r_lo=2, r_hi=10, embed_dim=64, strict=False)
37 | model.eval()
38 | val_dataset = ArgoverseV1Dataset(data_root=args.data_root, p1_root=args.p1_root, split='val')
39 | dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers,
40 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers)
41 |
42 | trainer.validate(model, dataloader)
43 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | # change root to your path of dataset root.
3 | data_root=../argo1_data/
4 | p1_root=../p1_data/
5 | # the version directory of the experiment name used in training.
6 | ckpt_version=./ckpts/version_6191823/
7 | pwd
8 |
9 | python eval.py \
10 | --data_root $data_root --p1_root $p1_root \
11 | --ckpt_dir $ckpt_version \
12 | --refine_num 5 --refine_radius -1 \
13 | --embed_dim 64 \
14 |
--------------------------------------------------------------------------------
/eval_store.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | from torch_geometric.data import DataLoader
5 |
6 | from datasets import ArgoverseV1Dataset
7 | from models.hivt import HiVT
8 | import logging
9 | logger = logging.getLogger(__name__)
10 | import os
11 | import numpy as np
12 | import torch
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | from typing import Dict, List, Tuple, NamedTuple, Any, Union, Optional
16 | import os
17 | import pickle
18 | from tqdm import tqdm
19 |
20 |
21 | def compute_ade(forecasted_trajectories, gt_trajectory):
22 | """Compute the average displacement error for a set of K predicted trajectories (for the same actor).
23 |
24 | Args:
25 | forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length.
26 | gt_trajectory: (N, 2) ground truth trajectory.
27 |
28 | Returns:
29 | (K,) Average displacement error for each of the predicted trajectories.
30 | """
31 | # displacement_errors = np.mean(np.linalg.norm(forecasted_trajectories - gt_trajectory, axis=-1), 1)
32 | displacement_errors = np.sqrt(np.sum((forecasted_trajectories - gt_trajectory)**2, -1))
33 | ade = np.mean(displacement_errors, axis=-1)
34 | return ade
35 |
36 |
37 | def compute_fde(forecasted_trajectories, gt_trajectory):
38 | """Compute the final displacement error for a set of K predicted trajectories (for the same actor).
39 |
40 | Args:
41 | forecasted_trajectories: (K, N, 2) predicted trajectories, each N timestamps in length.
42 | gt_trajectory: (N, 2) ground truth trajectory, FDE will be evaluated against true position at index `N-1`.
43 |
44 | Returns:
45 | (K,) Final displacement error for each of the predicted trajectories.
46 | """
47 | # Compute final displacement error for all K trajectories
48 | error_vector = forecasted_trajectories - gt_trajectory
49 | fde_vector = error_vector[:, -1]
50 | fde = np.linalg.norm(fde_vector, axis=-1)
51 | return fde
52 |
53 |
54 | class Metric:
55 | def __init__(self):
56 | self.values = []
57 |
58 | def accumulate(self, value):
59 | if value is not None:
60 | self.values.append(value)
61 |
62 | def get_mean(self):
63 | if len(self.values) > 0:
64 | return np.mean(self.values)
65 | else:
66 | return 0.0
67 |
68 | def get_sum(self):
69 | return np.sum(self.values)
70 |
71 |
72 | class PredictionMetrics:
73 | def __init__(self):
74 | self.minADE = Metric()
75 | self.minFDE = Metric()
76 | self.MR = Metric()
77 | self.brier_minFDE = Metric()
78 |
79 | def serialize(self) -> Dict[str, Any]:
80 | return dict(
81 | minADE=float(self.minADE.get_mean()),
82 | minFDE=float(self.minFDE.get_mean()),
83 | MR=float(self.MR.get_mean()),
84 | brier_minFDE=float(self.brier_minFDE.get_mean()),
85 | )
86 |
87 |
88 | if __name__ == '__main__':
89 |
90 | #! set split first.
91 | split='train'
92 |
93 | #! prepare your model, dataloader configuration here.
94 | model = None
95 | dataloader = None
96 |
97 | processed_dir = './p1/'
98 | model.to("cuda")
99 | model.eval()
100 | metrics = PredictionMetrics()
101 | for data in tqdm(dataloader):
102 | data.to("cuda")
103 | with torch.no_grad():
104 | #! infer your model here and output trajectory and embeddings.
105 | pred_trajectory = None # [K, N, T, 2]
106 | embeds = None # [K, N, -1]
107 |
108 | file_names = None # data ids
109 | gt_eval = None # ground-truth: [N, T, 2]
110 |
111 | embeds = embeds.transpose(0,1).detach().cpu().numpy()
112 | pred_trajectory = pred_trajectory.detach().cpu().numpy()
113 | gt_eval = gt_eval.detach().cpu().numpy()
114 | for i in range(gt_trajectory.shape[0]):
115 | forecasted_trajectories = pred_trajectory[i][:, :, :]
116 | gt_trajectory = gt_eval[i][:,:]
117 | #! make sure the file name is the same with original id in dataset.
118 | raw_file_name = file_names[i]
119 |
120 | #! dict to store..
121 | dict_data = {
122 | 'traj': torch.from_numpy(pred_trajectory[i].copy().astype(np.float32)),
123 | 'embed': torch.from_numpy(embeds[i].copy().astype(np.float32)),
124 | }
125 | with open(os.path.join(processed_dir, split, f'{raw_file_name}.pkl'), 'wb') as handle:
126 | pickle.dump(dict_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
127 |
128 | assert forecasted_trajectories.shape == (6, 30, 2)
129 | assert gt_trajectory.shape == (30, 2)
130 |
131 | fde = compute_fde(forecasted_trajectories, gt_trajectory)
132 | idx = fde.argmin()
133 | ade = compute_ade(forecasted_trajectories[idx], gt_trajectory)
134 |
135 | metrics.minADE.accumulate(ade.min())
136 | metrics.minFDE.accumulate(fde.min())
137 | metrics.MR.accumulate(fde.min() > 2.0)
138 | import json
139 | print('Metrics:')
140 | print(json.dumps(metrics.serialize(), indent=4))
141 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from losses.laplace_nll_loss import LaplaceNLLLoss
2 | from losses.soft_target_cross_entropy_loss import SoftTargetCrossEntropyLoss
3 | from losses.score_reg_l1_loss import ScoreRegL1Loss
4 |
--------------------------------------------------------------------------------
/losses/laplace_nll_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LaplaceNLLLoss(nn.Module):
6 |
7 | def __init__(self,
8 | eps: float = 1e-6,
9 | reduction: str = 'mean') -> None:
10 | super(LaplaceNLLLoss, self).__init__()
11 | self.eps = eps
12 | self.reduction = reduction
13 |
14 | def forward(self,
15 | pred: torch.Tensor,
16 | target: torch.Tensor) -> torch.Tensor:
17 | loc, scale = pred.chunk(2, dim=-1)
18 | scale = scale.clone()
19 | with torch.no_grad():
20 | scale.clamp_(min=self.eps)
21 | nll = torch.log(2 * scale) + torch.abs(target - loc) / scale
22 | if self.reduction == 'mean':
23 | return nll.mean()
24 | elif self.reduction == 'sum':
25 | return nll.sum()
26 | elif self.reduction == 'none':
27 | return nll
28 | else:
29 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction))
30 |
--------------------------------------------------------------------------------
/losses/score_reg_l1_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ScoreRegL1Loss(nn.Module):
7 |
8 | def __init__(self, reduction: str = 'mean') -> None:
9 | super(ScoreRegL1Loss, self).__init__()
10 | self.loss = nn.L1Loss(reduction=reduction)
11 |
12 | def forward(self,
13 | pred: torch.Tensor,
14 | target: torch.Tensor) -> torch.Tensor:
15 | if pred.shape[0] == 0:
16 | return 0
17 | return self.loss(pred, target)
18 |
--------------------------------------------------------------------------------
/losses/soft_target_cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SoftTargetCrossEntropyLoss(nn.Module):
7 |
8 | def __init__(self, reduction: str = 'mean') -> None:
9 | super(SoftTargetCrossEntropyLoss, self).__init__()
10 | self.reduction = reduction
11 |
12 | def forward(self,
13 | pred: torch.Tensor,
14 | target: torch.Tensor) -> torch.Tensor:
15 | cross_entropy = torch.sum(-target * F.log_softmax(pred, dim=-1), dim=-1)
16 | if self.reduction == 'mean':
17 | return cross_entropy.mean()
18 | elif self.reduction == 'sum':
19 | return cross_entropy.sum()
20 | elif self.reduction == 'none':
21 | return cross_entropy
22 | else:
23 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction))
24 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from metrics.ade import ADE
2 | from metrics.fde import FDE
3 | from metrics.mr import MR
4 |
--------------------------------------------------------------------------------
/metrics/ade.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class ADE(Metric):
8 |
9 | def __init__(self,
10 | **kwargs) -> None:
11 | super(ADE, self).__init__(**kwargs)
12 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
13 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
14 |
15 | def update(self,
16 | pred: torch.Tensor,
17 | target: torch.Tensor) -> None:
18 | self.sum += torch.norm(pred - target, p=2, dim=-1).mean(dim=-1).sum()
19 | self.count += pred.size(0)
20 |
21 | def compute(self) -> torch.Tensor:
22 | return self.sum / self.count
23 |
--------------------------------------------------------------------------------
/metrics/fde.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class FDE(Metric):
8 |
9 | def __init__(self,
10 | **kwargs) -> None:
11 | super(FDE, self).__init__(**kwargs)
12 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
13 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
14 |
15 | def update(self,
16 | pred: torch.Tensor,
17 | target: torch.Tensor=None) -> None:
18 | if target is not None:
19 | self.sum += torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1).sum()
20 | self.count += pred.size(0)
21 | else:
22 | self.sum += pred.sum()
23 | self.count += pred.size(0)
24 |
25 | def compute(self) -> torch.Tensor:
26 | return self.sum / self.count
27 |
--------------------------------------------------------------------------------
/metrics/mr.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class MR(Metric):
8 |
9 | def __init__(self,
10 | miss_threshold: float = 2.0,
11 | **kwargs,) -> None:
12 | super(MR, self).__init__(**kwargs)
13 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
14 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
15 | self.miss_threshold = miss_threshold
16 |
17 | def update(self,
18 | pred: torch.Tensor,
19 | target: torch.Tensor) -> None:
20 | self.sum += (torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1) > self.miss_threshold).sum()
21 | self.count += pred.size(0)
22 |
23 | def compute(self) -> torch.Tensor:
24 | return self.sum / self.count
25 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.decoder import GRUDecoder
2 | from models.decoder import MLPDecoder, MLPDeltaDecoder, MLPDeltaDecoderPi, MLPDeltaDecoderScore
3 | from models.embedding import MultipleInputEmbedding
4 | from models.embedding import SingleInputEmbedding
5 | from models.local_encoder import ALEncoder, ALEncoderWithAo
6 | from models.target_region import TargetRegion
7 |
--------------------------------------------------------------------------------
/models/decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from utils import init_weights
8 |
9 |
10 | class GRUDecoder(nn.Module):
11 |
12 | def __init__(self,
13 | local_channels: int,
14 | global_channels: int,
15 | future_steps: int,
16 | num_modes: int,
17 | uncertain: bool = True,
18 | min_scale: float = 1e-3) -> None:
19 | super(GRUDecoder, self).__init__()
20 | self.input_size = global_channels
21 | self.hidden_size = local_channels
22 | self.future_steps = future_steps
23 | self.num_modes = num_modes
24 | self.uncertain = uncertain
25 | self.min_scale = min_scale
26 |
27 | self.gru = nn.GRU(input_size=self.input_size,
28 | hidden_size=self.hidden_size,
29 | num_layers=1,
30 | bias=True,
31 | batch_first=False,
32 | dropout=0,
33 | bidirectional=False)
34 | self.loc = nn.Sequential(
35 | nn.Linear(self.hidden_size, self.hidden_size),
36 | nn.LayerNorm(self.hidden_size),
37 | nn.ReLU(inplace=True),
38 | nn.Linear(self.hidden_size, 2))
39 | if uncertain:
40 | self.scale = nn.Sequential(
41 | nn.Linear(self.hidden_size, self.hidden_size),
42 | nn.LayerNorm(self.hidden_size),
43 | nn.ReLU(inplace=True),
44 | nn.Linear(self.hidden_size, 2))
45 | self.pi = nn.Sequential(
46 | nn.Linear(self.hidden_size + self.input_size, self.hidden_size),
47 | nn.LayerNorm(self.hidden_size),
48 | nn.ReLU(inplace=True),
49 | nn.Linear(self.hidden_size, self.hidden_size),
50 | nn.LayerNorm(self.hidden_size),
51 | nn.ReLU(inplace=True),
52 | nn.Linear(self.hidden_size, 1))
53 | self.apply(init_weights)
54 |
55 | def forward(self,
56 | local_embed: torch.Tensor,
57 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58 | pi = self.pi(torch.cat((local_embed.expand(self.num_modes, *local_embed.shape),
59 | global_embed), dim=-1)).squeeze(-1).t()
60 | global_embed = global_embed.reshape(-1, self.input_size) # [F x N, D]
61 | global_embed = global_embed.expand(self.future_steps, *global_embed.shape) # [H, F x N, D]
62 | local_embed = local_embed.repeat(self.num_modes, 1).unsqueeze(0) # [1, F x N, D]
63 | out, _ = self.gru(global_embed, local_embed)
64 | out = out.transpose(0, 1) # [F x N, H, D]
65 | loc = self.loc(out) # [F x N, H, 2]
66 | if self.uncertain:
67 | scale = F.elu_(self.scale(out), alpha=1.0) + 1.0 + self.min_scale # [F x N, H, 2]
68 | return torch.cat((loc, scale),
69 | dim=-1).view(self.num_modes, -1, self.future_steps, 4), pi # [F, N, H, 4], [N, F]
70 | else:
71 | return loc.view(self.num_modes, -1, self.future_steps, 2), pi # [F, N, H, 2], [N, F]
72 |
73 |
74 | class MLPDecoder(nn.Module):
75 |
76 | def __init__(self,
77 | local_channels: int,
78 | global_channels: int,
79 | future_steps: int,
80 | num_modes: int,
81 | uncertain: bool = True,
82 | min_scale: float = 1e-3) -> None:
83 | super(MLPDecoder, self).__init__()
84 | self.input_size = global_channels
85 | self.hidden_size = local_channels
86 | self.future_steps = future_steps
87 | self.num_modes = num_modes
88 | self.uncertain = uncertain
89 | self.min_scale = min_scale
90 |
91 | self.aggr_embed = nn.Sequential(
92 | nn.Linear(self.input_size + self.hidden_size, self.hidden_size),
93 | nn.LayerNorm(self.hidden_size),
94 | nn.ReLU(inplace=True))
95 | self.loc = nn.Sequential(
96 | nn.Linear(self.hidden_size, self.hidden_size),
97 | nn.LayerNorm(self.hidden_size),
98 | nn.ReLU(inplace=True),
99 | nn.Linear(self.hidden_size, self.future_steps * 2))
100 | if uncertain:
101 | self.scale = nn.Sequential(
102 | nn.Linear(self.hidden_size, self.hidden_size),
103 | nn.LayerNorm(self.hidden_size),
104 | nn.ReLU(inplace=True),
105 | nn.Linear(self.hidden_size, self.future_steps * 2))
106 | self.pi = nn.Sequential(
107 | nn.Linear(self.hidden_size + self.input_size, self.hidden_size),
108 | nn.LayerNorm(self.hidden_size),
109 | nn.ReLU(inplace=True),
110 | nn.Linear(self.hidden_size, self.hidden_size),
111 | nn.LayerNorm(self.hidden_size),
112 | nn.ReLU(inplace=True),
113 | nn.Linear(self.hidden_size, 1))
114 | self.apply(init_weights)
115 |
116 | def forward(self,
117 | local_embed: torch.Tensor,
118 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
119 | pi = self.pi(torch.cat((local_embed.expand(self.num_modes, *local_embed.shape),
120 | global_embed), dim=-1)).squeeze(-1).t()
121 | out = self.aggr_embed(torch.cat((global_embed, local_embed.expand(self.num_modes, *local_embed.shape)), dim=-1))
122 | loc = self.loc(out).view(self.num_modes, -1, self.future_steps, 2) # [F, N, H, 2]
123 | if self.uncertain:
124 | scale = F.elu_(self.scale(out), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2) + 1.0
125 | scale = scale + self.min_scale # [F, N, H, 2]
126 | return torch.cat((loc, scale), dim=-1), pi, out # [F, N, H, 4], [N, F]
127 | else:
128 | return loc, pi # [F, N, H, 2], [N, F]
129 |
130 |
131 | class MLPDeltaDecoder(nn.Module):
132 |
133 | def __init__(self,
134 | local_channels: int,
135 | global_channels: int,
136 | future_steps: int,
137 | num_modes: int,
138 | with_cumsum:int=0,
139 | uncertain: bool = True,
140 | min_scale: float = 1e-3) -> None:
141 | super(MLPDeltaDecoder, self).__init__()
142 | self.input_size = global_channels
143 | self.hidden_size = local_channels
144 | self.future_steps = future_steps
145 | self.num_modes = num_modes
146 | self.uncertain = uncertain
147 | self.min_scale = min_scale
148 | self.with_cumsum = False if with_cumsum==0 else True
149 |
150 | self.loc = nn.Sequential(
151 | nn.Linear(self.hidden_size, self.hidden_size),
152 | nn.LayerNorm(self.hidden_size),
153 | nn.ReLU(inplace=True),
154 | nn.Linear(self.hidden_size, self.future_steps * 2))
155 | if uncertain:
156 | self.scale = nn.Sequential(
157 | nn.Linear(self.hidden_size, self.hidden_size),
158 | nn.LayerNorm(self.hidden_size),
159 | nn.ReLU(inplace=True),
160 | nn.Linear(self.hidden_size, self.future_steps * 2))
161 | # self.pi = nn.Sequential(
162 | # nn.Linear(self.hidden_size, self.hidden_size),
163 | # nn.LayerNorm(self.hidden_size),
164 | # nn.ReLU(inplace=True),
165 | # nn.Linear(self.hidden_size, 1))
166 | self.apply(init_weights)
167 |
168 | def forward(self,
169 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
170 | # pi = self.pi(global_embed).squeeze(-1).t()
171 | loc = self.loc(global_embed).view(self.num_modes, -1, self.future_steps, 2) # [F, N, H, 2]
172 | if self.uncertain:
173 | if not self.with_cumsum:
174 | scale = F.elu_(self.scale(global_embed), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2) + 1.0
175 | # scale = F.elu_(self.scale(global_embed), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2)
176 | scale = scale + self.min_scale # [F, N, H, 2]
177 | else:
178 | # only to (0,+inf)
179 | scale = F.elu_(self.scale(global_embed), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2) + 1.0
180 | return torch.cat((loc, scale), dim=-1) # [F, N, H, 4], [N, F]
181 | else:
182 | return loc # [F, N, H, 2], [N, F]
183 |
184 |
185 |
186 | class MLPDeltaDecoderPi(nn.Module):
187 |
188 | def __init__(self,
189 | embed_dim:int) -> None:
190 | super(MLPDeltaDecoderPi, self).__init__()
191 | self.hidden_size = embed_dim
192 |
193 | self.pi = nn.Sequential(
194 | nn.Linear(self.hidden_size, self.hidden_size),
195 | nn.LayerNorm(self.hidden_size),
196 | nn.ReLU(inplace=True),
197 | nn.Linear(self.hidden_size, 1))
198 | self.apply(init_weights)
199 |
200 | def forward(self,
201 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
202 | pi = self.pi(global_embed).squeeze(-1).t()
203 | return pi
204 |
205 |
206 |
207 | class MLPDeltaDecoderScore(nn.Module):
208 |
209 | def __init__(self,
210 | embed_dim:int,
211 | with_last:bool=False) -> None:
212 | super(MLPDeltaDecoderScore, self).__init__()
213 | self.hidden_size = embed_dim
214 | self.with_last = with_last
215 | if not self.with_last:
216 | self.pi = nn.Sequential(
217 | nn.Linear(self.hidden_size, self.hidden_size),
218 | nn.LayerNorm(self.hidden_size),
219 | nn.ReLU(inplace=True),
220 | nn.Linear(self.hidden_size, 1),
221 | # nn.Sigmoid()
222 | nn.Tanh()
223 | )
224 | # cross entropy
225 | else:
226 | self.pi = nn.Sequential(
227 | nn.Linear(self.hidden_size*2, self.hidden_size),
228 | nn.LayerNorm(self.hidden_size),
229 | nn.ReLU(inplace=True),
230 | nn.Linear(self.hidden_size, 1),
231 | nn.Sigmoid()
232 | )
233 | self.apply(init_weights)
234 |
235 | def forward(self,
236 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
237 | pi = self.pi(global_embed).squeeze(-1).transpose(0,1)
238 | # 0, 1
239 | pi = (pi+1)/2
240 | return pi
--------------------------------------------------------------------------------
/models/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from utils import init_weights
7 |
8 |
9 | class SingleInputEmbedding(nn.Module):
10 |
11 | def __init__(self,
12 | in_channel: int,
13 | out_channel: int) -> None:
14 | super(SingleInputEmbedding, self).__init__()
15 | self.embed = nn.Sequential(
16 | nn.Linear(in_channel, out_channel),
17 | nn.LayerNorm(out_channel),
18 | nn.ReLU(inplace=True),
19 | nn.Linear(out_channel, out_channel),
20 | nn.LayerNorm(out_channel),
21 | nn.ReLU(inplace=True),
22 | nn.Linear(out_channel, out_channel),
23 | nn.LayerNorm(out_channel))
24 | self.apply(init_weights)
25 |
26 | def forward(self, x: torch.Tensor) -> torch.Tensor:
27 | return self.embed(x)
28 |
29 |
30 | class MultipleInputEmbedding(nn.Module):
31 |
32 | def __init__(self,
33 | in_channels: List[int],
34 | out_channel: int) -> None:
35 | super(MultipleInputEmbedding, self).__init__()
36 | self.module_list = nn.ModuleList(
37 | [nn.Sequential(nn.Linear(in_channel, out_channel),
38 | nn.LayerNorm(out_channel),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(out_channel, out_channel))
41 | for in_channel in in_channels])
42 | self.aggr_embed = nn.Sequential(
43 | nn.LayerNorm(out_channel),
44 | nn.ReLU(inplace=True),
45 | nn.Linear(out_channel, out_channel),
46 | nn.LayerNorm(out_channel))
47 | self.apply(init_weights)
48 |
49 | def forward(self,
50 | continuous_inputs: List[torch.Tensor],
51 | categorical_inputs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
52 | for i in range(len(self.module_list)):
53 | continuous_inputs[i] = self.module_list[i](continuous_inputs[i])
54 | output = torch.stack(continuous_inputs).sum(dim=0)
55 | if categorical_inputs is not None:
56 | output += torch.stack(categorical_inputs).sum(dim=0)
57 | return self.aggr_embed(output)
58 |
--------------------------------------------------------------------------------
/models/local_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch_geometric.data import Batch
7 | from torch_geometric.data import Data
8 | from torch_geometric.nn.conv import MessagePassing
9 | from torch_geometric.typing import Adj
10 | from torch_geometric.typing import OptTensor
11 | from torch_geometric.typing import Size
12 | from torch_geometric.utils import softmax
13 | from torch_geometric.utils import subgraph
14 |
15 | from models import MultipleInputEmbedding
16 | from models import SingleInputEmbedding
17 | from utils import DistanceDropEdge
18 | from utils import TemporalData
19 | from utils import init_weights
20 |
21 |
22 | class ALEncoder(MessagePassing):
23 |
24 | def __init__(self,
25 | node_dim: int,
26 | edge_dim: int,
27 | embed_dim: int,
28 | num_heads: int = 8,
29 | dropout: float = 0.1,
30 | **kwargs) -> None:
31 | super(ALEncoder, self).__init__(aggr='add', node_dim=0, **kwargs)
32 | self.embed_dim = embed_dim
33 | self.num_heads = num_heads
34 |
35 | self.lane_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim], out_channel=embed_dim)
36 | self.lin_q = nn.Linear(embed_dim, embed_dim)
37 | self.lin_k = nn.Linear(embed_dim, embed_dim)
38 | self.lin_v = nn.Linear(embed_dim, embed_dim)
39 | self.lin_self = nn.Linear(embed_dim, embed_dim)
40 | self.attn_drop = nn.Dropout(dropout)
41 | self.lin_ih = nn.Linear(embed_dim, embed_dim)
42 | self.lin_hh = nn.Linear(embed_dim, embed_dim)
43 | self.out_proj = nn.Linear(embed_dim, embed_dim)
44 | self.proj_drop = nn.Dropout(dropout)
45 | self.norm1 = nn.LayerNorm(embed_dim)
46 | self.norm2 = nn.LayerNorm(embed_dim)
47 | self.mlp = nn.Sequential(
48 | nn.Linear(embed_dim, embed_dim * 4),
49 | nn.ReLU(inplace=True),
50 | nn.Dropout(dropout),
51 | nn.Linear(embed_dim * 4, embed_dim),
52 | nn.Dropout(dropout))
53 | self.is_intersection_embed = nn.Parameter(torch.Tensor(2, embed_dim))
54 | self.turn_direction_embed = nn.Parameter(torch.Tensor(3, embed_dim))
55 | self.traffic_control_embed = nn.Parameter(torch.Tensor(2, embed_dim))
56 | nn.init.normal_(self.is_intersection_embed, mean=0., std=.02)
57 | nn.init.normal_(self.turn_direction_embed, mean=0., std=.02)
58 | nn.init.normal_(self.traffic_control_embed, mean=0., std=.02)
59 | self.apply(init_weights)
60 |
61 | def forward(self,
62 | x: Tuple[torch.Tensor, torch.Tensor],
63 | edge_index: Adj,
64 | edge_attr: torch.Tensor,
65 | is_intersections: torch.Tensor,
66 | turn_directions: torch.Tensor,
67 | traffic_controls: torch.Tensor,
68 | rotate_mat: Optional[torch.Tensor] = None,
69 | size: Size = None) -> torch.Tensor:
70 | x_lane, x_actor = x
71 | is_intersections = is_intersections.long()
72 | turn_directions = turn_directions.long()
73 | traffic_controls = traffic_controls.long()
74 | x_actor = x_actor + self._mha_block(self.norm1(x_actor), x_lane, edge_index, edge_attr, is_intersections,
75 | turn_directions, traffic_controls, rotate_mat, size)
76 | x_actor = x_actor + self._ff_block(self.norm2(x_actor))
77 | return x_actor
78 |
79 | def message(self,
80 | edge_index: Adj,
81 | x_i: torch.Tensor,
82 | x_j: torch.Tensor,
83 | edge_attr: torch.Tensor,
84 | is_intersections_j,
85 | turn_directions_j,
86 | traffic_controls_j,
87 | rotate_mat: Optional[torch.Tensor],
88 | index: torch.Tensor,
89 | ptr: OptTensor,
90 | size_i: Optional[int]) -> torch.Tensor:
91 | if rotate_mat is None:
92 | x_j = self.lane_embed([x_j, edge_attr],
93 | [self.is_intersection_embed[is_intersections_j],
94 | self.turn_direction_embed[turn_directions_j],
95 | self.traffic_control_embed[traffic_controls_j]])
96 | else:
97 | # import pdb
98 | # pdb.set_trace()
99 | rotate_mat = rotate_mat[edge_index[1]]
100 | x_j = self.lane_embed([torch.bmm(x_j.unsqueeze(-2), rotate_mat).squeeze(-2),
101 | torch.bmm(edge_attr.unsqueeze(-2), rotate_mat).squeeze(-2)],
102 | [self.is_intersection_embed[is_intersections_j],
103 | self.turn_direction_embed[turn_directions_j],
104 | self.traffic_control_embed[traffic_controls_j]])
105 | query = self.lin_q(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads)
106 | key = self.lin_k(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
107 | value = self.lin_v(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
108 | scale = (self.embed_dim // self.num_heads) ** 0.5
109 | alpha = (query * key).sum(dim=-1) / scale
110 | alpha = softmax(alpha, index, ptr, size_i)
111 | alpha = self.attn_drop(alpha)
112 | return value * alpha.unsqueeze(-1)
113 |
114 | def update(self,
115 | inputs: torch.Tensor,
116 | x: torch.Tensor) -> torch.Tensor:
117 | x_actor = x[1]
118 | inputs = inputs.view(-1, self.embed_dim)
119 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x_actor))
120 | return inputs + gate * (self.lin_self(x_actor) - inputs)
121 |
122 | def _mha_block(self,
123 | x_actor: torch.Tensor,
124 | x_lane: torch.Tensor,
125 | edge_index: Adj,
126 | edge_attr: torch.Tensor,
127 | is_intersections: torch.Tensor,
128 | turn_directions: torch.Tensor,
129 | traffic_controls: torch.Tensor,
130 | rotate_mat: Optional[torch.Tensor],
131 | size: Size) -> torch.Tensor:
132 | # import pdb
133 | # pdb.set_trace()
134 | x_actor = self.out_proj(self.propagate(edge_index=edge_index, x=(x_lane, x_actor), edge_attr=edge_attr,
135 | is_intersections=is_intersections, turn_directions=turn_directions,
136 | traffic_controls=traffic_controls, rotate_mat=rotate_mat, size=size))
137 | return self.proj_drop(x_actor)
138 |
139 | def _ff_block(self, x_actor: torch.Tensor) -> torch.Tensor:
140 | return self.mlp(x_actor)
141 |
142 |
143 | class ALEncoderWithAo(MessagePassing):
144 |
145 | def __init__(self,
146 | node_dim: int,
147 | edge_dim: int,
148 | embed_dim: int,
149 | num_heads: int = 8,
150 | dropout: float = 0.1,
151 | **kwargs) -> None:
152 | super(ALEncoderWithAo, self).__init__(aggr='add', node_dim=0, **kwargs)
153 | self.embed_dim = embed_dim
154 | self.num_heads = num_heads
155 |
156 | self.lane_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim, node_dim], out_channel=embed_dim)
157 | self.lin_q = nn.Linear(embed_dim, embed_dim)
158 | self.lin_k = nn.Linear(embed_dim, embed_dim)
159 | self.lin_v = nn.Linear(embed_dim, embed_dim)
160 | self.lin_self = nn.Linear(embed_dim, embed_dim)
161 | self.attn_drop = nn.Dropout(dropout)
162 | self.lin_ih = nn.Linear(embed_dim, embed_dim)
163 | self.lin_hh = nn.Linear(embed_dim, embed_dim)
164 | self.out_proj = nn.Linear(embed_dim, embed_dim)
165 | self.proj_drop = nn.Dropout(dropout)
166 | self.norm1 = nn.LayerNorm(embed_dim)
167 | self.norm2 = nn.LayerNorm(embed_dim)
168 | self.mlp = nn.Sequential(
169 | nn.Linear(embed_dim, embed_dim * 4),
170 | nn.ReLU(inplace=True),
171 | nn.Dropout(dropout),
172 | nn.Linear(embed_dim * 4, embed_dim),
173 | nn.Dropout(dropout))
174 | self.is_intersection_embed = nn.Parameter(torch.Tensor(2, embed_dim))
175 | self.turn_direction_embed = nn.Parameter(torch.Tensor(3, embed_dim))
176 | self.traffic_control_embed = nn.Parameter(torch.Tensor(2, embed_dim))
177 | nn.init.normal_(self.is_intersection_embed, mean=0., std=.02)
178 | nn.init.normal_(self.turn_direction_embed, mean=0., std=.02)
179 | nn.init.normal_(self.traffic_control_embed, mean=0., std=.02)
180 | self.apply(init_weights)
181 |
182 | def forward(self,
183 | x: Tuple[torch.Tensor, torch.Tensor],
184 | edge_index: Adj,
185 | edge_attr: torch.Tensor,
186 | is_intersections: torch.Tensor,
187 | turn_directions: torch.Tensor,
188 | traffic_controls: torch.Tensor,
189 | vec_ao:torch.Tensor,
190 | rotate_mat: Optional[torch.Tensor] = None,
191 | size: Size = None) -> torch.Tensor:
192 | x_lane, x_actor = x
193 | is_intersections = is_intersections.long()
194 | turn_directions = turn_directions.long()
195 | traffic_controls = traffic_controls.long()
196 | x_actor = x_actor + self._mha_block(self.norm1(x_actor), x_lane, edge_index, edge_attr, is_intersections,
197 | turn_directions, traffic_controls, vec_ao, rotate_mat, size)
198 | x_actor = x_actor + self._ff_block(self.norm2(x_actor))
199 | return x_actor
200 |
201 | def message(self,
202 | edge_index: Adj,
203 | x_i: torch.Tensor,
204 | x_j: torch.Tensor,
205 | edge_attr: torch.Tensor,
206 | is_intersections_j,
207 | turn_directions_j,
208 | traffic_controls_j,
209 | vec_ao,
210 | rotate_mat: Optional[torch.Tensor],
211 | index: torch.Tensor,
212 | ptr: OptTensor,
213 | size_i: Optional[int]) -> torch.Tensor:
214 | if rotate_mat is None:
215 | x_j = self.lane_embed([x_j, edge_attr],
216 | [self.is_intersection_embed[is_intersections_j],
217 | self.turn_direction_embed[turn_directions_j],
218 | self.traffic_control_embed[traffic_controls_j]])
219 | else:
220 | # import pdb
221 | # pdb.set_trace()
222 | rotate_mat = rotate_mat[edge_index[1]]
223 |
224 | vec_ao = vec_ao[edge_index[1]]
225 | # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
226 | # starter.record()
227 | # test = self.is_intersection_embed[is_intersections_j]
228 | # # ender.record()
229 | # # # WAIT FOR GPU SYNC
230 | # # torch.cuda.synchronize()
231 | # # curr_time = starter.elapsed_time(ender)
232 | # # print(f'para index: {curr_time}')
233 | x_j = self.lane_embed([torch.bmm(x_j.unsqueeze(-2), rotate_mat).squeeze(-2),
234 | torch.bmm(edge_attr.unsqueeze(-2), rotate_mat).squeeze(-2),
235 | torch.bmm(vec_ao.unsqueeze(-2), rotate_mat).squeeze(-2)],
236 | [self.is_intersection_embed[is_intersections_j],
237 | self.turn_direction_embed[turn_directions_j],
238 | self.traffic_control_embed[traffic_controls_j]])
239 | query = self.lin_q(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads)
240 | key = self.lin_k(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
241 | value = self.lin_v(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
242 | scale = (self.embed_dim // self.num_heads) ** 0.5
243 | alpha = (query * key).sum(dim=-1) / scale
244 | alpha = softmax(alpha, index, ptr, size_i)
245 | alpha = self.attn_drop(alpha)
246 | return value * alpha.unsqueeze(-1)
247 |
248 | def update(self,
249 | inputs: torch.Tensor,
250 | x: torch.Tensor) -> torch.Tensor:
251 | x_actor = x[1]
252 | inputs = inputs.view(-1, self.embed_dim)
253 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x_actor))
254 | return inputs + gate * (self.lin_self(x_actor) - inputs)
255 |
256 | def _mha_block(self,
257 | x_actor: torch.Tensor,
258 | x_lane: torch.Tensor,
259 | edge_index: Adj,
260 | edge_attr: torch.Tensor,
261 | is_intersections: torch.Tensor,
262 | turn_directions: torch.Tensor,
263 | traffic_controls: torch.Tensor,
264 | vec_ao:torch.Tensor,
265 | rotate_mat: Optional[torch.Tensor],
266 | size: Size) -> torch.Tensor:
267 | # import pdb
268 | # pdb.set_trace()
269 | x_actor = self.out_proj(self.propagate(edge_index=edge_index, x=(x_lane, x_actor), edge_attr=edge_attr,
270 | is_intersections=is_intersections, turn_directions=turn_directions,
271 | traffic_controls=traffic_controls, vec_ao = vec_ao,
272 | rotate_mat=rotate_mat, size=size))
273 | return self.proj_drop(x_actor)
274 |
275 | def _ff_block(self, x_actor: torch.Tensor) -> torch.Tensor:
276 | return self.mlp(x_actor)
277 |
278 |
279 | class AttentionLayer(MessagePassing):
280 |
281 | def __init__(self,
282 | hidden_dim: int,
283 | num_heads: int,
284 | head_dim: int,
285 | dropout: float,
286 | bipartite: bool,
287 | has_pos_emb: bool,
288 | **kwargs) -> None:
289 | super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
290 | self.num_heads = num_heads
291 | self.head_dim = head_dim
292 | self.has_pos_emb = has_pos_emb
293 | self.scale = head_dim ** -0.5
294 |
295 | self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
296 | self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
297 | self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
298 | if has_pos_emb:
299 | self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
300 | self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
301 | self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
302 | self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
303 | self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
304 | self.attn_drop = nn.Dropout(dropout)
305 | self.ff_mlp = nn.Sequential(
306 | nn.Linear(hidden_dim, hidden_dim * 4),
307 | nn.ReLU(inplace=True),
308 | nn.Dropout(dropout),
309 | nn.Linear(hidden_dim * 4, hidden_dim),
310 | )
311 | if bipartite:
312 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
313 | self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
314 | else:
315 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
316 | self.attn_prenorm_x_dst = self.attn_prenorm_x_src
317 | if has_pos_emb:
318 | self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
319 | self.attn_postnorm = nn.LayerNorm(hidden_dim)
320 | self.ff_prenorm = nn.LayerNorm(hidden_dim)
321 | self.ff_postnorm = nn.LayerNorm(hidden_dim)
322 | self.apply(init_weights)
323 |
324 | def forward(self,
325 | x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
326 | r: Optional[torch.Tensor],
327 | edge_index: torch.Tensor) -> torch.Tensor:
328 | if isinstance(x, torch.Tensor):
329 | x_src = x_dst = self.attn_prenorm_x_src(x)
330 | else:
331 | x_src, x_dst = x
332 | x_src = self.attn_prenorm_x_src(x_src)
333 | x_dst = self.attn_prenorm_x_dst(x_dst)
334 | x = x[1]
335 | if self.has_pos_emb and r is not None:
336 | r = self.attn_prenorm_r(r)
337 | x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
338 | x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
339 | return x
340 |
341 | def message(self,
342 | q_i: torch.Tensor,
343 | k_j: torch.Tensor,
344 | v_j: torch.Tensor,
345 | r: Optional[torch.Tensor],
346 | index: torch.Tensor,
347 | ptr: Optional[torch.Tensor]) -> torch.Tensor:
348 | if self.has_pos_emb and r is not None:
349 | k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
350 | v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
351 | sim = (q_i * k_j).sum(dim=-1) * self.scale
352 | attn = softmax(sim, index, ptr)
353 | attn = self.attn_drop(attn)
354 | return v_j * attn.unsqueeze(-1)
355 |
356 | def update(self,
357 | inputs: torch.Tensor,
358 | x_dst: torch.Tensor) -> torch.Tensor:
359 | inputs = inputs.view(-1, self.num_heads * self.head_dim)
360 | g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
361 | return inputs + g * (self.to_s(x_dst) - inputs)
362 |
363 | def _attn_block(self,
364 | x_src: torch.Tensor,
365 | x_dst: torch.Tensor,
366 | r: Optional[torch.Tensor],
367 | edge_index: torch.Tensor) -> torch.Tensor:
368 | q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
369 | k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
370 | v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
371 | agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
372 | return self.to_out(agg)
373 |
374 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
375 | return self.ff_mlp(x)
--------------------------------------------------------------------------------
/models/refine.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from losses import LaplaceNLLLoss
7 | from losses import SoftTargetCrossEntropyLoss
8 | from losses import ScoreRegL1Loss
9 | from metrics import ADE
10 | from metrics import FDE
11 | from metrics import MR
12 | from models import TargetRegion
13 | from collections import OrderedDict
14 |
15 | from utils import TemporalData
16 |
17 |
18 |
19 | class Refine(pl.LightningModule):
20 |
21 | def __init__(self,
22 | cls_temperture: int,
23 | lr: float,
24 | weight_decay: float,
25 | T_max: int,
26 | rotate: bool,
27 |
28 | future_steps: int,
29 | num_modes: int,
30 | node_dim: int,
31 | edge_dim: int,
32 | embed_dim: int,
33 | seg_num: int,
34 | refine_num: int,
35 | refine_radius: int,
36 | r_lo: int,
37 | r_hi: int,
38 | **kwargs) -> None:
39 | super(Refine, self).__init__()
40 | self.save_hyperparameters()
41 |
42 | self.cls_temperture = cls_temperture
43 |
44 | self.lr = lr
45 | self.weight_decay = weight_decay
46 | self.T_max = T_max
47 |
48 | self.future_steps = future_steps
49 | self.num_modes = num_modes
50 | self.rotate = rotate
51 |
52 | self.refine_num = refine_num
53 |
54 | self.target_encoder = TargetRegion(
55 | future_steps=future_steps,
56 | num_modes=num_modes,
57 | node_dim=node_dim,
58 | edge_dim=edge_dim,
59 | embed_dim=embed_dim,
60 | refine_num=refine_num,
61 | seg_num=seg_num,
62 | refine_radius=refine_radius,
63 | r_lo=r_lo,
64 | r_hi=r_hi,
65 | **kwargs)
66 |
67 |
68 | self.reg_loss = LaplaceNLLLoss(reduction='mean')
69 | self.cls_loss = SoftTargetCrossEntropyLoss(reduction='mean')
70 | self.score_loss = ScoreRegL1Loss()
71 |
72 | self.minADE = ADE()
73 | self.minFDE = FDE()
74 | self.minMR = MR()
75 |
76 | def to_global_coord(self, data):
77 | data_angles = data['theta']
78 | data_rotate_angle = data['rotate_angles'][data['agent_index']]
79 |
80 | rotate_local = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
81 | sin_vals_angle = torch.sin(-data_rotate_angle)
82 | cos_vals_angle = torch.cos(-data_rotate_angle)
83 | rotate_local[:, 0, 0] = cos_vals_angle
84 | rotate_local[:, 0, 1] = -sin_vals_angle
85 | rotate_local[:, 1, 0] = sin_vals_angle
86 | rotate_local[:, 1, 1] = cos_vals_angle
87 | # agent to av
88 | data.rotate_local = rotate_local
89 |
90 | rotate_mat = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
91 | sin_vals = torch.sin(-data_angles)
92 | cos_vals = torch.cos(-data_angles)
93 | rotate_mat[:, 0, 0] = cos_vals
94 | rotate_mat[:, 0, 1] = -sin_vals
95 | rotate_mat[:, 1, 0] = sin_vals
96 | rotate_mat[:, 1, 1] = cos_vals
97 | # av to global
98 | data.rotate_mat_ = rotate_mat
99 |
100 | rotate_mat_ = torch.empty(data['agent_index'].shape[0], 2, 2, device=self.device)
101 | sin_vals = torch.sin(data_angles)
102 | cos_vals = torch.cos(data_angles)
103 | rotate_mat_[:, 0, 0] = cos_vals
104 | rotate_mat_[:, 0, 1] = -sin_vals
105 | rotate_mat_[:, 1, 0] = sin_vals
106 | rotate_mat_[:, 1, 1] = cos_vals
107 | # global to av
108 | data.r_rotate_mat_ = rotate_mat_
109 |
110 | def refine(self, data, ys_hat, embed):
111 |
112 | assert ys_hat.shape[-1] == 2
113 |
114 | y_hat_ego = ys_hat.reshape(ys_hat.shape[1]*self.num_modes, -1, 2) # n*f, t,2
115 |
116 | refine_y_hat, refine_pi = self.target_encoder(data, y_hat_ego, embed)
117 |
118 | return refine_y_hat, refine_pi
119 |
120 | def forward(self, data: TemporalData, p1_data=None):
121 | if self.rotate:
122 | rotate_mat = torch.empty(data.num_nodes, 2, 2, device=self.device)
123 | sin_vals = torch.sin(data['rotate_angles'])
124 | cos_vals = torch.cos(data['rotate_angles'])
125 | rotate_mat[:, 0, 0] = cos_vals
126 | rotate_mat[:, 0, 1] = -sin_vals
127 | rotate_mat[:, 1, 0] = sin_vals
128 | rotate_mat[:, 1, 1] = cos_vals
129 | if data.y is not None:
130 | data.y = torch.bmm(data.y, rotate_mat)
131 | data['rotate_mat'] = rotate_mat
132 | else:
133 | data['rotate_mat'] = None
134 |
135 | self.to_global_coord(data)
136 |
137 | ys_hat_ego = p1_data['traj'].transpose(0,1)
138 | traj_embed = p1_data['embed'].transpose(0,1)
139 |
140 | ys_refine, pis_refine = self.refine(data, ys_hat_ego, traj_embed)
141 |
142 | # concat for later laplace sigma computation.
143 | return torch.cat((ys_hat_ego, ys_hat_ego), -1), None, ys_refine, pis_refine
144 |
145 | def training_step(self, data, batch_idx):
146 | data, p1_data = data
147 | reg_mask = ~data['padding_mask'][:, -self.future_steps:]
148 | reg_mask_ego = ~data['padding_mask'][data.agent_index][:, -self.future_steps:]
149 | valid_steps = reg_mask.sum(dim=-1)
150 | valid_steps_ego = reg_mask_ego.sum(dim=-1)
151 | cls_mask = valid_steps > 0
152 | cls_mask_ego = valid_steps_ego > 0
153 |
154 | ys_hat_ego, _, refine_y_hat_deltas, refine_pis = self(data, p1_data)
155 |
156 | refine_pi, refine_score = refine_pis
157 | y_agent = data.y[data.agent_index]
158 |
159 | reg_loss_refines = 0
160 | cls_loss_refines = 0
161 | score_loss_refines=0
162 |
163 | max_val = (torch.norm(ys_hat_ego[..., :2] - y_agent, p=2, dim=-1) * reg_mask_ego).sum(dim=-1)
164 | max_val = max_val.min(0)[0]
165 | y_i = ys_hat_ego.clone()
166 | min_vals = []
167 | min_vals.append(max_val)
168 | for i in range(self.refine_num):
169 | y_i = y_i + refine_y_hat_deltas[i]
170 | l2_norm = (torch.norm(y_i[..., :2] - y_agent, p=2, dim=-1) * reg_mask_ego).sum(dim=-1)
171 | min_vals.append(l2_norm.min(0)[0])
172 | min_vals = torch.stack(min_vals)
173 |
174 | min_val = min_vals.min(0)[0]
175 | max_val = min_vals.max(0)[0]
176 | min_id = min_vals.min(0)[1]
177 | max_id = min_vals.max(0)[1]
178 |
179 | refine_y_hat = ys_hat_ego
180 | refine_score_i = refine_score[0].transpose(0,1)
181 | l2_norm = (torch.norm(refine_y_hat[..., :2] - y_agent, p=2, dim=-1) * reg_mask_ego).sum(dim=-1) # [F, N]
182 | best_mode = l2_norm.argmin(dim=0)
183 |
184 | target_score_i = ((max_val - l2_norm.min(0)[0]) / ((max_val - min_val)+1e-6))
185 | target_score_i = torch.clamp(target_score_i,0,1)
186 | refine_score_i = refine_score_i[best_mode, torch.arange(data.num_graphs)]
187 | score_loss_refine = self.score_loss(refine_score_i, target_score_i)
188 | score_loss_refines += score_loss_refine
189 |
190 | for i in range(self.refine_num):
191 | refine_y_hat_i = refine_y_hat_deltas[i]
192 | refine_pi_i = refine_pi[i]
193 | refine_score_i = refine_score[i+1].transpose(0,1)
194 |
195 | refine_y_hat[...,:2] = refine_y_hat[...,:2] + refine_y_hat_i[...,:2]
196 | refine_y_hat[...,2:] = refine_y_hat_i[...,2:]
197 |
198 | l2_norm = (torch.norm(refine_y_hat[..., :2] - y_agent, p=2, dim=-1) * reg_mask_ego).sum(dim=-1) # [F, N]
199 | best_mode = l2_norm.argmin(dim=0)
200 | refine_y_hat_best = refine_y_hat[best_mode, torch.arange(data.num_graphs)] # n, t, 4
201 | reg_loss_refine = self.reg_loss(refine_y_hat_best[reg_mask_ego], y_agent[reg_mask_ego])
202 | reg_loss_refines += reg_loss_refine
203 |
204 | soft_target = F.softmax((-l2_norm[:, cls_mask_ego] / valid_steps_ego[cls_mask_ego])/self.cls_temperture, dim=0).t().detach()
205 | cls_loss_refine = self.cls_loss(refine_pi_i[cls_mask_ego], soft_target)
206 | cls_loss_refines += cls_loss_refine
207 |
208 | target_score_i = ((max_val - l2_norm.min(0)[0]) / ((max_val - min_val)+1e-6))
209 | refine_score_i = refine_score_i[best_mode, torch.arange(data.num_graphs)]
210 | target_score_i = torch.clamp(target_score_i,0,1)
211 | score_loss_refine = self.score_loss(refine_score_i, target_score_i)
212 | score_loss_refines += score_loss_refine
213 |
214 | self.log('refine_reg_loss', reg_loss_refines/self.refine_num, prog_bar=False, on_step=True, on_epoch=True, batch_size=1)
215 | self.log('refine_cls_loss', cls_loss_refines/self.refine_num, prog_bar=False, on_step=True, on_epoch=True, batch_size=1)
216 | self.log('refine_score_loss', score_loss_refines/(self.refine_num+1), prog_bar=False, on_step=True, on_epoch=True, batch_size=1)# else:
217 |
218 | loss = reg_loss_refines/self.refine_num + cls_loss_refines/self.refine_num
219 |
220 | loss += 0.01*(score_loss_refines)/(self.refine_num+1)
221 |
222 | return loss
223 |
224 | def validation_step(self, data, batch_idx):
225 | data, p1_data = data
226 | reg_mask = ~data['padding_mask'][data.agent_index][:, -self.future_steps:]
227 | valid_steps = reg_mask.sum(dim=-1)
228 | cls_mask = valid_steps > 0
229 |
230 | y_hat_init_ego, _, refine_y_hat_deltas, refine_pis = self(data, p1_data)
231 |
232 | refine_pis, refine_scores = refine_pis
233 |
234 | y_agent = data.y[data.agent_index]
235 |
236 | max_val = (torch.norm(y_hat_init_ego[..., :2] - y_agent, p=2, dim=-1) * reg_mask).sum(dim=-1)
237 |
238 | y_i = y_hat_init_ego.clone()
239 | min_vals = []
240 | max_val = max_val.min(0)[0]
241 | min_vals.append(max_val)
242 | for i in range(self.refine_num):
243 | y_i += refine_y_hat_deltas[i]
244 | l2_norm = (torch.norm(y_i[..., :2] - y_agent, p=2, dim=-1) * reg_mask).sum(dim=-1)
245 | min_vals.append(l2_norm.min(0)[0])
246 | min_vals = torch.stack(min_vals)
247 | min_val = min_vals.min(0)[0]
248 | max_val = min_vals.max(0)[0]
249 |
250 | score_loss_refines=0
251 | refine_y_hat = y_hat_init_ego.clone()
252 |
253 | refine_score_i = refine_scores[0].transpose(0,1)
254 | l2_norm = (torch.norm(refine_y_hat[..., :2] - y_agent, p=2, dim=-1) * reg_mask).sum(dim=-1) # [F, N]
255 | best_mode = l2_norm.argmin(0)
256 |
257 | target_score_i = ((max_val - l2_norm.min(0)[0]) / ((max_val - min_val)))
258 | refine_score_i = refine_score_i[best_mode, torch.arange(data.num_graphs)]
259 | score_loss_refine = self.score_loss(refine_score_i, target_score_i)
260 | score_loss_refines += score_loss_refine
261 | for i in range(self.refine_num):
262 | refine_y_hat[...,:2] += refine_y_hat_deltas[i][...,:2]
263 | refine_y_hat[...,2:] = refine_y_hat_deltas[i][...,2:]
264 |
265 | refine_score_i = refine_scores[i+1].transpose(0,1)
266 | l2_norm = (torch.norm(refine_y_hat[..., :2] - y_agent, p=2, dim=-1) * reg_mask).sum(dim=-1) # [F, N]
267 | best_mode = l2_norm.argmin(0)
268 | target_score_i = ((max_val - l2_norm.min(0)[0]) / ((max_val - min_val)))
269 | refine_score_i = refine_score_i[best_mode, torch.arange(data.num_graphs)]
270 | score_loss_refine = self.score_loss(refine_score_i, target_score_i)
271 | score_loss_refines += score_loss_refine
272 |
273 | refine_pi = refine_pis[-1]
274 |
275 | l2_norm = (torch.norm(refine_y_hat[..., :2] - y_agent, p=2, dim=-1) * reg_mask).sum(dim=-1) # [F, N]
276 | best_mode = l2_norm.argmin(dim=0)
277 | refine_y_hat_best = refine_y_hat[best_mode, torch.arange(data.num_graphs)] # n, t, 4
278 | reg_loss_refine = self.reg_loss(refine_y_hat_best[reg_mask], y_agent[reg_mask])
279 | soft_target = F.softmax((-l2_norm[:, cls_mask] / valid_steps[cls_mask])/self.cls_temperture, dim=0).t().detach()
280 | cls_loss_refine = self.cls_loss(refine_pi[cls_mask], soft_target)
281 | self.log('val_refine_reg_loss', reg_loss_refine, prog_bar=False, on_step=False, on_epoch=True, batch_size=1)
282 | self.log('val_refine_cls_loss', cls_loss_refine, prog_bar=False, on_step=False, on_epoch=True, batch_size=1)
283 | self.log('val_refine_score_loss', score_loss_refines/(self.refine_num+1), prog_bar=False, on_step=False, on_epoch=True, batch_size=1)
284 |
285 | y_hat_agent = refine_y_hat[..., : 2]
286 | fde_agent = torch.norm(y_hat_agent[:, :, -1] - y_agent[:, -1], p=2, dim=-1)
287 | best_mode_agent = fde_agent.argmin(dim=0)
288 | y_hat_best_agent = y_hat_agent[best_mode_agent, torch.arange(data.num_graphs)]
289 | self.minADE.update(y_hat_best_agent, y_agent)
290 | self.minFDE.update(y_hat_best_agent, y_agent)
291 | self.minMR.update(y_hat_best_agent, y_agent)
292 | self.log('val_minADE', self.minADE, prog_bar=False, on_step=False, on_epoch=True, batch_size=y_agent.size(0))
293 | self.log('val_minFDE', self.minFDE, prog_bar=False, on_step=False, on_epoch=True, batch_size=y_agent.size(0))
294 | self.log('val_minMR', self.minMR, prog_bar=False, on_step=False, on_epoch=True, batch_size=y_agent.size(0))
295 |
296 | def configure_optimizers(self):
297 | decay = set()
298 | no_decay = set()
299 | whitelist_weight_modules = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.MultiheadAttention, nn.LSTM, nn.GRU)
300 | blacklist_weight_modules = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, nn.Embedding)
301 | for module_name, module in self.named_modules():
302 | for param_name, param in module.named_parameters():
303 | full_param_name = '%s.%s' % (module_name, param_name) if module_name else param_name
304 | if 'bias' in param_name:
305 | no_decay.add(full_param_name)
306 | elif 'weight' in param_name:
307 | if isinstance(module, whitelist_weight_modules):
308 | decay.add(full_param_name)
309 | elif isinstance(module, blacklist_weight_modules):
310 | no_decay.add(full_param_name)
311 | elif not ('weight' in param_name or 'bias' in param_name):
312 | no_decay.add(full_param_name)
313 | param_dict = {param_name: param for param_name, param in self.named_parameters()}
314 | inter_params = decay & no_decay
315 | union_params = decay | no_decay
316 | assert len(inter_params) == 0
317 | assert len(param_dict.keys() - union_params) == 0
318 |
319 | optim_groups = [
320 | {"params": [param_dict[param_name] for param_name in sorted(list(decay)) if 'encoder_phase1' not in param_name],
321 | "lr": self.lr,
322 | "weight_decay": self.weight_decay},
323 | {"params": [param_dict[param_name] for param_name in sorted(list(no_decay)) if 'encoder_phase1' not in param_name],
324 | "lr": self.lr,
325 | "weight_decay": 0.0},
326 | ]
327 |
328 | optimizer = torch.optim.AdamW(optim_groups)
329 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=self.T_max, eta_min=0.0)
330 | return [optimizer], [scheduler]
331 |
332 | @staticmethod
333 | def add_model_specific_args(parent_parser):
334 | parser = parent_parser.add_argument_group('Refine')
335 | parser.add_argument('--lr', type=float, default=5e-4)
336 | parser.add_argument('--weight_decay', type=float, default=1e-4)
337 | parser.add_argument('--T_max', type=int, default=64)
338 | parser.add_argument('--local_radius', type=int, default=150)
339 | parser.add_argument('--cls_temperture', type=int, default=1)
340 |
341 |
342 | parser.add_argument('--future_steps', type=int, default=30)
343 | parser.add_argument('--num_modes', type=int, default=6)
344 | parser.add_argument('--rotate', type=bool, default=True)
345 | parser.add_argument('--node_dim', type=int, default=2)
346 | parser.add_argument('--edge_dim', type=int, default=2)
347 | parser.add_argument('--embed_dim', type=int, required=True)
348 | parser.add_argument('--seg_num', type=int, default=2)
349 | parser.add_argument('--refine_num', type=int, required=True)
350 | parser.add_argument('--refine_radius', type=int, default=-1)
351 | parser.add_argument('--r_lo', type=int, default=2)
352 | parser.add_argument('--r_hi', type=int, default=10)
353 | return parent_parser
354 |
--------------------------------------------------------------------------------
/models/target_region.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from models import MLPDecoder, MLPDeltaDecoder, MLPDeltaDecoderPi, MLPDeltaDecoderScore
7 | from models.local_encoder import ALEncoder, ALEncoderWithAo
8 | from itertools import permutations
9 | from utils import TemporalData
10 | from utils import DistanceDropEdge
11 | from torch_geometric.utils import subgraph
12 | from itertools import product
13 | import numpy as np
14 | from utils import init_weights
15 | from torch_geometric.utils import dense_to_sparse
16 | from torch_cluster import radius
17 | from torch_cluster import radius_graph
18 | from torch_geometric.data import Batch
19 |
20 |
21 | class TargetRegion(nn.Module):
22 |
23 | def __init__(self,
24 | future_steps: int,
25 | num_modes: int,
26 | node_dim: int,
27 | edge_dim: int,
28 | embed_dim: int,
29 | refine_num: int,
30 | seg_num: int,
31 | refine_radius: int,
32 | r_lo: int,
33 | r_hi: int,
34 | **kwargs) -> None:
35 | super(TargetRegion, self).__init__()
36 | self.num_modes = num_modes
37 | self.future_steps = future_steps
38 | self.embed_dim = embed_dim
39 | self.radius = refine_radius
40 | if self.radius == -1:
41 | self._radius = [0.8, 0.8*1/2, 0.8*1/4, 0.8*1/8, 0.8*1/16]
42 | self.refine_num = refine_num
43 | self.seg_num = seg_num
44 | self.r_lo = r_lo
45 | self.r_hi = r_hi
46 |
47 | assert embed_dim == 64
48 |
49 | fc_module = []
50 | #! 128 for hivt
51 | fc_module.append(nn.Linear(128, embed_dim))
52 | self.fc_encoder = nn.Sequential(*fc_module)
53 |
54 | fusion_module = []
55 | for i in range(self.seg_num):
56 | fusion_module.append(ALEncoderWithAo(node_dim=node_dim,
57 | edge_dim=edge_dim,
58 | embed_dim=embed_dim))
59 | self.target_al_encoder = nn.Sequential(*fusion_module)
60 |
61 | dec_module = []
62 | dec_module.append(MLPDeltaDecoder(local_channels = embed_dim,
63 | global_channels = embed_dim,
64 | future_steps = future_steps//self.seg_num, # cut to chunk
65 | num_modes = num_modes,
66 | with_cumsum=0))
67 | self.refine_decoder = nn.Sequential(*dec_module)
68 |
69 | dec_pi_module = []
70 | dec_pi_module.append(MLPDeltaDecoderPi(embed_dim=embed_dim,))
71 | self.refine_pi_decoder = nn.Sequential(*dec_pi_module)
72 |
73 | self.pos_embed = nn.Parameter(torch.zeros(self.refine_num+1, 1, embed_dim))
74 |
75 | score_module = []
76 | score_module.append(nn.GRU(input_size=embed_dim,hidden_size=embed_dim))
77 | score_module.append(MLPDeltaDecoderScore(embed_dim=embed_dim, with_last=False))
78 | self.refine_score_decoder = nn.Sequential(*score_module)
79 |
80 | self.apply(init_weights)
81 |
82 |
83 | def forward(self, data: TemporalData, y_hat, ego_embed):
84 |
85 | y_hat_init = y_hat
86 |
87 | rotate_local_modes = data.rotate_local.repeat(self.num_modes, 1, 1)
88 | data_local_origin_modes = data.positions[data['agent_index'], 19, :].repeat(self.num_modes, 1)
89 |
90 | num_ego = data.agent_index.shape[0]
91 | new_agent_index = torch.arange(data.agent_index.shape[0]*self.num_modes).to(ego_embed.device) # n*f f1 f2 ... fn
92 |
93 | mask_dst = torch.ones((num_ego, self.num_modes)).to(ego_embed.device).bool()
94 | edge_index_m2m = dense_to_sparse(mask_dst.unsqueeze(2) & mask_dst.unsqueeze(1))[0]
95 |
96 | tar_lane_positions = data.tar_lane_positions
97 | tar_lane_vectors = data.tar_lane_vectors
98 | tar_is_intersections = data.tar_is_intersections
99 | tar_turn_directions = data.tar_turn_directions
100 | tar_traffic_controls = data.tar_traffic_controls
101 |
102 | trajs = []
103 | pis = []
104 | scores = []
105 | embeds = []
106 |
107 | ego_embed = self.fc_encoder[0](ego_embed)
108 |
109 | ego_embed = ego_embed.reshape(self.num_modes*num_ego, -1)
110 | score = self.refine_score_decoder[0]((ego_embed.unsqueeze(0)))[0][-1]
111 | score = self.refine_score_decoder[1](score.reshape(self.num_modes, num_ego, -1)+self.pos_embed[:1])
112 | embeds.append(ego_embed.detach())
113 | ego_embed = ego_embed.reshape(self.num_modes, num_ego, -1)
114 | scores.append(score)
115 |
116 | for refine_iter in range(self.refine_num):
117 |
118 | if refine_iter == 0:
119 | y_hat_agent_cord = y_hat_init.clone()
120 | y_hat = torch.bmm(y_hat_init, rotate_local_modes)+data_local_origin_modes.unsqueeze(1)
121 | else:
122 | y_hat_init = y_hat_init + y_hat_delta
123 | y_hat_agent_cord = y_hat_init.clone()
124 | y_hat = torch.bmm(y_hat_init, rotate_local_modes)+data_local_origin_modes.unsqueeze(1)
125 |
126 | # argo predict 30 timesteps
127 | if self.seg_num == 1:
128 | idx = [-1]
129 | elif self.seg_num == 2:
130 | idx = [-16, -1]
131 | elif self.seg_num == 3:
132 | idx = [-21, -11, -1]
133 | elif self.seg_num == 5:
134 | idx = [-25, -19, -13, -7, -1]
135 | elif self.seg_num == 6:
136 | idx = [-26, -21, -16, -11, -6, -1]
137 | else:
138 | assert False
139 |
140 | target_hats = [y_hat[:, id].reshape(self.num_modes, num_ego, -1) for id in idx]
141 |
142 | refine_cum_sum = []
143 |
144 | for tar_id, target_hat in enumerate(target_hats):
145 |
146 | ego_embed = ego_embed.reshape(self.num_modes*num_ego, -1)
147 |
148 | tar_index = []
149 | split_len = 0
150 | tar_lane_actor_vectors = []
151 | for i, tar_lane_point_num in enumerate(data.tar_lane_points_num): # batch
152 | num_point = tar_lane_point_num
153 | index_lo, index_hi = split_len, split_len + num_point
154 | tar_lane_positions_i = tar_lane_positions[index_lo:index_hi]
155 |
156 | tar_lane_actor_vectors_i = \
157 | tar_lane_positions_i.repeat_interleave(self.num_modes, dim=0) - target_hat[:,i].repeat(tar_lane_positions_i.size(0), 1)
158 |
159 | index_this = [i+j*num_ego for j in range(self.num_modes)]
160 | index_i = torch.cartesian_prod(torch.arange(index_lo, index_hi).long().to(ego_embed.device), new_agent_index[index_this].long())
161 |
162 | tar_index.append(index_i)
163 | tar_lane_actor_vectors.append(tar_lane_actor_vectors_i) # p*f
164 |
165 | split_len = index_hi
166 |
167 | tar_lane_actor_index = torch.cat(tar_index).t().contiguous().to(ego_embed.device)
168 |
169 | tar_lane_actor_vectors = torch.cat(tar_lane_actor_vectors).to(ego_embed.device)
170 |
171 | #! use api
172 | # pos_m = data.positions[:,seg_end-1]
173 | # pos_m = y_hat
174 | # num_batch = num_ego
175 | # batch_x = torch.tensor([i for i in range(num_batch)]).to(ego_embed.device)
176 | # batch_x = torch.cat([batch_x for i in range(self.num_modes)], dim=0)
177 | # batch_y = []
178 | # for i, n_p in enumerate(data.tar_lane_points_num):
179 | # batch_y += [i]*n_p
180 | # batch_y = torch.tensor(batch_y, dtype=torch.int64).to(batch_x.device)
181 | # # batch_y = torch.cat([batch_y + i*num_batch for i in range(self.num_modes)], dim=0)
182 |
183 | # lane_positions = data.tar_lane_positions
184 |
185 | # batch_x = batch_x.repeat(30)
186 |
187 | # edge_index_pt2m = radius(
188 | # x=pos_m.transpose(0,1).reshape(-1,2),
189 | # # x=pos_m[:,-1],
190 | # y=lane_positions,
191 | # r=10,
192 | # batch_x=batch_x if isinstance(data, Batch) else None,
193 | # batch_y=batch_y if isinstance(data, Batch) else None,
194 | # max_num_neighbors=300)
195 |
196 | # edge_index_pt2m[1] = edge_index_pt2m[1] % (pos_m.shape[0])
197 | # edge_index_pt2m = torch.unique(edge_index_pt2m, dim=1)
198 | # edge_attr_pt2m = lane_positions[edge_index_pt2m[0]] - pos_m[:,-1][edge_index_pt2m[1]]
199 | # tar_lane_actor_index = edge_index_pt2m
200 | # tar_lane_actor_vectors = edge_attr_pt2m
201 |
202 |
203 | if self.radius == -1:
204 | dis_prefix = torch.cat((torch.zeros(self.num_modes*num_ego, 1, 2).to(ego_embed.device), y_hat_agent_cord[:,:-1]), dim=1)
205 | dis = torch.norm(y_hat_agent_cord-dis_prefix,dim=-1).sum(-1)
206 |
207 | dis = dis*self._radius[refine_iter]
208 |
209 | dis[disself.r_hi] = self.r_hi
211 | dis_this = dis[tar_lane_actor_index[1,:]]
212 | mask = torch.norm(tar_lane_actor_vectors, p=2, dim=-1) < dis_this
213 | else:
214 | mask = torch.norm(tar_lane_actor_vectors, p=2, dim=-1) < self.radius
215 |
216 | tar_lane_actor_index = tar_lane_actor_index[:, mask]
217 | tar_lane_actor_vectors = tar_lane_actor_vectors[mask]
218 |
219 | vec_ao = data_local_origin_modes - target_hat.reshape(self.num_modes*num_ego, -1)
220 |
221 | rotate_mat_ego = data.rotate_mat[data.agent_index]
222 | rotate_mat_ego = rotate_mat_ego.repeat(self.num_modes, 1, 1)
223 |
224 | theta_now = torch.atan2(target_hat.reshape(self.num_modes*num_ego, -1)[..., 1:2] - y_hat[:,idx[tar_id]-1,1:2],
225 | target_hat.reshape(self.num_modes*num_ego, -1)[..., 0:1] - y_hat[:,idx[tar_id]-1,:1])
226 | rotate_mat_tar = torch.cat(
227 | (
228 | torch.cat((torch.cos(theta_now), -torch.sin(theta_now)), -1).unsqueeze(-2),
229 | torch.cat((torch.sin(theta_now), torch.cos(theta_now)), -1).unsqueeze(-2)
230 | ),
231 | -2
232 | )
233 |
234 | rotate_mat_ego = rotate_mat_tar.reshape(self.num_modes*num_ego, 2, 2)
235 |
236 | ego_embed = self.target_al_encoder[tar_id](x=(tar_lane_vectors, ego_embed),
237 | edge_index=tar_lane_actor_index,
238 | edge_attr=tar_lane_actor_vectors,
239 | is_intersections=tar_is_intersections,
240 | turn_directions=tar_turn_directions,
241 | traffic_controls=tar_traffic_controls,
242 | vec_ao=vec_ao,
243 | rotate_mat=rotate_mat_ego)
244 |
245 |
246 | refine_y_hat_delta = self.refine_decoder[0](ego_embed + self.pos_embed[refine_iter+1])
247 | refine_cum_sum.append(refine_y_hat_delta)
248 |
249 | ego_embed = ego_embed.reshape(self.num_modes, num_ego, -1)
250 |
251 | refine_y_hat_delta = torch.cat(refine_cum_sum, dim=-2).view(self.num_modes, num_ego, self.future_steps, 4)
252 |
253 | refine_pi = self.refine_pi_decoder[0](ego_embed + self.pos_embed[refine_iter+1:refine_iter+2])
254 | pis.append(refine_pi)
255 |
256 | ego_embed = ego_embed.reshape(self.num_modes*num_ego, -1)
257 | embeds_before = torch.stack(embeds, 0)
258 | score_input = torch.cat((embeds_before, ego_embed.unsqueeze(0)), 0)
259 | score = self.refine_score_decoder[0](score_input)[0][-1]
260 | score = self.refine_score_decoder[1](score.reshape(self.num_modes, num_ego, -1)+self.pos_embed[refine_iter+1:refine_iter+2])
261 | embeds.append(ego_embed.detach())
262 | ego_embed = ego_embed.reshape(self.num_modes, num_ego, -1)
263 | scores.append(score)
264 |
265 | ego_embed = ego_embed.detach()
266 | y_hat_delta = refine_y_hat_delta.reshape(self.num_modes*num_ego, -1, 4)[...,:2].detach()
267 |
268 | trajs.append(refine_y_hat_delta)
269 |
270 | ret_pis = pis, scores
271 | return trajs, ret_pis
272 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch<2.0.0
2 | pytorch-lightning==1.5.10
3 | torch-geometric>=2.2.0
4 | torch-cluster>=1.6.0
5 | torch-scatter>=2.1.0
6 | torch-sparse>=0.6.16
7 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 |
6 | from datamodules import ArgoverseV1DataModule
7 | from models.refine import Refine
8 |
9 | import torch
10 | import os
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--data_root', type=str, required=True)
14 | parser.add_argument('--p1_root', type=str, required=True)
15 | parser.add_argument('--train_batch_size', type=int, default=4)
16 | parser.add_argument('--val_batch_size', type=int, default=4)
17 | parser.add_argument('--shuffle', type=bool, default=True)
18 | parser.add_argument('--num_workers', type=int, default=8)
19 | parser.add_argument('--seed', type=int, default=2024)
20 | parser.add_argument('--pin_memory', type=bool, default=True)
21 | parser.add_argument('--persistent_workers', type=bool, default=True)
22 | parser.add_argument('--prefetch_factor', type=int, default=4)
23 | parser.add_argument('--max_epochs', type=int, default=64)
24 | parser.add_argument('--monitor', type=str, default='val_minFDE', choices=['val_minADE', 'val_minFDE', 'val_minMR'])
25 | parser.add_argument('--save_top_k', type=int, default=5)
26 | parser.add_argument('--exp_name', type=str, required=True)
27 | parser.add_argument('--gpus', type=int, default=1)
28 | parser = Refine.add_model_specific_args(parser)
29 | args = parser.parse_args()
30 | if args.num_workers == 0:
31 | args.persistent_workers = False
32 | args.accelerator='auto'
33 | if args.gpus > 1:
34 | args.strategy="ddp_find_unused_parameters_false"
35 |
36 | pl.seed_everything(args.seed)
37 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min')
38 | #! base dir for loogging
39 | base_dir="./"
40 | trainer = pl.Trainer.from_argparse_args(args, callbacks=[model_checkpoint],
41 | default_root_dir=base_dir+args.exp_name)
42 | model = Refine(**vars(args))
43 | datamodule = ArgoverseV1DataModule.from_argparse_args(args)
44 | trainer.fit(model, datamodule)
45 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | set -x
2 | # change root to your path of dataset root.
3 | data_root=../argo1_data/
4 | # change p1_root to your path of prediction outputs root.
5 | p1_root=../p1_data/
6 | # experiment name used for logging.
7 | exp=smartref_hivt_argo1
8 | # device number.
9 | ngpus=1
10 | pwd
11 |
12 | python train.py \
13 | --data_root $data_root --p1_root $p1_root --exp $exp \
14 | --train_batch_size 32 --val_batch_size 32 \
15 | --gpus $ngpus --embed_dim 64 --refine_num 5 --seg_num 2 \
16 | --refine_radius -1 --r_lo 2 --r_hi 10 \
17 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch_geometric.data import Data
7 |
8 |
9 | class TemporalData(Data):
10 |
11 | def __init__(self,
12 | **kwargs) -> None:
13 | super(TemporalData, self).__init__(**kwargs)
14 |
15 | def __inc__(self, key, value):
16 | return super().__inc__(key, value)
17 |
18 |
19 | class DistanceDropEdge(object):
20 |
21 | def __init__(self, max_distance: Optional[float] = None) -> None:
22 | self.max_distance = max_distance
23 |
24 | def __call__(self,
25 | edge_index: torch.Tensor,
26 | edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
27 | if self.max_distance is None:
28 | return edge_index, edge_attr
29 | row, col = edge_index
30 | mask = torch.norm(edge_attr, p=2, dim=-1) < self.max_distance
31 | edge_index = torch.stack([row[mask], col[mask]], dim=0)
32 | edge_attr = edge_attr[mask]
33 | return edge_index, edge_attr
34 |
35 |
36 | def init_weights(m: nn.Module) -> None:
37 | if isinstance(m, nn.Linear):
38 | nn.init.xavier_uniform_(m.weight)
39 | if m.bias is not None:
40 | nn.init.zeros_(m.bias)
41 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
42 | fan_in = m.in_channels / m.groups
43 | fan_out = m.out_channels / m.groups
44 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
45 | nn.init.uniform_(m.weight, -bound, bound)
46 | if m.bias is not None:
47 | nn.init.zeros_(m.bias)
48 | elif isinstance(m, nn.Embedding):
49 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
50 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
51 | nn.init.ones_(m.weight)
52 | nn.init.zeros_(m.bias)
53 | elif isinstance(m, nn.LayerNorm):
54 | nn.init.ones_(m.weight)
55 | nn.init.zeros_(m.bias)
56 | elif isinstance(m, nn.MultiheadAttention):
57 | if m.in_proj_weight is not None:
58 | fan_in = m.embed_dim
59 | fan_out = m.embed_dim
60 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
61 | nn.init.uniform_(m.in_proj_weight, -bound, bound)
62 | else:
63 | nn.init.xavier_uniform_(m.q_proj_weight)
64 | nn.init.xavier_uniform_(m.k_proj_weight)
65 | nn.init.xavier_uniform_(m.v_proj_weight)
66 | if m.in_proj_bias is not None:
67 | nn.init.zeros_(m.in_proj_bias)
68 | nn.init.xavier_uniform_(m.out_proj.weight)
69 | if m.out_proj.bias is not None:
70 | nn.init.zeros_(m.out_proj.bias)
71 | if m.bias_k is not None:
72 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
73 | if m.bias_v is not None:
74 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
75 | elif isinstance(m, nn.LSTM):
76 | for name, param in m.named_parameters():
77 | if 'weight_ih' in name:
78 | for ih in param.chunk(4, 0):
79 | nn.init.xavier_uniform_(ih)
80 | elif 'weight_hh' in name:
81 | for hh in param.chunk(4, 0):
82 | nn.init.orthogonal_(hh)
83 | elif 'weight_hr' in name:
84 | nn.init.xavier_uniform_(param)
85 | elif 'bias_ih' in name:
86 | nn.init.zeros_(param)
87 | elif 'bias_hh' in name:
88 | nn.init.zeros_(param)
89 | nn.init.ones_(param.chunk(4, 0)[1])
90 | elif isinstance(m, nn.GRU):
91 | for name, param in m.named_parameters():
92 | if 'weight_ih' in name:
93 | for ih in param.chunk(3, 0):
94 | nn.init.xavier_uniform_(ih)
95 | elif 'weight_hh' in name:
96 | for hh in param.chunk(3, 0):
97 | nn.init.orthogonal_(hh)
98 | elif 'bias_ih' in name:
99 | nn.init.zeros_(param)
100 | elif 'bias_hh' in name:
101 | nn.init.zeros_(param)
102 |
--------------------------------------------------------------------------------