├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Car TOP_VIEW 375397.png ├── Car TOP_VIEW 80CBE5.png ├── Car TOP_VIEW ROBOT.png ├── intro.gif └── readme.gif ├── configs ├── pgp_gatx2_lvm_traversal.yml ├── pgp_gatx2_mtp.yml ├── preprocess_nuscenes.yml └── xscout_pgp.yml ├── datasets ├── interface.py └── nuScenes │ ├── nuScenes.py │ ├── nuScenes_graphs.py │ ├── nuScenes_raster.py │ └── nuScenes_vector.py ├── evaluate.py ├── metrics ├── covernet_loss.py ├── goal_pred_nll.py ├── metric.py ├── min_ade.py ├── min_fde.py ├── miss_rate.py ├── mtp_loss.py ├── pi_bc.py └── utils.py ├── models ├── aggregators │ ├── aggregator.py │ ├── concat.py │ ├── global_attention.py │ ├── goal_conditioned.py │ └── pgp.py ├── decoders │ ├── covernet.py │ ├── decoder.py │ ├── lvm.py │ ├── mtp.py │ ├── multipath.py │ └── utils.py ├── encoders │ ├── encoder.py │ ├── pgp_encoder.py │ ├── pgp_scout_encoder.py │ ├── polyline_subgraph.py │ ├── raster_encoder.py │ └── scout_encoder.py ├── heterograph_models.py ├── layers.py └── model.py ├── preprocess.py ├── train.py ├── train_eval ├── evaluator.py ├── initialization.py ├── preprocessor.py ├── trainer.py ├── utils.py └── visualizer.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | #.pkl 134 | *.pkl 135 | apollo_train_data.pkl 136 | apollo_test_data.pkl 137 | 138 | 139 | #checkpoints 140 | /models_checkpoints 141 | 142 | 143 | /data 144 | *.pth 145 | *.sh 146 | *workspace 147 | *.json 148 | *.txt 149 | *env.yaml 150 | *.png 151 | 152 | *.tar 153 | *.tar.gz 154 | *.avi 155 | *.zip 156 | 157 | events* 158 | .pyc 159 | .txt 160 | *.gif 161 | wandb/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 nachiket92 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Towards Explainable Multi-modal Motion Prediction using Graph Representations 2 | [![DOI](https://zenodo.org/badge/553454432.svg)](https://zenodo.org/badge/latestdoi/553454432) 3 | 4 | 5 | This repository contains code for ["Towards Explainable Motion Prediction using Heterogeneous Graph Representations"](https://arxiv.org/pdf/2212.03806.pdf) by Sandra Carrasco Limeros, Sylwia Majchrowska, Joakim Johnander, Christoffer Petersson and David Fernández Llorca, 2022. 6 | 7 | ![](https://github.com/sancarlim/Explainable-MP/blob/main/assets/readme.gif) 8 | 9 | ```bibtex 10 | @misc{Carrasco:22b, 11 | doi = {10.48550/ARXIV.2212.03806}, 12 | url = {https://arxiv.org/abs/2212.03806}, 13 | author = {Carrasco Limeros, Sandra and Majchrowska, Sylwia 14 | and Johnander, Joakim and Petersson, Christoffer 15 | and Llorca, David Fernández}, 16 | title = {Towards Explainable Motion Prediction using Heterogeneous Graph Representations}, 17 | publisher = {arXiv}, 18 | year = {2022} 19 | } 20 | ``` 21 | Note: This repository is built on [PGP repository](https://github.com/nachiket92/PGP/tree/main/) 22 | 23 | 24 | ## Installation 25 | 26 | 1. Clone this repository 27 | 28 | 2. Set up a new conda environment 29 | ``` shell 30 | conda create --name xscout python=3.7.10 31 | ``` 32 | 33 | 3. Install dependencies 34 | ```shell 35 | conda activate xscout 36 | 37 | # nuScenes devkit 38 | pip install nuscenes-devkit 39 | 40 | # Pytorch: The code has been tested with Pytorch 1.7.1, CUDA 10.1, but should work with newer versions 41 | conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch 42 | 43 | # Additional utilities 44 | pip install ray 45 | pip install psutil 46 | pip install scipy 47 | pip install positional-encodings 48 | pip install imageio 49 | pip install tensorboard 50 | pip install dgl-cu101 51 | ``` 52 | 53 | 54 | ## Dataset 55 | 56 | 1. Download the [nuScenes dataset](https://www.nuscenes.org/download). For this project we just need the following. 57 | - Metadata for the Trainval split (v1.0) 58 | - Map expansion pack (v1.3) 59 | 60 | 2. Organize the nuScenes root directory as follows 61 | ```plain 62 | └── nuScenes/ 63 | ├── maps/ 64 | | ├── basemaps/ 65 | | ├── expansion/ 66 | | ├── prediction/ 67 | | ├── 36092f0b03a857c6a3403e25b4b7aab3.png 68 | | ├── 37819e65e09e5547b8a3ceaefba56bb2.png 69 | | ├── 53992ee3023e5494b90c316c183be829.png 70 | | └── 93406b464a165eaba6d9de76ca09f5da.png 71 | └── v1.0-trainval 72 | ├── attribute.json 73 | ├── calibrated_sensor.json 74 | ... 75 | └── visibility.json 76 | ``` 77 | 78 | 3. Run the following script to extract pre-processed data. This speeds up training significantly. 79 | ```shell 80 | python preprocess.py -c configs/preprocess_nuscenes.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data 81 | ``` 82 | You can download the preprocessed data in [this link](https://drive.google.com/file/d/1Ovf4eX4RtejyhX-hji77MjFjOUwTIdbH/view?usp=sharing). 83 | 84 | 85 | ## Evaluation 86 | 87 | You can download the trained model weights using [this link](https://drive.google.com/file/d/1i9Afa9UhOPAYbjB9nY6D-En0z8HgoEnl/view?usp=sharing). 88 | 89 | To evaluate on the nuScenes val set run the following script. This will generate a text file with evaluation metrics at the specified output directory. The results should match the [benchmark entry](https://eval.ai/web/challenges/challenge-page/591/leaderboard/1659) on Eval.ai. 90 | ```shell 91 | python evaluate.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights 92 | ``` 93 | 94 | ## Visualization 95 | 96 | To visualize predictions run the following script. This will generate gifs for a set of instance tokens (track ids) from nuScenes val at the specified output directory. 97 | ```shell 98 | python visualize.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights 99 | ``` 100 | You can indicate the number of modes and future temporal horizon to visualize with ```--num_modes``` and ```--tf``` respectively. 101 | 102 | 103 | ## Training 104 | 105 | To train the model from scratch, run 106 | ```shell 107 | python train.py -c configs/xscout_pgp.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -n 100 108 | ``` 109 | 110 | The training script will save training checkpoints and tensorboard logs in the output directory. Wandb logger is also supported. You need to specify the entity and project in the ```wandb.init``` function in ```train.py```. If you do not want to log in wandb, please use ```--nowandb``` argument. 111 | 112 | To launch tensorboard, run 113 | ```shell 114 | tensorboard --logdir=path/to/output/directory/tensorboard_logs 115 | ``` 116 | 117 | ## Robustness analysis 118 | 119 | This repository contains the code to reproduce the robustness analysis (Section IV) presented in ["Towards Trustworthy Multi-Modal Motion Prediction: Evaluation and Interpretability"]() by Sandra Carrasco, Sylwia Majchrowska,Joakim Johnander, Christoffer Petersson and David Fernández LLorca, presented at .. 2022. 120 | 121 | You can download the PGP trained model weights using [this link](https://drive.google.com/file/d/1i9Afa9UhOPAYbjB9nY6D-En0z8HgoEnl/view?usp=sharing). 122 | 123 | To evaluate on the nuScenes val set, you can indicate the probability of randomly masking out dynamic objects and/or lanes in ```agent_mask_p_veh```, ```agent_mask_p_ped``` and ```lane_mask_prob``` arguments in the configuration file ```configs/pgp_gatx2_lvm_traversal.yml``` . Indicate a probability of masking out random frames of interacting agents using ```mask_frames_p``` argument. 124 | 125 | ```shell 126 | python evaluate.py -c configs/pgp_gatx2_lvm_traversal.yml -r path/to/nuScenes/root/directory -d path/to/directory/with/preprocessed/data -o path/to/output/directory -w path/to/trained/weights 127 | ``` 128 | 129 | -------------------------------------------------------------------------------- /assets/Car TOP_VIEW 375397.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sancarlim/Explainable-MP/6419894aa040adb9570b14493952a98c0a52f803/assets/Car TOP_VIEW 375397.png -------------------------------------------------------------------------------- /assets/Car TOP_VIEW 80CBE5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sancarlim/Explainable-MP/6419894aa040adb9570b14493952a98c0a52f803/assets/Car TOP_VIEW 80CBE5.png -------------------------------------------------------------------------------- /assets/Car TOP_VIEW ROBOT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sancarlim/Explainable-MP/6419894aa040adb9570b14493952a98c0a52f803/assets/Car TOP_VIEW ROBOT.png -------------------------------------------------------------------------------- /assets/intro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sancarlim/Explainable-MP/6419894aa040adb9570b14493952a98c0a52f803/assets/intro.gif -------------------------------------------------------------------------------- /assets/readme.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sancarlim/Explainable-MP/6419894aa040adb9570b14493952a98c0a52f803/assets/readme.gif -------------------------------------------------------------------------------- /configs/pgp_gatx2_lvm_traversal.yml: -------------------------------------------------------------------------------- 1 | # Dataset and dataloader parameters 2 | dataset: 'nuScenes' 3 | version: 'v1.0-trainval' 4 | agent_setting: 'single_agent' 5 | input_representation: 'graphs' 6 | name: 'cluster 5' 7 | 8 | train_set_args: &ds_args 9 | split: 'train' 10 | t_h: 2 11 | t_f: 6 12 | map_extent: [ -50, 50, -20, 80 ] 13 | polyline_resolution: 1 14 | polyline_length: 20 15 | traversal_horizon: 15 16 | random_flips: True 17 | 18 | val_set_args: 19 | <<: *ds_args 20 | split: 'train_val' 21 | random_flips: False 22 | 23 | test_set_args: 24 | <<: *ds_args 25 | split: 'val' 26 | random_flips: False 27 | 28 | batch_size: 64 29 | num_workers: 64 30 | 31 | encoder_type: 'pgp_encoder' 32 | encoder_args: 33 | target_agent_feat_size: 5 34 | target_agent_emb_size: 16 35 | target_agent_enc_size: 32 36 | node_feat_size: 6 37 | node_emb_size: 16 38 | node_enc_size: 32 39 | nbr_feat_size: 5 40 | nbr_emb_size: 16 41 | nbr_enc_size: 32 42 | num_gat_layers: 2 43 | agent_mask_p_veh: 0 44 | agent_mask_p_ped: 0 45 | lane_mask_p: 0. 46 | mask_frames_p: 0. 47 | 48 | 49 | # Aggregator parameters 50 | aggregator_type: 'pgp' 51 | aggregator_args: 52 | pre_train: True 53 | target_agent_enc_size: 32 54 | node_enc_size: 32 55 | pi_h1_size: 32 56 | pi_h2_size: 32 57 | horizon: 15 58 | num_samples: 1000 59 | emb_size: 128 60 | num_heads: 32 61 | 62 | 63 | # Decoder parameters 64 | decoder_type: 'lvm' 65 | decoder_args: 66 | num_samples: 1000 67 | op_len: 12 68 | hidden_size: 128 69 | encoding_size: 160 70 | agg_type: 'sample_specific' 71 | lv_dim: 5 72 | num_clusters: 10 73 | 74 | # Optimizer parameters 75 | optim_args: 76 | lr: 0.001 77 | scheduler_step: 10 78 | scheduler_gamma: 0.5 79 | 80 | 81 | losses: ['min_ade_k', 'pi_bc'] 82 | loss_weights: [1, 0.5] 83 | loss_args: 84 | - k: 10 85 | - dummy: 0 86 | 87 | tr_metrics: ['min_ade_k', 'miss_rate_k', 'pi_bc'] 88 | tr_metric_args: 89 | - k: 10 90 | - k: 10 91 | dist_thresh: 2 92 | - dummy: 0 93 | 94 | val_metrics: ['min_ade_k','min_ade_k', 'miss_rate_k', 'miss_rate_k', 'pi_bc'] 95 | val_metric_args: 96 | - k: 5 97 | - k: 10 98 | - k: 5 99 | dist_thresh: 2 100 | - k: 10 101 | dist_thresh: 2 102 | - dummy: 0 103 | 104 | 105 | log_freq: 100 106 | -------------------------------------------------------------------------------- /configs/pgp_gatx2_mtp.yml: -------------------------------------------------------------------------------- 1 | # Dataset and dataloader parameters 2 | dataset: 'nuScenes' 3 | version: 'v1.0-trainval' 4 | agent_setting: 'single_agent' 5 | input_representation: 'graphs' 6 | 7 | train_set_args: &ds_args 8 | split: 'train' 9 | t_h: 2 10 | t_f: 6 11 | map_extent: [ -50, 50, -20, 80 ] 12 | polyline_resolution: 1 13 | polyline_length: 20 14 | traversal_horizon: 15 15 | random_flips: True 16 | 17 | val_set_args: 18 | <<: *ds_args 19 | split: 'train_val' 20 | random_flips: False 21 | 22 | test_set_args: 23 | <<: *ds_args 24 | split: 'val' 25 | random_flips: False 26 | 27 | batch_size: 64 28 | num_workers: 64 29 | 30 | encoder_type: 'pgp_encoder' 31 | encoder_args: 32 | target_agent_feat_size: 5 33 | target_agent_emb_size: 16 34 | target_agent_enc_size: 32 35 | node_feat_size: 6 36 | node_emb_size: 16 37 | node_enc_size: 32 38 | nbr_feat_size: 5 39 | nbr_emb_size: 16 40 | nbr_enc_size: 32 41 | num_gat_layers: 2 42 | 43 | 44 | # Aggregator parameters 45 | aggregator_type: 'pgp' 46 | aggregator_args: 47 | pre_train: True 48 | target_agent_enc_size: 32 49 | node_enc_size: 32 50 | pi_h1_size: 32 51 | pi_h2_size: 32 52 | horizon: 15 53 | num_samples: 1000 54 | emb_size: 128 55 | num_heads: 32 56 | 57 | # Aggregator parameters 58 | aggregator_type: 'global_attention' 59 | aggregator_args: 60 | target_agent_enc_size: 32 61 | context_enc_size: 32 62 | emb_size: 128 63 | num_heads: 32 64 | 65 | 66 | # Decoder parameters 67 | decoder_type: 'mtp' 68 | decoder_args: 69 | num_modes: 5 70 | op_len: 12 71 | use_variance: True 72 | hidden_size: 128 73 | encoding_size: 160 74 | agg_type: 'sample_specific' 75 | lv_dim: 5 76 | num_clusters: 10 77 | 78 | # Optimizer parameters 79 | optim_args: 80 | lr: 0.001 81 | scheduler_step: 10 82 | scheduler_gamma: 0.5 83 | 84 | 85 | losses: ['min_ade_k', 'pi_bc'] 86 | loss_weights: [1, 0.5] 87 | loss_args: 88 | - k: 10 89 | - dummy: 0 90 | 91 | tr_metrics: ['min_ade_k', 'miss_rate_k', 'pi_bc'] 92 | tr_metric_args: 93 | - k: 10 94 | - k: 10 95 | dist_thresh: 2 96 | - dummy: 0 97 | 98 | val_metrics: ['min_ade_k','min_ade_k', 'miss_rate_k', 'miss_rate_k', 'pi_bc'] 99 | val_metric_args: 100 | - k: 5 101 | - k: 10 102 | - k: 5 103 | dist_thresh: 2 104 | - k: 10 105 | dist_thresh: 2 106 | - dummy: 0 107 | 108 | 109 | log_freq: 100 110 | -------------------------------------------------------------------------------- /configs/preprocess_nuscenes.yml: -------------------------------------------------------------------------------- 1 | dataset: 'nuScenes' 2 | version: 'v1.0-trainval' 3 | agent_setting: 'single_agent' 4 | input_representation: 'graphs' 5 | 6 | train_set_args: 7 | split: 'train' 8 | t_h: 2 9 | t_f: 6 10 | map_extent: [-50, 50, -20, 80] 11 | polyline_resolution: 1 12 | polyline_length: 20 13 | traversal_horizon: 15 14 | 15 | val_set_args: 16 | split: 'train_val' 17 | t_h: 2 18 | t_f: 6 19 | map_extent: [-50, 50, -20, 80] 20 | polyline_resolution: 1 21 | polyline_length: 20 22 | traversal_horizon: 15 23 | 24 | test_set_args: 25 | split: 'val' 26 | t_h: 2 27 | t_f: 6 28 | map_extent: [-50, 50, -20, 80] 29 | polyline_resolution: 1 30 | polyline_length: 20 31 | traversal_horizon: 15 32 | 33 | batch_size: 64 34 | num_workers: 128 35 | verbosity: True 36 | -------------------------------------------------------------------------------- /configs/xscout_pgp.yml: -------------------------------------------------------------------------------- 1 | # Dataset and dataloader parameters 2 | dataset: 'nuScenes' 3 | version: 'v1.0-trainval' 4 | agent_setting: 'single_agent' 5 | input_representation: 'graphs' 6 | name: 'HGCN 64bs (2l pre-train)' 7 | 8 | train_set_args: &ds_args 9 | split: 'train' 10 | t_h: 2 11 | t_f: 6 12 | map_extent: [ -50, 50, -20, 80 ] 13 | polyline_resolution: 1 14 | polyline_length: 20 15 | traversal_horizon: 15 16 | random_flips: True 17 | 18 | val_set_args: 19 | <<: *ds_args 20 | split: 'train_val' 21 | random_flips: False 22 | 23 | test_set_args: 24 | <<: *ds_args 25 | split: 'val' 26 | random_flips: False 27 | 28 | batch_size: 64 29 | num_workers: 128 30 | 31 | encoder_type: 'pgp_scout_encoder' 32 | encoder_args: 33 | target_agent_feat_size: 5 34 | target_agent_emb_size: 16 35 | target_agent_enc_size: 32 36 | num_heads_lanes: [1, 1, 1, 1, 1] 37 | feat_drop: 0.0 38 | attn_drop: 0.0 39 | num_layers: 2 40 | node_feat_size: 6 41 | node_emb_size: 16 42 | node_enc_size: 32 43 | node_attn_size: 32 44 | node_out_hgt_size: 32 45 | nbr_feat_size: 5 46 | nbr_emb_size: 16 47 | nbr_enc_size: 32 48 | num_gat_layers: 2 49 | num_heads: 1 50 | hg: "hgcn" 51 | agent_mask_p_veh: 0 52 | lane_mask_p: 0. 53 | mask_frames_p: 0. 54 | 55 | 56 | # Aggregator parameters 57 | aggregator_type: 'pgp' 58 | aggregator_args: 59 | pre_train: False 60 | target_agent_enc_size: 64 61 | node_enc_size: 32 62 | pi_h1_size: 32 63 | pi_h2_size: 32 64 | horizon: 15 65 | num_samples: 1000 66 | emb_size: 128 67 | num_heads: 32 68 | 69 | 70 | # Decoder parameters 71 | decoder_type: 'lvm' 72 | decoder_args: 73 | num_samples: 1000 74 | op_len: 12 75 | hidden_size: 128 76 | encoding_size: 192 77 | agg_type: 'sample_specific' 78 | lv_dim: 5 79 | num_clusters: 10 80 | 81 | # Optimizer parameters 82 | optim_args: 83 | lr: 0.001 84 | scheduler_step: 10 85 | scheduler_gamma: 0.5 86 | 87 | 88 | losses: ['min_ade_k', 'pi_bc'] 89 | loss_weights: [1, 0.5] 90 | loss_args: 91 | - k: 10 92 | - dummy: 0 93 | 94 | tr_metrics: ['min_ade_k', 'miss_rate_k', 'pi_bc'] 95 | tr_metric_args: 96 | - k: 10 97 | - k: 10 98 | dist_thresh: 2 99 | - dummy: 0 100 | 101 | val_metrics: ['min_ade_k','min_ade_k', 'miss_rate_k', 'miss_rate_k', 'pi_bc'] 102 | val_metric_args: 103 | - k: 5 104 | - k: 10 105 | - k: 5 106 | dist_thresh: 2 107 | - k: 10 108 | dist_thresh: 2 109 | - dummy: 0 110 | 111 | 112 | log_freq: 100 113 | -------------------------------------------------------------------------------- /datasets/interface.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch.utils.data as torch_data 3 | import numpy as np 4 | from typing import Union, Dict 5 | import os 6 | 7 | 8 | class TrajectoryDataset(torch_data.Dataset): 9 | """ 10 | Base class for trajectory datasets. 11 | """ 12 | 13 | def __init__(self, mode: str, data_dir: str): 14 | """ 15 | Initialize trajectory dataset. 16 | :param mode: Mode of operation of dataset 17 | :param data_dir: Directory to store extracted pre-processed data 18 | """ 19 | if mode not in ['compute_stats', 'extract_data', 'load_data']: 20 | raise Exception('Dataset mode needs to be one of {compute_stats, extract_data or load_data}') 21 | self.mode = mode 22 | self.data_dir = data_dir 23 | if mode != 'load_data' and not os.path.isdir(self.data_dir): 24 | os.mkdir(self.data_dir) 25 | 26 | @abc.abstractmethod 27 | def __len__(self) -> int: 28 | """ 29 | Returns size of dataset 30 | """ 31 | raise NotImplementedError() 32 | 33 | def __getitem__(self, idx: int) -> Union[Dict, int]: 34 | """ 35 | Get data point, based on mode of operation of dataset. 36 | :param idx: data index 37 | """ 38 | if self.mode == 'compute_stats': 39 | return self.compute_stats(idx) 40 | elif self.mode == 'extract_data': 41 | self.extract_data(idx) 42 | return 0 43 | else: 44 | return self.load_data(idx) 45 | 46 | @abc.abstractmethod 47 | def compute_stats(self, idx: int) -> Dict: 48 | """ 49 | Function to compute dataset statistics like max surrounding agents, max nodes, max edges etc. 50 | :param idx: data index 51 | """ 52 | raise NotImplementedError() 53 | 54 | def extract_data(self, idx: int): 55 | """ 56 | Function to extract data. Bulk of the dataset functionality will be implemented by this method. 57 | :param idx: data index 58 | """ 59 | inputs = self.get_inputs(idx) 60 | ground_truth = self.get_ground_truth(idx) 61 | data = {'inputs': inputs, 'ground_truth': ground_truth} 62 | self.save_data(idx, data) 63 | 64 | @abc.abstractmethod 65 | def get_inputs(self, idx: int) -> Dict: 66 | """ 67 | Extracts model inputs. 68 | :param idx: data index 69 | """ 70 | raise NotImplementedError() 71 | 72 | @abc.abstractmethod 73 | def get_ground_truth(self, idx: int) -> Dict: 74 | """ 75 | Extracts ground truth 'labels' for training. 76 | :param idx: data index 77 | """ 78 | raise NotImplementedError() 79 | 80 | @abc.abstractmethod 81 | def load_data(self, idx: int) -> Dict: 82 | """ 83 | Function to load extracted data. 84 | :param idx: data index 85 | :return data: Dictionary with pre-processed data 86 | """ 87 | raise NotImplementedError() 88 | 89 | @abc.abstractmethod 90 | def save_data(self, idx: int, data: Dict): 91 | """ 92 | Function to save extracted pre-processed data. 93 | :param idx: data index 94 | :param data: Dictionary with pre-processed data 95 | """ 96 | raise NotImplementedError() 97 | 98 | 99 | class SingleAgentDataset(TrajectoryDataset): 100 | """ 101 | Base class for single agent dataset. While we implicitly model all surrounding agents in the scene, predictions 102 | are made for a single target agent at a time. 103 | """ 104 | 105 | @abc.abstractmethod 106 | def get_map_representation(self, idx: int) -> Union[np.ndarray, Dict]: 107 | """ 108 | Extracts map representation 109 | :param idx: data index 110 | """ 111 | raise NotImplementedError() 112 | 113 | @abc.abstractmethod 114 | def get_surrounding_agent_representation(self, idx: int) -> Union[np.ndarray, Dict]: 115 | """ 116 | Extracts surrounding agent representation 117 | :param idx: data index 118 | """ 119 | raise NotImplementedError() 120 | 121 | @abc.abstractmethod 122 | def get_target_agent_representation(self, idx: int) -> Union[np.ndarray, Dict]: 123 | """ 124 | Extracts target agent representation 125 | :param idx: data index 126 | """ 127 | raise NotImplementedError() 128 | 129 | @abc.abstractmethod 130 | def get_target_agent_future(self, idx: int) -> Union[np.ndarray, Dict]: 131 | """ 132 | Extracts future trajectory for target agent 133 | :param idx: data index 134 | """ 135 | raise NotImplementedError() 136 | -------------------------------------------------------------------------------- /datasets/nuScenes/nuScenes.py: -------------------------------------------------------------------------------- 1 | from datasets.interface import SingleAgentDataset 2 | from nuscenes.eval.prediction.splits import get_prediction_challenge_split 3 | from nuscenes.prediction import PredictHelper 4 | import numpy as np 5 | from typing import Dict, Union 6 | import abc 7 | import os 8 | import pickle 9 | 10 | 11 | class NuScenesTrajectories(SingleAgentDataset): 12 | """ 13 | NuScenes dataset class for single agent prediction 14 | """ 15 | 16 | def __init__(self, mode: str, data_dir: str, args: Dict, helper: PredictHelper): 17 | """ 18 | Initialize predict helper, agent and scene representations 19 | :param mode: Mode of operation of dataset, one of {'compute_stats', 'extract_data', 'load_data'} 20 | :param data_dir: Directory to store extracted pre-processed data 21 | :param helper: NuScenes PredictHelper 22 | :param args: Dataset arguments 23 | """ 24 | super().__init__(mode, data_dir) 25 | self.helper = helper 26 | 27 | # nuScenes sample and instance tokens for prediction challenge 28 | self.token_list = get_prediction_challenge_split(args['split'], dataroot=helper.data.dataroot) 29 | 30 | # Past and prediction horizons 31 | self.t_h = args['t_h'] 32 | self.t_f = args['t_f'] 33 | 34 | def __len__(self): 35 | """ 36 | Size of dataset 37 | """ 38 | return len(self.token_list) 39 | 40 | def get_inputs(self, idx: int) -> Dict: 41 | """ 42 | Gets model inputs for nuScenes single agent prediction 43 | :param idx: data index 44 | :return inputs: Dictionary with input representations 45 | """ 46 | i_t, s_t = self.token_list[idx].split("_") 47 | map_representation = self.get_map_representation(idx) 48 | surrounding_agent_representation = self.get_surrounding_agent_representation(idx) 49 | target_agent_representation = self.get_target_agent_representation(idx) 50 | inputs = {'instance_token': i_t, 51 | 'sample_token': s_t, 52 | 'map_representation': map_representation, 53 | 'surrounding_agent_representation': surrounding_agent_representation, 54 | 'target_agent_representation': target_agent_representation} 55 | return inputs 56 | 57 | def get_ground_truth(self, idx: int) -> Dict: 58 | """ 59 | Gets ground truth labels for nuScenes single agent prediction 60 | :param idx: data index 61 | :return ground_truth: Dictionary with grund truth labels 62 | """ 63 | target_agent_future = self.get_target_agent_future(idx) 64 | ground_truth = {'traj': target_agent_future} 65 | return ground_truth 66 | 67 | def save_data(self, idx: int, data: Dict): 68 | """ 69 | Saves extracted pre-processed data 70 | :param idx: data index 71 | :param data: pre-processed data 72 | """ 73 | filename = os.path.join(self.data_dir, self.token_list[idx] + '.pickle') 74 | with open(filename, 'wb') as handle: 75 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 76 | 77 | def load_data(self, idx: int) -> Dict: 78 | """ 79 | Function to load extracted data. 80 | :param idx: data index 81 | :return data: Dictionary with batched tensors 82 | """ 83 | filename = os.path.join(self.data_dir, self.token_list[idx] + '.pickle') 84 | 85 | if not os.path.isfile(filename): 86 | raise Exception('Could not find data. Please run the dataset in extract_data mode') 87 | 88 | with open(filename, 'rb') as handle: 89 | data = pickle.load(handle) 90 | return data 91 | 92 | def get_target_agent_future(self, idx: int) -> np.ndarray: 93 | """ 94 | Extracts future trajectory for target agent 95 | :param idx: data index 96 | :return fut: future trajectory for target agent, shape: [t_f * 2, 2] 97 | """ 98 | i_t, s_t = self.token_list[idx].split("_") 99 | fut = self.helper.get_future_for_agent(i_t, s_t, seconds=self.t_f, in_agent_frame=True) 100 | 101 | return fut 102 | 103 | @abc.abstractmethod 104 | def get_target_agent_representation(self, idx: int) -> Union[np.ndarray, Dict]: 105 | """ 106 | Extracts target agent representation 107 | :param idx: data index 108 | """ 109 | raise NotImplementedError() 110 | 111 | @abc.abstractmethod 112 | def get_map_representation(self, idx: int) -> Union[np.ndarray, Dict]: 113 | """ 114 | Extracts map representation 115 | :param idx: data index 116 | """ 117 | raise NotImplementedError() 118 | 119 | @abc.abstractmethod 120 | def get_surrounding_agent_representation(self, idx: int) -> Union[np.ndarray, Dict]: 121 | """ 122 | Extracts surrounding agent representation 123 | :param idx: data index 124 | """ 125 | raise NotImplementedError() 126 | -------------------------------------------------------------------------------- /datasets/nuScenes/nuScenes_graphs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from datasets.nuScenes.nuScenes_vector import NuScenesVector 3 | from nuscenes.prediction.input_representation.static_layers import correct_yaw , get_lanes_for_agent 4 | from nuscenes.map_expansion.map_api import NuScenesMap 5 | from nuscenes.prediction import PredictHelper 6 | import numpy as np 7 | from typing import Dict, Tuple, Union, List 8 | from scipy.spatial.distance import cdist 9 | 10 | 11 | class NuScenesGraphs(NuScenesVector): 12 | """ 13 | NuScenes dataset class for single agent prediction, using the graph representation from PGP for maps and agents 14 | """ 15 | 16 | def __init__(self, mode: str, data_dir: str, args: Dict, helper: PredictHelper): 17 | """ 18 | Initialize predict helper, agent and scene representations 19 | :param mode: Mode of operation of dataset, one of {'compute_stats', 'extract_data', 'load_data'} 20 | :param data_dir: Directory to store extracted pre-processed data 21 | :param helper: NuScenes PredictHelper 22 | :param args: Dataset arguments 23 | """ 24 | super().__init__(mode, data_dir, args, helper) 25 | self.traversal_horizon = args['traversal_horizon'] 26 | 27 | # Load dataset stats (max nodes, max agents etc.) 28 | if self.mode == 'extract_data': 29 | stats = self.load_stats() 30 | self.max_nbr_nodes = stats['max_nbr_nodes'] 31 | 32 | def compute_stats(self, idx: int) -> Dict[str, int]: 33 | """ 34 | Function to compute statistics for a given data point 35 | """ 36 | num_lane_nodes, max_nbr_nodes = self.get_map_representation(idx) 37 | num_vehicles, num_pedestrians = self.get_surrounding_agent_representation(idx) 38 | stats = { 39 | 'num_lane_nodes': num_lane_nodes, 40 | 'max_nbr_nodes': max_nbr_nodes, 41 | 'num_vehicles': num_vehicles, 42 | 'num_pedestrians': num_pedestrians 43 | } 44 | 45 | return stats 46 | 47 | def extract_data(self, idx: int): 48 | """ 49 | Function to extract data. Bulk of the dataset functionality will be implemented by this method. 50 | :param idx: data index 51 | """ 52 | inputs = self.get_inputs(idx) 53 | ground_truth = self.get_ground_truth(idx) 54 | node_seq_gt, evf_gt = self.get_visited_edges(idx, inputs['map_representation']) 55 | init_node = self.get_initial_node(inputs['map_representation']) 56 | 57 | ground_truth['evf_gt'] = evf_gt 58 | inputs['init_node'] = init_node 59 | inputs['node_seq_gt'] = node_seq_gt # For pretraining with ground truth node sequence 60 | data = {'inputs': inputs, 'ground_truth': ground_truth} 61 | self.save_data(idx, data) 62 | 63 | def get_inputs(self, idx: int) -> Dict: 64 | inputs = super().get_inputs(idx) 65 | a_n_masks = self.get_agent_node_masks(inputs['map_representation'], inputs['surrounding_agent_representation']) 66 | inputs['agent_node_masks'] = a_n_masks 67 | return inputs 68 | 69 | def get_ground_truth(self, idx: int) -> Dict: 70 | ground_truth = super().get_ground_truth(idx) 71 | return ground_truth 72 | 73 | 74 | def get_map_representation(self, idx: int) -> Union[Tuple[int, int], Dict]: 75 | """ 76 | Extracts map representation 77 | :param idx: data index 78 | :return: Returns an ndarray with lane node features, shape [max_nodes, polyline_length, 5] and an ndarray of 79 | masks of the same shape, with value 1 if the nodes/poses are empty, 80 | """ 81 | i_t, s_t = self.token_list[idx].split("_") 82 | map_name = self.helper.get_map_name_from_sample_token(s_t) 83 | map_api = self.maps[map_name] 84 | 85 | # Get agent representation in global co-ordinates 86 | global_pose = self.get_target_agent_global_pose(idx) 87 | 88 | # Get path candidates 89 | # paths_ids = {i_t: self.get_lanes_for_agent(global_pose[0],global_pose[1],global_pose[2],map_api) } 90 | # paths_vectors = self.get_path(paths_ids[i_t], map_api) 91 | # paths_ids, paths_vectors = self.split_lanes(paths_vectors, self.polyline_length, paths_ids) 92 | 93 | # Get lanes around agent within map_extent 94 | lanes = self.get_lanes_around_agent(global_pose, map_api) 95 | 96 | # Get relevant polygon layers from the map_api 97 | polygons = self.get_polygons_around_agent(global_pose, map_api) 98 | 99 | # Get vectorized represlane_idsentation of lanes 100 | lane_node_feats, lane_ids = self.get_lane_node_feats(global_pose, lanes, polygons, map_api) 101 | 102 | # Discard lanes outside map extent 103 | lane_node_feats, lane_ids = self.discard_poses_outside_extent(lane_node_feats, lane_ids) 104 | 105 | # Get edges: 106 | e_succ = self.get_successor_edges(lane_ids, map_api) 107 | e_prox = self.get_proximal_edges(lane_node_feats, e_succ) 108 | 109 | # Concatentate flag indicating whether a node hassss successors to lane node feats 110 | lane_node_feats = self.add_boundary_flag(e_succ, lane_node_feats) 111 | 112 | # Add dummy node (0, 0, 0, 0, 0, 0) if no lane nodes are found 113 | if len(lane_node_feats) == 0: 114 | lane_node_feats = [np.zeros((1, 8))] 115 | e_succ = [[]] 116 | e_prox = [[]] 117 | 118 | # While running the dataset class in 'compute_stats' mode: 119 | if self.mode == 'compute_stats': 120 | 121 | num_nbrs = [len(e_succ[i]) + len(e_prox[i]) for i in range(len(e_succ))] 122 | max_nbrs = max(num_nbrs) if len(num_nbrs) > 0 else 0 123 | num_nodes = len(lane_node_feats) 124 | 125 | return num_nodes, max_nbrs 126 | 127 | # Get edge lookup tables 128 | s_next, edge_type = self.get_edge_lookup(e_succ, e_prox) 129 | 130 | # Build adjacency matrix for heterograph - treating succ and prox edges separately and directional 131 | succ_adj_matrix, prox_adj_matrix = self.build_adj_mat_directional_with_types(s_next, edge_type) 132 | 133 | # Convert list of lane node feats to fixed size numpy array and masks 134 | lane_node_feats, lane_node_masks = self.list_to_tensor(lane_node_feats, self.max_nodes, self.polyline_length, 8) 135 | 136 | map_representation = { 137 | 'lane_node_feats': lane_node_feats, 138 | 'lane_node_masks': lane_node_masks, 139 | 's_next': s_next, 140 | 'edge_type': edge_type, 141 | 'succ_adj_matrix': succ_adj_matrix, 142 | 'prox_adj_matrix': prox_adj_matrix 143 | } 144 | 145 | return map_representation 146 | 147 | @staticmethod 148 | def build_adj_mat_directional_with_types(edges, edges_type): 149 | # Given edges [number of nodes, edges per node] create adjacency matrix [B, number of nodes, number of nodes] 150 | edges_succ_adj = np.zeros((edges.shape[0], edges.shape[0])) 151 | edges_prox_adj = np.zeros((edges.shape[0], edges.shape[0])) 152 | succ_u = np.array([]) 153 | succ_v = np.array([]) 154 | prox_u = np.array([]) 155 | prox_v = np.array([]) 156 | for i in range(edges.shape[0]): 157 | for j in range(edges.shape[1]-1): 158 | if edges_type[i,j] == 1: 159 | succ_u = np.append(succ_u, i) 160 | succ_v = np.append(succ_v, j) 161 | edges_succ_adj[i,int(edges[i,j])] = 1 162 | elif edges_type[i,j] == 2: 163 | prox_u = np.append(prox_u, i) 164 | prox_v = np.append(prox_v, j) 165 | edges_prox_adj[i,int(edges[i,j])] = 1 166 | if np.count_nonzero(edges_succ_adj[i]) == 0 and edges_type[i, -1] == 3: 167 | succ_u = np.append(succ_u, i) 168 | succ_v = np.append(succ_v, i) 169 | edges_succ_adj[i,i] = 1 170 | return edges_succ_adj, edges_prox_adj 171 | 172 | @staticmethod 173 | def get_successor_edges(lane_ids: List[str], map_api: NuScenesMap) -> List[List[int]]: 174 | """ 175 | Returns successor edge list for each node 176 | """ 177 | e_succ = [] 178 | for node_id, lane_id in enumerate(lane_ids): 179 | e_succ_node = [] 180 | if node_id + 1 < len(lane_ids) and lane_id == lane_ids[node_id + 1]: 181 | e_succ_node.append(node_id + 1) 182 | else: 183 | outgoing_lane_ids = map_api.get_outgoing_lane_ids(lane_id) 184 | for outgoing_id in outgoing_lane_ids: 185 | if outgoing_id in lane_ids: 186 | e_succ_node.append(lane_ids.index(outgoing_id)) 187 | 188 | e_succ.append(e_succ_node) 189 | 190 | return e_succ 191 | 192 | @staticmethod 193 | def get_proximal_edges(lane_node_feats: List[np.ndarray], e_succ: List[List[int]], 194 | dist_thresh=4, yaw_thresh=np.pi/4) -> List[List[int]]: 195 | """ 196 | Returns proximal edge list for each node 197 | """ 198 | e_prox = [[] for _ in lane_node_feats] 199 | for src_node_id, src_node_feats in enumerate(lane_node_feats): 200 | for dest_node_id in range(src_node_id + 1, len(lane_node_feats)): 201 | if dest_node_id not in e_succ[src_node_id] and src_node_id not in e_succ[dest_node_id]: 202 | dest_node_feats = lane_node_feats[dest_node_id] 203 | pairwise_dist = cdist(src_node_feats[:, :2], dest_node_feats[:, :2]) 204 | min_dist = np.min(pairwise_dist) 205 | if min_dist <= dist_thresh: 206 | yaw_src = np.arctan2(np.mean(np.sin(src_node_feats[:, 2])), 207 | np.mean(np.cos(src_node_feats[:, 2]))) 208 | yaw_dest = np.arctan2(np.mean(np.sin(dest_node_feats[:, 2])), 209 | np.mean(np.cos(dest_node_feats[:, 2]))) 210 | yaw_diff = np.arctan2(np.sin(yaw_src-yaw_dest), np.cos(yaw_src-yaw_dest)) 211 | if np.absolute(yaw_diff) <= yaw_thresh: 212 | e_prox[src_node_id].append(dest_node_id) 213 | e_prox[dest_node_id].append(src_node_id) 214 | 215 | return e_prox 216 | 217 | @staticmethod 218 | def add_boundary_flag(e_succ: List[List[int]], lane_node_feats: np.ndarray): 219 | """ 220 | Adds a binary flag to lane node features indicating whether the lane node has any successors. 221 | Serves as an indicator for boundary nodes. 222 | """ 223 | for n, lane_node_feat_array in enumerate(lane_node_feats): 224 | flag = 1 if len(e_succ[n]) == 0 else 0 225 | lane_node_feats[n] = np.concatenate((lane_node_feat_array, flag * np.ones((len(lane_node_feat_array), 1))), 226 | axis=1) 227 | 228 | return lane_node_feats 229 | 230 | def get_edge_lookup(self, e_succ: List[List[int]], e_prox: List[List[int]]): 231 | """ 232 | Returns edge look up tables 233 | :param e_succ: Lists of successor edges for each node 234 | :param e_prox: Lists of proximal edges for each node 235 | :return: 236 | 237 | s_next: Look-up table mapping source node to destination node for each edge. Each row corresponds to 238 | a source node, with entries corresponding to destination nodes. Last entry is always a terminal edge to a goal 239 | state at that node. shape: [max_nodes, max_nbr_nodes + 1]. Last 240 | 241 | edge_type: Look-up table of the same shape as s_next containing integer values for edge types. 242 | {0: No edge exists, 1: successor edge, 2: proximal edge, 3: terminal edge} 243 | """ 244 | 245 | s_next = np.zeros((self.max_nodes, self.max_nbr_nodes + 1)) 246 | edge_type = np.zeros((self.max_nodes, self.max_nbr_nodes + 1), dtype=int) 247 | 248 | for src_node in range(len(e_succ)): 249 | nbr_idx = 0 250 | successors = e_succ[src_node] 251 | prox_nodes = e_prox[src_node] 252 | 253 | # Populate successor edges 254 | for successor in successors: 255 | s_next[src_node, nbr_idx] = successor 256 | edge_type[src_node, nbr_idx] = 1 257 | nbr_idx += 1 258 | 259 | # Populate proximal edges 260 | for prox_node in prox_nodes: 261 | s_next[src_node, nbr_idx] = prox_node 262 | edge_type[src_node, nbr_idx] = 2 263 | nbr_idx += 1 264 | 265 | # Populate terminal edge 266 | s_next[src_node, -1] = src_node + self.max_nodes 267 | edge_type[src_node, -1] = 3 268 | 269 | return s_next, edge_type 270 | 271 | def get_initial_node(self, lane_graph: Dict) -> np.ndarray: 272 | """ 273 | Returns initial node probabilities for initializing the graph traversal policy 274 | :param lane_graph: lane graph dictionary with lane node features and edge look-up tables 275 | """ 276 | 277 | # Unpack lane node poses 278 | node_feats = lane_graph['lane_node_feats'] 279 | node_feat_lens = np.sum(1 - lane_graph['lane_node_masks'][:, :, 0], axis=1) 280 | node_poses = [] 281 | for i, node_feat in enumerate(node_feats): 282 | if node_feat_lens[i] != 0: 283 | node_poses.append(node_feat[:int(node_feat_lens[i]), :3]) 284 | 285 | assigned_nodes = self.assign_pose_to_node(node_poses, np.asarray([0, 0, 0]), dist_thresh=3, 286 | yaw_thresh=np.pi / 4, return_multiple=True) 287 | 288 | init_node = np.zeros(self.max_nodes) 289 | init_node[assigned_nodes] = 1/len(assigned_nodes) 290 | return init_node 291 | 292 | def get_visited_edges(self, idx: int, lane_graph: Dict) -> Tuple[np.ndarray, np.ndarray]: 293 | """ 294 | Returns nodes and edges of the lane graph visited by the actual target vehicle in the future. This serves as 295 | ground truth for training the graph traversal policy pi_route. 296 | 297 | :param idx: dataset index 298 | :param lane_graph: lane graph dictionary with lane node features and edge look-up tables 299 | :return: node_seq: Sequence of visited node ids. 300 | evf: Look-up table of visited edges. 301 | """ 302 | 303 | # Unpack lane graph dictionary 304 | node_feats = lane_graph['lane_node_feats'] 305 | s_next = lane_graph['s_next'] 306 | edge_type = lane_graph['edge_type'] 307 | 308 | node_feat_lens = np.sum(1 - lane_graph['lane_node_masks'][:, :, 0], axis=1) 309 | node_poses = [] 310 | for i, node_feat in enumerate(node_feats): 311 | if node_feat_lens[i] != 0: 312 | node_poses.append(node_feat[:int(node_feat_lens[i]), :3]) 313 | 314 | # Initialize outputs 315 | current_step = 0 316 | node_seq = np.zeros(self.traversal_horizon) 317 | evf = np.zeros_like(s_next) 318 | 319 | # Get future trajectory: 320 | i_t, s_t = self.token_list[idx].split("_") 321 | fut_xy = self.helper.get_future_for_agent(i_t, s_t, 6, True) 322 | fut_interpolated = np.zeros((fut_xy.shape[0] * 10 + 1, 2)) 323 | param_query = np.linspace(0, fut_xy.shape[0], fut_xy.shape[0] * 10 + 1) 324 | param_given = np.linspace(0, fut_xy.shape[0], fut_xy.shape[0] + 1) 325 | val_given_x = np.concatenate(([0], fut_xy[:, 0])) 326 | val_given_y = np.concatenate(([0], fut_xy[:, 1])) 327 | fut_interpolated[:, 0] = np.interp(param_query, param_given, val_given_x) 328 | fut_interpolated[:, 1] = np.interp(param_query, param_given, val_given_y) 329 | fut_xy = fut_interpolated 330 | 331 | # Compute yaw values for future: 332 | fut_yaw = np.zeros(len(fut_xy)) 333 | for n in range(1, len(fut_yaw)): 334 | fut_yaw[n] = -np.arctan2(fut_xy[n, 0] - fut_xy[n-1, 0], fut_xy[n, 1] - fut_xy[n-1, 1]) 335 | 336 | # Loop over future trajectory poses 337 | query_pose = np.asarray([fut_xy[0, 0], fut_xy[0, 1], fut_yaw[0]]) 338 | current_node = self.assign_pose_to_node(node_poses, query_pose) 339 | node_seq[current_step] = current_node 340 | for n in range(1, len(fut_xy)): 341 | query_pose = np.asarray([fut_xy[n, 0], fut_xy[n, 1], fut_yaw[n]]) 342 | dist_from_current_node = np.min(np.linalg.norm(node_poses[current_node][:, :2] - query_pose[:2], axis=1)) 343 | 344 | # If pose has deviated sufficiently from current node and is within area of interest, assign to a new node 345 | padding = self.polyline_length * self.polyline_resolution / 2 346 | if self.map_extent[0] - padding <= query_pose[0] <= self.map_extent[1] + padding and \ 347 | self.map_extent[2] - padding <= query_pose[1] <= self.map_extent[3] + padding: 348 | 349 | if dist_from_current_node >= 1.5: 350 | assigned_node = self.assign_pose_to_node(node_poses, query_pose) 351 | 352 | # Assign new node to node sequence and edge to visited edges 353 | if assigned_node != current_node: 354 | 355 | if assigned_node in s_next[current_node]: 356 | nbr_idx = np.where(s_next[current_node] == assigned_node)[0] 357 | nbr_valid = np.where(edge_type[current_node] > 0)[0] 358 | nbr_idx = np.intersect1d(nbr_idx, nbr_valid) 359 | 360 | if edge_type[current_node, nbr_idx] > 0: 361 | evf[current_node, nbr_idx] = 1 362 | 363 | current_node = assigned_node 364 | if current_step < self.traversal_horizon-1: 365 | current_step += 1 366 | node_seq[current_step] = current_node 367 | 368 | else: 369 | break 370 | 371 | # Assign goal node and edge 372 | goal_node = current_node + self.max_nodes 373 | node_seq[current_step + 1:] = goal_node 374 | evf[current_node, -1] = 1 375 | 376 | return node_seq, evf 377 | 378 | @staticmethod 379 | def assign_pose_to_node(node_poses, query_pose, dist_thresh=5, yaw_thresh=np.pi/3, return_multiple=False): 380 | """ 381 | Assigns a given agent pose to a lane node. Takes into account distance from the lane centerline as well as 382 | direction of motion. 383 | """ 384 | dist_vals = [] 385 | yaw_diffs = [] 386 | 387 | for i in range(len(node_poses)): 388 | distances = np.linalg.norm(node_poses[i][:, :2] - query_pose[:2], axis=1) 389 | dist_vals.append(np.min(distances)) 390 | idx = np.argmin(distances) 391 | yaw_lane = node_poses[i][idx, 2] 392 | yaw_query = query_pose[2] 393 | yaw_diffs.append(np.arctan2(np.sin(yaw_lane - yaw_query), np.cos(yaw_lane - yaw_query))) 394 | 395 | idcs_yaw = np.where(np.absolute(np.asarray(yaw_diffs)) <= yaw_thresh)[0] 396 | idcs_dist = np.where(np.asarray(dist_vals) <= dist_thresh)[0] 397 | idcs = np.intersect1d(idcs_dist, idcs_yaw) 398 | 399 | if len(idcs) > 0: 400 | if return_multiple: 401 | return idcs 402 | assigned_node_id = idcs[int(np.argmin(np.asarray(dist_vals)[idcs]))] 403 | else: 404 | assigned_node_id = np.argmin(np.asarray(dist_vals)) 405 | if return_multiple: 406 | assigned_node_id = np.asarray([assigned_node_id]) 407 | 408 | return assigned_node_id 409 | 410 | @staticmethod 411 | def get_agent_node_masks(hd_map: Dict, agents: Dict, dist_thresh=10) -> Dict: 412 | """ 413 | Returns key/val masks for agent-node attention layers. All agents except those within a distance threshold of 414 | the lane node are masked. The idea is to incorporate local agent context at each lane node. 415 | """ 416 | 417 | lane_node_feats = hd_map['lane_node_feats'] 418 | lane_node_masks = hd_map['lane_node_masks'] 419 | vehicle_feats = agents['vehicles'] 420 | vehicle_masks = agents['vehicle_masks'] 421 | ped_feats = agents['pedestrians'] 422 | ped_masks = agents['pedestrian_masks'] 423 | 424 | vehicle_node_masks = np.ones((len(lane_node_feats), len(vehicle_feats))) 425 | ped_node_masks = np.ones((len(lane_node_feats), len(ped_feats))) 426 | 427 | for i, node_feat in enumerate(lane_node_feats): 428 | if (lane_node_masks[i] == 0).any(): 429 | node_pose_idcs = np.where(lane_node_masks[i][:, 0] == 0)[0] 430 | node_locs = node_feat[node_pose_idcs, :2] 431 | 432 | for j, vehicle_feat in enumerate(vehicle_feats): 433 | if (vehicle_masks[j] == 0).any(): 434 | vehicle_loc = vehicle_feat[-1, :2] 435 | dist = np.min(np.linalg.norm(node_locs - vehicle_loc, axis=1)) 436 | if dist <= dist_thresh: 437 | vehicle_node_masks[i, j] = 0 438 | 439 | for j, ped_feat in enumerate(ped_feats): 440 | if (ped_masks[j] == 0).any(): 441 | ped_loc = ped_feat[-1, :2] 442 | dist = np.min(np.linalg.norm(node_locs - ped_loc, axis=1)) 443 | if dist <= dist_thresh: 444 | ped_node_masks[i, j] = 0 445 | 446 | agent_node_masks = {'vehicles': vehicle_node_masks, 'pedestrians': ped_node_masks} 447 | return agent_node_masks 448 | 449 | def visualize_graph(self, node_feats, s_next, edge_type, evf_gt, node_seq, fut_xy): 450 | """ 451 | Function to visualize lane graph. 452 | """ 453 | fig, ax = plt.subplots() 454 | ax.imshow(np.zeros((3, 3)), extent=self.map_extent, cmap='gist_gray') 455 | 456 | # Plot edges 457 | for src_id, src_feats in enumerate(node_feats): 458 | feat_len = np.sum(np.sum(np.absolute(src_feats), axis=1) != 0) 459 | 460 | if feat_len > 0: 461 | src_x = np.mean(src_feats[:feat_len, 0]) 462 | src_y = np.mean(src_feats[:feat_len, 1]) 463 | 464 | for idx, dest_id in enumerate(s_next[src_id]): 465 | edge_t = edge_type[src_id, idx] 466 | visited = evf_gt[src_id, idx] 467 | if 3 > edge_t > 0: 468 | 469 | dest_feats = node_feats[int(dest_id)] 470 | feat_len_dest = np.sum(np.sum(np.absolute(dest_feats), axis=1) != 0) 471 | dest_x = np.mean(dest_feats[:feat_len_dest, 0]) 472 | dest_y = np.mean(dest_feats[:feat_len_dest, 1]) 473 | d_x = dest_x - src_x 474 | d_y = dest_y - src_y 475 | 476 | line_style = '-' if edge_t == 1 else '--' 477 | width = 2 if visited else 0.01 478 | alpha = 1 if visited else 0.5 479 | 480 | plt.arrow(src_x, src_y, d_x, d_y, color='w', head_width=0.1, length_includes_head=True, 481 | linestyle=line_style, width=width, alpha=alpha) 482 | 483 | # Plot nodes 484 | for node_id, node_feat in enumerate(node_feats): 485 | feat_len = np.sum(np.sum(np.absolute(node_feat), axis=1) != 0) 486 | if feat_len > 0: 487 | visited = node_id in node_seq 488 | x = np.mean(node_feat[:feat_len, 0]) 489 | y = np.mean(node_feat[:feat_len, 1]) 490 | yaw = np.arctan2(np.mean(np.sin(node_feat[:feat_len, 2])), 491 | np.mean(np.cos(node_feat[:feat_len, 2]))) 492 | c = color_by_yaw(0, yaw) 493 | c = np.asarray(c).reshape(-1, 3) / 255 494 | s = 200 if visited else 50 495 | ax.scatter(x, y, s, c=c) 496 | 497 | plt.plot(fut_xy[:, 0], fut_xy[:, 1], color='r', lw=3) 498 | 499 | plt.show() 500 | -------------------------------------------------------------------------------- /datasets/nuScenes/nuScenes_raster.py: -------------------------------------------------------------------------------- 1 | from datasets.nuScenes.nuScenes import NuScenesTrajectories 2 | from nuscenes.prediction.input_representation.static_layers import StaticLayerRasterizer 3 | from nuscenes.prediction.input_representation.agents import AgentBoxesWithFadedHistory 4 | from nuscenes.prediction import PredictHelper 5 | import numpy as np 6 | from typing import Dict 7 | 8 | 9 | class NuScenesRaster(NuScenesTrajectories): 10 | """ 11 | NuScenes dataset class for single agent prediction, using the raster representation for maps and agents 12 | """ 13 | 14 | def __init__(self, mode: str, data_dir: str, args: Dict, helper: PredictHelper): 15 | """ 16 | Initialize predict helper, agent and scene representations 17 | :param mode: Mode of operation of dataset, one of {'compute_stats', 'extract_data', 'load_data'} 18 | :param data_dir: Directory to store extracted pre-processed data 19 | :param helper: NuScenes PredictHelper 20 | :param args: Dataset arguments 21 | """ 22 | super().__init__(mode, data_dir, args, helper) 23 | 24 | # Raster parameters 25 | self.img_size = args['img_size'] 26 | self.map_extent = args['map_extent'] 27 | 28 | # Raster map with agent boxes 29 | resolution = (self.map_extent[1] - self.map_extent[0]) / self. img_size[1] 30 | self.map_rasterizer = StaticLayerRasterizer(self.helper, 31 | resolution=resolution, 32 | meters_ahead=self.map_extent[3], 33 | meters_behind=-self.map_extent[2], 34 | meters_left=-self.map_extent[0], 35 | meters_right=self.map_extent[1]) 36 | 37 | self.agent_rasterizer = AgentBoxesWithFadedHistory(self.helper, seconds_of_history=self.t_h, 38 | resolution=resolution, 39 | meters_ahead=self.map_extent[3], 40 | meters_behind=-self.map_extent[2], 41 | meters_left=-self.map_extent[0], 42 | meters_right=self.map_extent[1]) 43 | 44 | def compute_stats(self, idx: int): 45 | """ 46 | Function to compute dataset statistics. Nothing to compute 47 | """ 48 | return {} 49 | 50 | def get_target_agent_representation(self, idx: int) -> np.ndarray: 51 | """ 52 | Extracts target agent representation 53 | :param idx: data index 54 | :return hist: motion state for target agent, [|velocity|, |acc|, |yaw_rate|] 55 | """ 56 | i_t, s_t = self.token_list[idx].split("_") 57 | 58 | vel = self.helper.get_velocity_for_agent(i_t, s_t) 59 | acc = self.helper.get_acceleration_for_agent(i_t, s_t) 60 | yaw_rate = self.helper.get_heading_change_rate_for_agent(i_t, s_t) 61 | 62 | motion_state = np.asarray([vel, acc, yaw_rate]) 63 | for i, val in enumerate(motion_state): 64 | if np.isnan(val): 65 | motion_state[i] = 0 66 | 67 | return motion_state 68 | 69 | def get_map_representation(self, idx: int) -> np.ndarray: 70 | """ 71 | Extracts map representation 72 | :param idx: data index 73 | :return img: RGB raster image with static map elements, shape: [3, img_size[0], img_size[1]] 74 | """ 75 | i_t, s_t = self.token_list[idx].split("_") 76 | img = self.map_rasterizer.make_representation(i_t, s_t) 77 | img = np.moveaxis(img, -1, 0) 78 | img = img.astype(float) / 255 79 | return img 80 | 81 | def get_surrounding_agent_representation(self, idx: int) -> np.ndarray: 82 | """ 83 | Extracts surrounding agent representation 84 | :param idx: data index 85 | :return img: Raster image with faded bounding boxes representing surrounding agents, 86 | shape: [3, img_size[0], img_size[1]] 87 | """ 88 | i_t, s_t = self.token_list[idx].split("_") 89 | img = self.agent_rasterizer.make_representation(i_t, s_t) 90 | img = np.moveaxis(img, -1, 0) 91 | img = img.astype(float) / 255 92 | return img 93 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from train_eval.evaluator import Evaluator 4 | import os 5 | 6 | # Parse arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("-c", "--config", help="Config file with dataset parameters", required=True) 9 | parser.add_argument("-r", "--data_root", help="Root directory with data", required=True) 10 | parser.add_argument("-d", "--data_dir", help="Directory to extract data", required=True) 11 | parser.add_argument("-o", "--output_dir", help="Directory to save results", required=True) 12 | parser.add_argument("-w", "--checkpoint", help="Path to pre-trained or intermediate checkpoint", required=True) 13 | args = parser.parse_args() 14 | 15 | 16 | # Make directories 17 | if not os.path.isdir(args.output_dir): 18 | os.mkdir(args.output_dir) 19 | if not os.path.isdir(os.path.join(args.output_dir, 'results')): 20 | os.mkdir(os.path.join(args.output_dir, 'results')) 21 | 22 | 23 | # Load config 24 | with open(args.config, 'r') as yaml_file: 25 | cfg = yaml.safe_load(yaml_file) 26 | 27 | 28 | # Evaluate 29 | evaluator = Evaluator(cfg, args.data_root, args.data_dir, args.checkpoint) 30 | evaluator.evaluate(output_dir=args.output_dir) 31 | -------------------------------------------------------------------------------- /metrics/covernet_loss.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | from metrics.utils import min_mse 5 | 6 | 7 | class CoverNetLoss(Metric): 8 | """ 9 | Purely computes the classification component of the MTP loss. 10 | """ 11 | 12 | def __init__(self): 13 | """ 14 | Initialize CoverNetLoss 15 | """ 16 | self.name = 'mtp_loss' 17 | 18 | def compute(self, predictions: Dict, ground_truth: Union[Dict, torch.Tensor]) -> torch.Tensor: 19 | """ 20 | Compute MTP loss 21 | :param predictions: Dictionary with 'traj': predicted trajectories and 'probs': mode probabilities 22 | :param ground_truth: Either a tensor with ground truth trajectories or a dictionary 23 | :return: 24 | """ 25 | 26 | # Unpack arguments 27 | traj = predictions['traj'] 28 | probs = predictions['probs'] 29 | traj_gt = ground_truth['traj'] if type(ground_truth) == dict else ground_truth 30 | 31 | # Useful variables 32 | batch_size = traj.shape[0] 33 | sequence_length = traj.shape[2] 34 | 35 | # Masks for variable length ground truth trajectories 36 | masks = ground_truth['masks'] if type(ground_truth) == dict and 'masks' in ground_truth.keys() \ 37 | else torch.zeros(batch_size, sequence_length).to(traj.device) 38 | 39 | # Obtain mode with minimum MSE with respect to ground truth: 40 | errs, inds = min_mse(traj, traj_gt, masks) 41 | 42 | # Calculate NLL loss for trajectories corresponding to selected outputs (assuming model uses log_softmax): 43 | loss = - torch.squeeze(probs.gather(1, inds.unsqueeze(1))) 44 | loss = torch.mean(loss) 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /metrics/goal_pred_nll.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | 5 | 6 | class GoalPredictionNLL(Metric): 7 | """ 8 | Negative log likelihood loss for ground truth goal nodes under predicted goal log-probabilities. 9 | """ 10 | def __init__(self, args: Dict): 11 | self.name = 'goal_pred_nll' 12 | 13 | def compute(self, predictions: Dict, ground_truth: Union[torch.Tensor, Dict]) -> torch.Tensor: 14 | """ 15 | Compute goal prediction NLL loss. 16 | 17 | :param predictions: Dictionary with 'goal_log_probs': log probabilities over nodes for goal prediction 18 | :param ground_truth: Dictionary with 'evf_gt': Look up table with visited edges. Only the goal transition edges 19 | will be used by the loss. 20 | """ 21 | # Unpack arguments 22 | goal_log_probs = predictions['goal_log_probs'] 23 | gt_goals = ground_truth['evf_gt'][:, :, -1].bool() 24 | 25 | loss = -torch.sum(goal_log_probs[gt_goals]) / goal_log_probs.shape[0] 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import abc 3 | from typing import Dict, Union 4 | 5 | 6 | class Metric: 7 | """ 8 | Base class for prediction metric/loss function 9 | """ 10 | @abc.abstractmethod 11 | def __init__(self): 12 | raise NotImplementedError() 13 | 14 | @abc.abstractmethod 15 | def compute(self, predictions: Union[torch.Tensor, Dict], ground_truth: Union[torch.Tensor, Dict]) -> torch.Tensor: 16 | """ 17 | Main function that computes the metric 18 | :param predictions: Predictions generated by the model 19 | :param ground_truth: Ground truth labels 20 | :return metric: Tensor with computed value of metric. 21 | """ 22 | raise NotImplementedError() 23 | -------------------------------------------------------------------------------- /metrics/min_ade.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | from metrics.utils import min_ade 5 | 6 | 7 | class MinADEK(Metric): 8 | """ 9 | Minimum average displacement error for the top K trajectories. 10 | """ 11 | def __init__(self, args: Dict): 12 | self.k = args['k'] 13 | self.name = 'min_ade_' + str(self.k) 14 | 15 | def compute(self, predictions: Dict, ground_truth: Union[Dict, torch.Tensor]) -> torch.Tensor: 16 | """ 17 | Compute MinADEK 18 | :param predictions: Dictionary with 'traj': predicted trajectories and 'probs': mode probabilities 19 | :param ground_truth: Either a tensor with ground truth trajectories or a dictionary 20 | :return: 21 | """ 22 | # Unpack arguments 23 | traj = predictions['traj'] 24 | probs = predictions['probs'] 25 | traj_gt = ground_truth['traj'] if type(ground_truth) == dict else ground_truth 26 | 27 | # Useful params 28 | batch_size = probs.shape[0] 29 | num_pred_modes = traj.shape[1] 30 | sequence_length = traj.shape[2] 31 | 32 | # Masks for variable length ground truth trajectories 33 | masks = ground_truth['masks'] if type(ground_truth) == dict and 'masks' in ground_truth.keys() \ 34 | else torch.zeros(batch_size, sequence_length).to(traj.device) 35 | 36 | min_k = min(self.k, num_pred_modes) 37 | 38 | _, inds_topk = torch.topk(probs, min_k, dim=1) 39 | batch_inds = torch.arange(batch_size).unsqueeze(1).repeat(1, min_k) 40 | traj_topk = traj[batch_inds, inds_topk] 41 | 42 | errs, _ = min_ade(traj_topk, traj_gt, masks) 43 | 44 | return torch.mean(errs) 45 | -------------------------------------------------------------------------------- /metrics/min_fde.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | from metrics.utils import min_fde 5 | 6 | 7 | class MinFDEK(Metric): 8 | """ 9 | Minimum final displacement error for the top K trajectories. 10 | """ 11 | 12 | def __init__(self, args: Dict): 13 | self.k = args['k'] 14 | self.name = 'min_fde_' + str(self.k) 15 | 16 | def compute(self, predictions: Dict, ground_truth: Union[Dict, torch.Tensor]) -> torch.Tensor: 17 | """ 18 | Compute MinFDEK 19 | :param predictions: Dictionary with 'traj': predicted trajectories and 'probs': mode probabilities 20 | :param ground_truth: Either a tensor with ground truth trajectories or a dictionary 21 | :return: 22 | """ 23 | # Unpack arguments 24 | traj = predictions['traj'] 25 | probs = predictions['probs'] 26 | traj_gt = ground_truth['traj'] if type(ground_truth) == dict else ground_truth 27 | 28 | # Useful params 29 | batch_size = probs.shape[0] 30 | num_pred_modes = traj.shape[1] 31 | sequence_length = traj.shape[2] 32 | 33 | # Masks for variable length ground truth trajectories 34 | masks = ground_truth['masks'] if type(ground_truth) == dict and 'masks' in ground_truth.keys() \ 35 | else torch.zeros(batch_size, sequence_length).to(traj.device) 36 | 37 | min_k = min(self.k, num_pred_modes) 38 | 39 | _, inds_topk = torch.topk(probs, min_k, dim=1) 40 | batch_inds = torch.arange(batch_size).unsqueeze(1).repeat(1, min_k) 41 | traj_topk = traj[batch_inds, inds_topk] 42 | 43 | errs, _ = min_fde(traj_topk, traj_gt, masks) 44 | 45 | return torch.mean(errs) 46 | -------------------------------------------------------------------------------- /metrics/miss_rate.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | from metrics.utils import miss_rate 5 | 6 | 7 | class MissRateK(Metric): 8 | """ 9 | Miss rate for the top K trajectories. 10 | """ 11 | 12 | def __init__(self, args: Dict): 13 | self.k = args['k'] 14 | self.dist_thresh = args['dist_thresh'] 15 | self.name = 'miss_rate_' + str(self.k) 16 | 17 | def compute(self, predictions: Dict, ground_truth: Union[Dict, torch.Tensor]) -> torch.Tensor: 18 | """ 19 | Compute miss rate 20 | :param predictions: Dictionary with 'traj': predicted trajectories and 'probs': mode probabilities 21 | :param ground_truth: Either a tensor with ground truth trajectories or a dictionary 22 | :return: 23 | """ 24 | # Unpack arguments 25 | traj = predictions['traj'] 26 | probs = predictions['probs'] 27 | traj_gt = ground_truth['traj'] if type(ground_truth) == dict else ground_truth 28 | 29 | # Useful params 30 | batch_size = probs.shape[0] 31 | num_pred_modes = traj.shape[1] 32 | sequence_length = traj.shape[2] 33 | 34 | # Masks for variable length ground truth trajectories 35 | masks = ground_truth['masks'] if type(ground_truth) == dict and 'masks' in ground_truth.keys() \ 36 | else torch.zeros(batch_size, sequence_length).to(traj.device) 37 | 38 | min_k = min(self.k, num_pred_modes) 39 | 40 | _, inds_topk = torch.topk(probs, min_k, dim=1) 41 | batch_inds = torch.arange(batch_size).unsqueeze(1).repeat(1, min_k) 42 | traj_topk = traj[batch_inds, inds_topk] 43 | 44 | return miss_rate(traj_topk, traj_gt, masks, dist_thresh=self.dist_thresh) 45 | -------------------------------------------------------------------------------- /metrics/mtp_loss.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | from metrics.utils import min_ade, traj_nll 5 | 6 | 7 | class MTPLoss(Metric): 8 | """ 9 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with 10 | Multipath outputs, with residuals added to anchors. 11 | """ 12 | 13 | def __init__(self, args: Dict = None): 14 | """ 15 | Initialize MTP loss 16 | :param args: Dictionary with the following (optional) keys 17 | use_variance: bool, whether or not to use variances for computing regression component of loss, 18 | default: False 19 | alpha: float, relative weight assigned to classification component, compared to regression component 20 | of loss, default: 1 21 | """ 22 | self.use_variance = args['use_variance'] if args is not None and 'use_variance' in args.keys() else False 23 | self.alpha = args['alpha'] if args is not None and 'alpha' in args.keys() else 1 24 | self.beta = args['beta'] if args is not None and 'beta' in args.keys() else 1 25 | self.name = 'mtp_loss' 26 | 27 | def compute(self, predictions: Dict, ground_truth: Union[Dict, torch.Tensor]) -> torch.Tensor: 28 | """ 29 | Compute MTP loss 30 | :param predictions: Dictionary with 'traj': predicted trajectories and 'probs': mode (log) probabilities 31 | :param ground_truth: Either a tensor with ground truth trajectories or a dictionary 32 | :return: 33 | """ 34 | 35 | # Unpack arguments 36 | traj = predictions['traj'] 37 | log_probs = predictions['probs'] 38 | traj_gt = ground_truth['traj'] if type(ground_truth) == dict else ground_truth 39 | 40 | # Useful variables 41 | batch_size = traj.shape[0] 42 | sequence_length = traj.shape[2] 43 | pred_params = 5 if self.use_variance else 2 44 | 45 | # Masks for variable length ground truth trajectories 46 | masks = ground_truth['masks'] if type(ground_truth) == dict and 'masks' in ground_truth.keys() \ 47 | else torch.zeros(batch_size, sequence_length).to(traj.device) 48 | 49 | # Obtain mode with minimum ADE with respect to ground truth: 50 | errs, inds = min_ade(traj, traj_gt, masks) 51 | inds_rep = inds.repeat(sequence_length, pred_params, 1, 1).permute(3, 2, 0, 1) 52 | 53 | # Calculate MSE or NLL loss for trajectories corresponding to selected outputs: 54 | traj_best = traj.gather(1, inds_rep).squeeze(dim=1) 55 | 56 | if self.use_variance: 57 | l_reg = traj_nll(traj_best, traj_gt, masks) 58 | else: 59 | l_reg = errs 60 | 61 | # Compute classification loss 62 | l_class = - torch.squeeze(log_probs.gather(1, inds.unsqueeze(1))) 63 | 64 | loss = self.beta * l_reg + self.alpha * l_class 65 | loss = torch.mean(loss) 66 | 67 | return loss 68 | -------------------------------------------------------------------------------- /metrics/pi_bc.py: -------------------------------------------------------------------------------- 1 | from metrics.metric import Metric 2 | from typing import Dict, Union 3 | import torch 4 | 5 | 6 | class PiBehaviorCloning(Metric): 7 | """ 8 | Behavior closing loss for training graph traversal policy. 9 | """ 10 | def __init__(self, args: Dict): 11 | self.name = 'pi_bc' 12 | 13 | def compute(self, predictions: Dict, ground_truth: Union[torch.Tensor, Dict]) -> torch.Tensor: 14 | """ 15 | Compute negative log likelihood of ground truth traversed edges under learned policy. 16 | 17 | :param predictions: Dictionary with 'pi': policy for lane graph traversal (log probabilities) 18 | :param ground_truth: Dictionary with 'evf_gt': Look up table with visited edges 19 | """ 20 | # Unpack arguments 21 | pi = predictions['pi'] 22 | evf_gt = ground_truth['evf_gt'] 23 | 24 | loss = -torch.sum(pi[evf_gt.bool()]) / pi.shape[0] 25 | 26 | return loss 27 | -------------------------------------------------------------------------------- /metrics/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | import math 4 | 5 | 6 | def mse(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Computes MSE for a set of trajectories with respect to ground truth. 9 | 10 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 11 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 12 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 13 | :return: errs: errors, shape [batch_size, num_modes] 14 | """ 15 | 16 | num_modes = traj.shape[1] 17 | 18 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 19 | masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 20 | err = traj_gt_rpt - traj[:, :, :, 0:2] 21 | err = torch.pow(err, exponent=2) 22 | err = torch.sum(err, dim=3) 23 | err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.sum((1 - masks_rpt), dim=2) 24 | return err 25 | 26 | 27 | def max_dist(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: 28 | """ 29 | Computes max distance of a set of trajectories with respect to ground truth trajectory. 30 | 31 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 32 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 33 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 34 | :return dist: shape [batch_size, num_modes] 35 | """ 36 | num_modes = traj.shape[1] 37 | 38 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 39 | masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 40 | dist = traj_gt_rpt - traj[:, :, :, 0:2] 41 | dist = torch.pow(dist, exponent=2) 42 | dist = torch.sum(dist, dim=3) 43 | dist = torch.pow(dist, exponent=0.5) 44 | dist[masks_rpt.bool()] = -math.inf 45 | dist, _ = torch.max(dist, dim=2) 46 | 47 | return dist 48 | 49 | 50 | def min_mse(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 51 | """ 52 | Computes MSE for the best trajectory is a set, with respect to ground truth 53 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 54 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 55 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 56 | :return errs, inds: errors and indices for modes with min error, shape [batch_size] 57 | """ 58 | 59 | num_modes = traj.shape[1] 60 | 61 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 62 | masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 63 | err = traj_gt_rpt - traj[:, :, :, 0:2] 64 | err = torch.pow(err, exponent=2) 65 | err = torch.sum(err, dim=3) 66 | err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.sum((1 - masks_rpt), dim=2) 67 | err, inds = torch.min(err, dim=1) 68 | 69 | return err, inds 70 | 71 | 72 | def min_ade(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 73 | """ 74 | Computes average displacement error for the best trajectory is a set, with respect to ground truth 75 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 76 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 77 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 78 | :return errs, inds: errors and indices for modes with min error, shape [batch_size] 79 | """ 80 | num_modes = traj.shape[1] 81 | 82 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 83 | masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 84 | err = traj_gt_rpt - traj[:, :, :, 0:2] 85 | err = torch.pow(err, exponent=2) 86 | err = torch.sum(err, dim=3) 87 | err = torch.pow(err, exponent=0.5) 88 | err = torch.sum(err * (1 - masks_rpt), dim=2) / torch.sum((1 - masks_rpt), dim=2) 89 | err, inds = torch.min(err, dim=1) 90 | 91 | return err, inds 92 | 93 | 94 | def min_fde(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 95 | """ 96 | Computes final displacement error for the best trajectory is a set, with respect to ground truth 97 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 98 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 99 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 100 | :return errs, inds: errors and indices for modes with min error, shape [batch_size] 101 | """ 102 | num_modes = traj.shape[1] 103 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 104 | lengths = torch.sum(1-masks, dim=1).long() 105 | inds = lengths.unsqueeze(1).unsqueeze(2).unsqueeze(3).repeat(1, num_modes, 1, 2) - 1 106 | 107 | traj_last = torch.gather(traj[..., :2], dim=2, index=inds).squeeze(2) 108 | traj_gt_last = torch.gather(traj_gt_rpt, dim=2, index=inds).squeeze(2) 109 | 110 | err = traj_gt_last - traj_last[..., 0:2] 111 | err = torch.pow(err, exponent=2) 112 | err = torch.sum(err, dim=2) 113 | err = torch.pow(err, exponent=0.5) 114 | err, inds = torch.min(err, dim=1) 115 | 116 | return err, inds 117 | 118 | 119 | def miss_rate(traj: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor, dist_thresh: float = 2) -> torch.Tensor: 120 | """ 121 | Computes miss rate for mini batch of trajectories, with respect to ground truth and given distance threshold 122 | :param traj: predictions, shape [batch_size, num_modes, sequence_length, 2] 123 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 124 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 125 | :param dist_thresh: distance threshold for computing miss rate. 126 | :return errs, inds: errors and indices for modes with min error, shape [batch_size] 127 | """ 128 | num_modes = traj.shape[1] 129 | 130 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 131 | masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 132 | dist = traj_gt_rpt - traj[:, :, :, 0:2] 133 | dist = torch.pow(dist, exponent=2) 134 | dist = torch.sum(dist, dim=3) 135 | dist = torch.pow(dist, exponent=0.5) 136 | dist[masks_rpt.bool()] = -math.inf 137 | dist, _ = torch.max(dist, dim=2) 138 | dist, _ = torch.min(dist, dim=1) 139 | m_r = torch.sum(torch.as_tensor(dist > dist_thresh)) / len(dist) 140 | 141 | return m_r 142 | 143 | 144 | # TODO: DEBUG THIS FUNCTION (?) 145 | def traj_nll(pred_dist: torch.Tensor, traj_gt: torch.Tensor, masks: torch.Tensor): 146 | """ 147 | Computes negative log likelihood of ground truth trajectory under a predictive distribution with a single mode, 148 | with a bivariate Gaussian distribution predicted at each time in the prediction horizon 149 | 150 | :param pred_dist: parameters of a bivariate Gaussian distribution, shape [batch_size, sequence_length, 5] 151 | :param traj_gt: ground truth trajectory, shape [batch_size, sequence_length, 2] 152 | :param masks: masks for varying length ground truth, shape [batch_size, sequence_length] 153 | :return: 154 | """ 155 | mu_x = pred_dist[:, :, 0] 156 | mu_y = pred_dist[:, :, 1] 157 | x = traj_gt[:, :, 0] 158 | y = traj_gt[:, :, 1] 159 | 160 | sig_x = pred_dist[:, :, 2] 161 | sig_y = pred_dist[:, :, 3] 162 | rho = pred_dist[:, :, 4] 163 | ohr = torch.pow(1 - torch.pow(rho, 2), -0.5) 164 | 165 | nll = 0.5 * torch.pow(ohr, 2) * \ 166 | (torch.pow(sig_x, 2) * torch.pow(x - mu_x, 2) + 167 | torch.pow(sig_y, 2) * torch.pow(y - mu_y, 2) - 168 | 2 * rho * torch.pow(sig_x, 1) * torch.pow(sig_y, 1) * (x - mu_x) * (y - mu_y))\ 169 | - torch.log(sig_x * sig_y * ohr) + 1.8379 170 | 171 | nll[nll.isnan()] = 0 172 | nll[nll.isinf()] = 0 173 | 174 | nll = torch.sum(nll * (1 - masks), dim=1) / torch.sum((1 - masks), dim=1) 175 | # Note: Normalizing with torch.sum((1 - masks), dim=1) makes values somewhat comparable for trajectories of 176 | # different lengths 177 | 178 | return nll 179 | -------------------------------------------------------------------------------- /models/aggregators/aggregator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import abc 4 | from typing import Dict, Union 5 | 6 | 7 | # Initialize device: 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class PredictionAggregator(nn.Module): 12 | """ 13 | Base class for context aggregators for single agent prediction. 14 | Aggregates a set of context (map, surrounding agent) encodings and outputs either a single aggregated context vector 15 | or 'K' selectively aggregated context vectors for multimodal prediction. 16 | """ 17 | 18 | def __init__(self): 19 | super().__init__() 20 | 21 | @abc.abstractmethod 22 | def forward(self, encodings: Dict) -> Union[Dict, torch.Tensor]: 23 | """ 24 | Forward pass for prediction aggregator 25 | :param encodings: Dictionary with target agent and context encodings 26 | :return agg_encoding: Aggregated context encoding 27 | """ 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /models/aggregators/concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.aggregators.aggregator import PredictionAggregator 3 | from typing import Dict 4 | 5 | 6 | class Concat(PredictionAggregator): 7 | """ 8 | Concatenates target agent encoding and all context encodings. 9 | Set of context encodings needs to be the same size, ideally with a well-defined order. 10 | """ 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, encodings: Dict) -> torch.Tensor: 15 | """ 16 | Forward pass for Concat aggregator 17 | """ 18 | target_agent_enc = encodings['target_agent_encoding'] 19 | context_enc = encodings['context_encoding'] 20 | batch_size = target_agent_enc.shape[0] 21 | 22 | if context_enc['combined'] is not None: 23 | context_vec = context_enc['combined'].reshape(batch_size, -1) 24 | else: 25 | map_vec = context_enc['map'].reshape(batch_size, -1) if context_enc['map'] else torch.empty(batch_size, 0) 26 | vehicle_vec = context_enc['vehicles'].reshape(batch_size, -1) if context_enc['map'] \ 27 | else torch.empty(batch_size, 0) 28 | ped_vec = context_enc['pedestrians'].reshape(batch_size, -1) if context_enc['pedestrians']\ 29 | else torch.empty(batch_size, 0) 30 | context_vec = torch.cat((map_vec, vehicle_vec, ped_vec), dim=1) 31 | 32 | op = torch.cat((target_agent_enc, context_vec), dim=1) 33 | return op 34 | -------------------------------------------------------------------------------- /models/aggregators/global_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.aggregators.aggregator import PredictionAggregator 4 | from typing import Dict, Tuple 5 | 6 | 7 | class GlobalAttention(PredictionAggregator): 8 | """ 9 | Aggregate context encoding using scaled dot product attention. Query obtained using target agent encoding, 10 | Keys and values obtained using map and surrounding agent encodings. 11 | """ 12 | 13 | def __init__(self, args: Dict): 14 | 15 | """ 16 | args to include 17 | 18 | enc_size: int Dimension of encodings generated by encoder 19 | emb_size: int Size of embeddings used for queries, keys and values 20 | num_heads: int Number of attention heads 21 | 22 | """ 23 | super().__init__() 24 | self.query_emb = nn.Linear(args['target_agent_enc_size'], args['emb_size']) 25 | self.key_emb = nn.Linear(args['context_enc_size'], args['emb_size']) 26 | self.val_emb = nn.Linear(args['context_enc_size'], args['emb_size']) 27 | self.mha = nn.MultiheadAttention(args['emb_size'], args['num_heads']) 28 | 29 | def forward(self, encodings: Dict) -> torch.Tensor: 30 | """ 31 | Forward pass for attention aggregator 32 | """ 33 | target_agent_enc = encodings['target_agent_encoding'] 34 | context_enc = encodings['context_encoding'] 35 | if context_enc['combined'] is not None: 36 | combined_enc, combined_masks = context_enc['combined'], context_enc['combined_masks'].bool() 37 | else: 38 | combined_enc, combined_masks = self.get_combined_encodings(context_enc) 39 | 40 | query = self.query_emb(target_agent_enc).unsqueeze(0) 41 | keys = self.key_emb(combined_enc).permute(1, 0, 2) 42 | vals = self.val_emb(combined_enc).permute(1, 0, 2) 43 | op, _ = self.mha(query, keys, vals, key_padding_mask=combined_masks) 44 | op = op.squeeze(0) 45 | op = torch.cat((target_agent_enc, op), dim=-1) 46 | 47 | return op 48 | 49 | @staticmethod 50 | def get_combined_encodings(context_enc: Dict) -> Tuple[torch.Tensor, torch.Tensor]: 51 | """ 52 | Creates a combined set of map and surrounding agent encodings to be aggregated using attention. 53 | """ 54 | encodings = [] 55 | masks = [] 56 | if 'map' in context_enc: 57 | encodings.append(context_enc['map']) 58 | masks.append(context_enc['map_masks']) 59 | if 'vehicles' in context_enc: 60 | encodings.append(context_enc['vehicles']) 61 | masks.append(context_enc['vehicle_masks']) 62 | if 'pedestrians' in context_enc: 63 | encodings.append(context_enc['pedestrians']) 64 | masks.append(context_enc['pedestrian_masks']) 65 | combined_enc = torch.cat(encodings, dim=1) 66 | combined_masks = torch.cat(masks, dim=1).bool() 67 | return combined_enc, combined_masks 68 | -------------------------------------------------------------------------------- /models/aggregators/goal_conditioned.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.aggregators.global_attention import GlobalAttention 4 | from typing import Dict 5 | from torch.distributions import Categorical 6 | 7 | 8 | # Initialize device: 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class GoalConditioned(GlobalAttention): 13 | """ 14 | Goal conditioned aggregator with following functionality. 15 | 1) Predicts goal probabilities over lane nodes 16 | 2) Samples goals 17 | 3) Outputs goal conditioned encodings for N samples to pass on to the trajectory decoder 18 | """ 19 | 20 | def __init__(self, args): 21 | """ 22 | args to include 23 | 24 | for aggregating map and agent context 25 | enc_size: int Dimension of encodings generated by encoder 26 | emb_size: int Size of embeddings used for queries, keys and values 27 | num_heads: int Number of attention heads 28 | 29 | for goal prediction 30 | 'pre_train': bool, whether the model is being pre-trained using ground truth goals. 31 | 'context_enc_size': int, size of node encoding 32 | 'target_agent_enc_size': int, size of target agent encoding 33 | 'goal_h1_size': int, size of first layer of goal prediction header 34 | 'goal_h2_size': int, size of second layer of goal prediction header 35 | 'num_samples': int, number of goals to sample 36 | """ 37 | super(GoalConditioned, self).__init__(args) 38 | 39 | # Goal prediction header 40 | self.goal_h1 = nn.Linear(args['context_enc_size'] + args['target_agent_enc_size'], args['goal_h1_size']) 41 | self.goal_h2 = nn.Linear(args['goal_h1_size'], args['goal_h2_size']) 42 | self.goal_op = nn.Linear(args['goal_h2_size'], 1) 43 | self.num_samples = args['num_samples'] 44 | self.leaky_relu = nn.LeakyReLU() 45 | self.log_softmax = nn.LogSoftmax(dim=1) 46 | 47 | # Pretraining 48 | self.pre_train = args['pre_train'] 49 | 50 | def forward(self, encodings: Dict) -> Dict: 51 | """ 52 | Forward pass for goal conditioned aggregator 53 | :param encodings: dictionary with encoder outputs 54 | :return: outputs, dictionary with 55 | 'agg_encoding': aggregated encodings 56 | 'goal_log_probs': log probabilities over nodes corresponding to predicted goals 57 | """ 58 | 59 | # Unpack encodings: 60 | target_agent_encoding = encodings['target_agent_encoding'] 61 | node_encodings = encodings['context_encoding']['combined'] 62 | node_masks = encodings['context_encoding']['combined_masks'] 63 | 64 | # Predict goal log-probabilities 65 | goal_log_probs = self.compute_goal_probs(target_agent_encoding, node_encodings, node_masks) 66 | 67 | # If pretraining model, use ground truth goals 68 | if self.pre_train and self.training: 69 | max_nodes = node_masks.shape[1] 70 | goals = encodings['node_seq_gt'][:, -1].unsqueeze(1).repeat(1, self.num_samples).long() - max_nodes 71 | else: 72 | # If fine-tuning or validating, sample goals 73 | goals = Categorical(torch.exp(goal_log_probs).unsqueeze(1).repeat(1, self.num_samples, 1)).sample() 74 | 75 | # Aggregate context 76 | agg_enc = super(GoalConditioned, self).forward(encodings) 77 | 78 | # Repeat context vector for number of samples and append goal encodings 79 | agg_enc = agg_enc.unsqueeze(1).repeat(1, self.num_samples, 1) 80 | batch_indices = torch.arange(agg_enc.shape[0]).unsqueeze(1).repeat(1, self.num_samples) 81 | goal_encodings = node_encodings[batch_indices, goals] 82 | agg_enc = torch.cat((agg_enc, goal_encodings), dim=2) 83 | 84 | # Return outputs 85 | outputs = {'agg_encoding': agg_enc, 'goal_log_probs': goal_log_probs} 86 | 87 | return outputs 88 | 89 | def compute_goal_probs(self, target_agent_encoding, node_encodings, node_masks): 90 | """ 91 | Forward pass for goal prediction header 92 | :param target_agent_encoding: tensor encoding the target agent's past motion 93 | :param node_encodings: tensor of node encodings provided by the encoder 94 | :param node_masks: masks indicating whether a node exists for a given index in the tensor 95 | :return: 96 | """ 97 | # Useful variables 98 | max_nodes = node_encodings.shape[1] 99 | target_agent_enc_size = target_agent_encoding.shape[-1] 100 | node_enc_size = node_encodings.shape[-1] 101 | 102 | # Concatenate node encodings with target agent encoding 103 | target_agent_encoding = target_agent_encoding.unsqueeze(1).repeat(1, max_nodes, 1) 104 | enc = torch.cat((target_agent_encoding, node_encodings), dim=2) 105 | 106 | # Form a single batch of encodings 107 | masks_goal = ~node_masks.unsqueeze(-1).bool() 108 | enc_batched = torch.masked_select(enc, masks_goal).reshape(-1, target_agent_enc_size + node_enc_size) 109 | 110 | # Compute goal log probabilities 111 | goal_ops_ = self.goal_op(self.leaky_relu(self.goal_h2(self.leaky_relu(self.goal_h1(enc_batched))))) 112 | goal_ops = torch.zeros_like(masks_goal).float() 113 | goal_ops = goal_ops.masked_scatter_(masks_goal, goal_ops_).squeeze(-1) 114 | goal_log_probs = self.log_softmax(goal_ops + torch.log(1-node_masks)) 115 | 116 | return goal_log_probs 117 | -------------------------------------------------------------------------------- /models/aggregators/pgp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.aggregators.aggregator import PredictionAggregator 4 | from typing import Dict 5 | from torch.distributions import Categorical 6 | from positional_encodings.torch_encodings import PositionalEncoding1D 7 | 8 | 9 | # Initialize device: 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | class PGP(PredictionAggregator): 14 | """ 15 | Policy header + selective aggregator from "Multimodal trajectory prediction conditioned on lane graph traversals" 16 | 1) Outputs edge probabilities corresponding to pi_route 17 | 2) Samples pi_route to output traversed paths 18 | 3) Selectively aggregates context along traversed paths 19 | """ 20 | 21 | def __init__(self, args: Dict): 22 | """ 23 | args to include 24 | 'pre_train': bool, whether the model is being pre-trained using ground truth node sequence. 25 | 'node_enc_size': int, size of node encoding 26 | 'target_agent_enc_size': int, size of target agent encoding 27 | 'pi_h1_size': int, size of first layer of policy header 28 | 'pi_h2_size': int, size of second layer of policy header 29 | 'emb_size': int, embedding size for attention layer for aggregating node encodings 30 | 'num_heads: int, number of attention heads 31 | 'num_samples': int, number of sampled traversals (and encodings) to output 32 | """ 33 | 34 | super().__init__() 35 | self.pre_train = args['pre_train'] 36 | 37 | # Policy header 38 | self.pi_h1 = nn.Linear(2 * args['node_enc_size'] + args['target_agent_enc_size'] + 2, args['pi_h1_size']) 39 | self.pi_h2 = nn.Linear(args['pi_h1_size'], args['pi_h2_size']) 40 | self.pi_op = nn.Linear(args['pi_h2_size'], 1) 41 | self.pi_h1_goal = nn.Linear(args['node_enc_size'] + args['target_agent_enc_size'], args['pi_h1_size']) 42 | self.pi_h2_goal = nn.Linear(args['pi_h1_size'], args['pi_h2_size']) 43 | self.pi_op_goal = nn.Linear(args['pi_h2_size'], 1) 44 | self.leaky_relu = nn.LeakyReLU() 45 | self.log_softmax = nn.LogSoftmax(dim=2) 46 | 47 | # For sampling policy 48 | self.horizon = args['horizon'] 49 | self.num_samples = args['num_samples'] 50 | 51 | # Attention based aggregator 52 | self.pos_enc = PositionalEncoding1D(args['node_enc_size']) 53 | self.query_emb = nn.Linear(args['target_agent_enc_size'], args['emb_size']) 54 | self.key_emb = nn.Linear(args['node_enc_size'], args['emb_size']) 55 | self.val_emb = nn.Linear(args['node_enc_size'], args['emb_size']) 56 | self.mha = nn.MultiheadAttention(args['emb_size'], args['num_heads']) 57 | 58 | def forward(self, encodings: Dict) -> Dict: 59 | """ 60 | Forward pass for PGP aggregator 61 | :param encodings: dictionary with encoder outputs 62 | :return: outputs: dictionary with 63 | 'agg_encoding': aggregated encodings along sampled traversals 64 | 'pi': discrete policy (probabilities over outgoing edges) for graph traversal 65 | """ 66 | 67 | # Unpack encodings: 68 | target_agent_encoding = encodings['target_agent_encoding'] 69 | node_encodings = encodings['context_encoding']['combined'] 70 | node_masks = encodings['context_encoding']['combined_masks'] 71 | s_next = encodings['s_next'] 72 | edge_type = encodings['edge_type'] 73 | if 'att' in encodings: 74 | att = encodings['att'] 75 | else: 76 | att = None 77 | 78 | # Compute pi (log probs) 79 | pi = self.compute_policy(target_agent_encoding, node_encodings, node_masks, s_next, edge_type) # (batch_size, num_nodes, num_edges) [64 164 15] 80 | 81 | # If pretraining model, use ground truth node sequences 82 | if self.pre_train and self.training: 83 | sampled_traversals = encodings['node_seq_gt'].unsqueeze(1).repeat(1, self.num_samples, 1).long() 84 | else: 85 | # Sample pi 86 | init_node = encodings['init_node'] # [64, 164] 87 | sampled_traversals = self.sample_policy(torch.exp(pi), s_next, init_node) # (batch_size, num_samples, horizon) [64 1000 15] 88 | 89 | # Selectively aggregate context along traversed paths 90 | agg_enc = self.aggregate(sampled_traversals, node_encodings, target_agent_encoding) 91 | 92 | outputs = {'agg_encoding': agg_enc, 'pi': pi, 'att': att} 93 | return outputs 94 | 95 | def aggregate(self, sampled_traversals, node_encodings, target_agent_encoding) -> torch.Tensor: 96 | 97 | # Useful variables: 98 | batch_size = node_encodings.shape[0] 99 | max_nodes = node_encodings.shape[1] 100 | 101 | # Get unique traversals and form consolidated batch: 102 | unique_traversals = [torch.unique(i, dim=0, return_counts=True) for i in sampled_traversals] 103 | traversals_batched = torch.cat([i[0] for i in unique_traversals], dim=0) 104 | counts_batched = torch.cat([i[1] for i in unique_traversals], dim=0) 105 | batch_idcs = torch.cat([n*torch.ones(len(i[1])).long() for n, i in enumerate(unique_traversals)]) 106 | batch_idcs = batch_idcs.unsqueeze(1).repeat(1, self.horizon) 107 | 108 | # Dummy encodings for goal nodes 109 | dummy_enc = torch.zeros_like(node_encodings) 110 | node_encodings = torch.cat((node_encodings, dummy_enc), dim=1) 111 | 112 | # Gather node encodings along traversed paths 113 | node_enc_selected = node_encodings[batch_idcs, traversals_batched] 114 | 115 | # Add positional encodings: 116 | pos_enc = self.pos_enc(torch.zeros_like(node_enc_selected)) 117 | node_enc_selected += pos_enc 118 | 119 | # Multi-head attention 120 | target_agent_enc_batched = target_agent_encoding[batch_idcs[:, 0]] 121 | query = self.query_emb(target_agent_enc_batched).unsqueeze(0) 122 | keys = self.key_emb(node_enc_selected).permute(1, 0, 2) 123 | vals = self.val_emb(node_enc_selected).permute(1, 0, 2) 124 | key_padding_mask = torch.as_tensor(traversals_batched >= max_nodes) 125 | att_op, _ = self.mha(query, keys, vals, key_padding_mask) 126 | 127 | # Repeat based on counts 128 | att_op = att_op.squeeze(0).repeat_interleave(counts_batched, dim=0).view(batch_size, self.num_samples, -1) 129 | 130 | # Concatenate target agent encoding 131 | agg_enc = torch.cat((target_agent_encoding.unsqueeze(1).repeat(1, self.num_samples, 1), att_op), dim=-1) 132 | 133 | return agg_enc 134 | 135 | def sample_policy(self, pi, s_next, init_node) -> torch.Tensor: 136 | """ 137 | Sample graph traversals using discrete policy. 138 | :param pi: tensor with probabilities corresponding to the policy 139 | :param s_next: look-up table for next node for a given source node and edge 140 | :param init_node: initial node to start the policy at 141 | :return: 142 | """ 143 | with torch.no_grad(): 144 | 145 | # Useful variables: 146 | batch_size = pi.shape[0] 147 | max_nodes = pi.shape[1] 148 | batch_idcs = torch.arange(batch_size, device=device).unsqueeze(1).repeat(1, self.num_samples).view(-1) # 1000*batch_idx 149 | 150 | # Initialize output 151 | sampled_traversals = torch.zeros(batch_size, self.num_samples, self.horizon, device=device).long() 152 | 153 | # Set up dummy self transitions for goal states: 154 | pi_dummy = torch.zeros_like(pi) 155 | pi_dummy[:, :, -1] = 1 156 | s_next_dummy = torch.zeros_like(s_next) 157 | s_next_dummy[:, :, -1] = max_nodes + torch.arange(max_nodes).unsqueeze(0).repeat(batch_size, 1) 158 | pi = torch.cat((pi, pi_dummy), dim=1) 159 | s_next = torch.cat((s_next, s_next_dummy), dim=1) 160 | 161 | # Sample initial node: 162 | pi_s = init_node.unsqueeze(1).repeat(1, self.num_samples, 1).view(-1, max_nodes) # [64000 164] [batch*samples, max_nodes] 163 | s = Categorical(pi_s).sample() # 64000 164 | sampled_traversals[:, :, 0] = s.reshape(batch_size, self.num_samples) 165 | 166 | # Sample traversed paths for a fixed horizon 167 | for n in range(1, self.horizon): 168 | 169 | # Gather policy at appropriate indices: 170 | pi_s = pi[batch_idcs, s] 171 | 172 | # Sample edges 173 | a = Categorical(pi_s).sample() # 64000 174 | 175 | # Look-up next node 176 | s = s_next[batch_idcs, s, a].long() 177 | 178 | # Add node indices to sampled traversals 179 | sampled_traversals[:, :, n] = s.reshape(batch_size, self.num_samples) 180 | 181 | return sampled_traversals 182 | 183 | def compute_policy(self, target_agent_encoding, node_encodings, node_masks, s_next, edge_type) -> torch.Tensor: 184 | """ 185 | Forward pass for policy header 186 | :param target_agent_encoding: tensor encoding the target agent's past motion 187 | :param node_encodings: tensor of node encodings provided by the encoder 188 | :param node_masks: masks indicating whether a node exists for a given index in the tensor 189 | :param s_next: look-up table for next node for a given source node and edge 190 | :param edge_type: look-up table with edge types 191 | :return pi: tensor with probabilities corresponding to the policy 192 | """ 193 | # Useful variables: 194 | batch_size = node_encodings.shape[0] 195 | max_nodes = node_encodings.shape[1] 196 | max_nbrs = s_next.shape[2] - 1 197 | node_enc_size = node_encodings.shape[2] 198 | target_agent_enc_size = target_agent_encoding.shape[1] 199 | 200 | # Gather source node encodigns, destination node encodings, edge encodings and target agent encodings. 201 | src_node_enc = node_encodings.unsqueeze(2).repeat(1, 1, max_nbrs, 1) # [B, max nodes, max nbrs, node enc size] 202 | dst_idcs = s_next[:, :, :-1].reshape(batch_size, -1).long() # [B, max nodes * max nbrs] 203 | batch_idcs = torch.arange(batch_size).unsqueeze(1).repeat(1, max_nodes * max_nbrs) # [B, max nodes * max nbrs] 204 | dst_node_enc = node_encodings[batch_idcs, dst_idcs].reshape(batch_size, max_nodes, max_nbrs, node_enc_size) 205 | target_agent_enc = target_agent_encoding.unsqueeze(1).unsqueeze(2).repeat(1, max_nodes, max_nbrs, 1) # [B, max nodes, max nbrs, enc size] 206 | edge_enc = torch.cat((torch.as_tensor(edge_type[:, :, :-1] == 1, device=device).unsqueeze(3).float(), 207 | torch.as_tensor(edge_type[:, :, :-1] == 2, device=device).unsqueeze(3).float()), dim=3) # cat(succ, prox) 208 | enc = torch.cat((target_agent_enc, src_node_enc, dst_node_enc, edge_enc), dim=3) 209 | enc_goal = torch.cat((target_agent_enc[:, :, 0, :], src_node_enc[:, :, 0, :]), dim=2) 210 | 211 | # Form a single batch of encodings 212 | masks = torch.sum(edge_enc, dim=3, keepdim=True).bool() 213 | masks_goal = ~node_masks.unsqueeze(-1).bool() 214 | enc_batched = torch.masked_select(enc, masks).reshape(-1, target_agent_enc_size + 2*node_enc_size + 2) 215 | enc_goal_batched = torch.masked_select(enc_goal, masks_goal).reshape(-1, target_agent_enc_size + node_enc_size) 216 | 217 | # Compute scores for pi_route 218 | pi_ = self.pi_op(self.leaky_relu(self.pi_h2(self.leaky_relu(self.pi_h1(enc_batched))))) 219 | pi = torch.zeros_like(masks).float() 220 | pi = pi.masked_scatter_(masks, pi_).squeeze(-1) 221 | pi_goal_ = self.pi_op_goal(self.leaky_relu(self.pi_h2_goal(self.leaky_relu(self.pi_h1_goal(enc_goal_batched))))) 222 | pi_goal = torch.zeros_like(masks_goal).float() 223 | pi_goal = pi_goal.masked_scatter_(masks_goal, pi_goal_) 224 | 225 | # Normalize to give log probabilities 226 | pi = torch.cat((pi, pi_goal), dim=-1) 227 | op_masks = torch.log(torch.as_tensor(edge_type != 0).float()) 228 | pi = self.log_softmax(pi + op_masks) 229 | 230 | return pi 231 | -------------------------------------------------------------------------------- /models/decoders/covernet.py: -------------------------------------------------------------------------------- 1 | from models.decoders.decoder import PredictionDecoder 2 | from models.decoders.utils import k_means_anchors 3 | import torch 4 | import torch.nn as nn 5 | from datasets.interface import SingleAgentDataset 6 | from typing import Dict 7 | 8 | 9 | class CoverNet(PredictionDecoder): 10 | 11 | def __init__(self, args): 12 | """ 13 | Prediction decoder for CoverNet 14 | 15 | args to include: 16 | num_modes: int number of modes K 17 | op_len: int prediction horizon 18 | hidden_size: int hidden layer size 19 | encoding_size: int size of context encoding 20 | """ 21 | 22 | super().__init__() 23 | 24 | self.agg_type = args['agg_type'] 25 | self.num_modes = args['num_modes'] 26 | self.hidden = nn.Linear(args['encoding_size'], args['hidden_size']) 27 | self.op_len = args['op_len'] 28 | self.prob_op = nn.Linear(args['hidden_size'], self.num_modes) 29 | self.leaky_relu = nn.LeakyReLU(0.01) 30 | self.log_softmax = nn.LogSoftmax(dim=1) 31 | self.anchors = nn.Parameter(torch.zeros(self.num_modes, self.op_len, 2), requires_grad=False) 32 | 33 | def generate_anchors(self, ds: SingleAgentDataset): 34 | """ 35 | Function to initialize anchors. Extracts fixed trajectory set with k-means. Dynamic trajectory sets 36 | have not been implemented. 37 | :param ds: train dataset for single agent trajectory prediction 38 | """ 39 | self.anchors = nn.Parameter(k_means_anchors(self.num_modes, ds)) 40 | 41 | def forward(self, agg_encoding: torch.Tensor) -> Dict: 42 | """ 43 | Forward pass for CoverNet 44 | :param agg_encoding: aggregated context encoding 45 | :return predictions: dictionary with 'traj': K predicted trajectories and 46 | 'probs': K corresponding probabilities 47 | """ 48 | h = self.leaky_relu(self.hidden(agg_encoding)) 49 | batch_size = h.shape[0] 50 | probs = self.log_softmax(self.prob_op(h)) 51 | probs = probs.squeeze(dim=-1) 52 | traj = self.anchors.unsqueeze(0).repeat(batch_size, 1, 1, 1) 53 | predictions = {'traj': traj, 'probs': probs} 54 | 55 | return predictions 56 | -------------------------------------------------------------------------------- /models/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import abc 4 | from typing import Union, Dict 5 | 6 | 7 | # Initialize device: 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class PredictionDecoder(nn.Module): 12 | """ 13 | Base class for decoders for single agent prediction. 14 | Outputs K trajectories and/or their probabilities 15 | """ 16 | 17 | def __init__(self): 18 | super().__init__() 19 | 20 | @abc.abstractmethod 21 | def forward(self, agg_encoding: Union[torch.Tensor, Dict]) -> Union[torch.Tensor, Dict]: 22 | """ 23 | Forward pass for prediction decoder 24 | :param agg_encoding: Aggregated context encoding 25 | :return outputs: K Predicted trajectories and/or their probabilities/scores 26 | """ 27 | raise NotImplementedError() 28 | -------------------------------------------------------------------------------- /models/decoders/lvm.py: -------------------------------------------------------------------------------- 1 | from models.decoders.decoder import PredictionDecoder 2 | import torch 3 | import torch.nn as nn 4 | from typing import Dict, Union 5 | from models.decoders.utils import cluster_traj 6 | 7 | 8 | # Initialize device: 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class LVM(PredictionDecoder): 13 | 14 | def __init__(self, args): 15 | """ 16 | Latent variable conditioned decoder. 17 | 18 | args to include: 19 | agg_type: 'combined' or 'sample_specific'. Whether we have a single aggregated context vector or sample-specific 20 | num_samples: int Number of trajectories to sample 21 | op_len: int Length of predicted trajectories 22 | lv_dim: int Dimension of latent variable 23 | encoding_size: int Dimension of encoded scene + agent context 24 | hidden_size: int Size of output mlp hidden layer 25 | num_clusters: int Number of final clustered trajectories to output 26 | 27 | """ 28 | super().__init__() 29 | self.agg_type = args['agg_type'] 30 | self.num_samples = args['num_samples'] 31 | self.op_len = args['op_len'] 32 | self.lv_dim = args['lv_dim'] 33 | self.hidden = nn.Linear(args['encoding_size'] + args['lv_dim'], args['hidden_size']) 34 | self.op_traj = nn.Linear(args['hidden_size'], args['op_len'] * 2) 35 | self.leaky_relu = nn.LeakyReLU() 36 | self.num_clusters = args['num_clusters'] 37 | 38 | def forward(self, inputs: Union[Dict, torch.Tensor]) -> Dict: 39 | """ 40 | Forward pass for latent variable model. 41 | 42 | :param inputs: aggregated context encoding, 43 | shape for combined encoding: [batch_size, encoding_size] 44 | shape if sample specific encoding: [batch_size, num_samples, encoding_size] 45 | :return: predictions 46 | """ 47 | 48 | if type(inputs) is torch.Tensor: 49 | agg_encoding = inputs 50 | else: 51 | agg_encoding = inputs['agg_encoding'] # [64,100,160] 52 | 53 | if self.agg_type == 'combined': 54 | agg_encoding = agg_encoding.unsqueeze(1).repeat(1, self.num_samples, 1) 55 | else: 56 | if len(agg_encoding.shape) != 3 or agg_encoding.shape[1] != self.num_samples: 57 | raise Exception('Expected ' + str(self.num_samples) + 'encodings for each train/val data') 58 | 59 | # Sample latent variable and concatenate with aggregated encoding 60 | batch_size = agg_encoding.shape[0] 61 | z = torch.randn(batch_size, self.num_samples, self.lv_dim, device=device) 62 | agg_encoding = torch.cat((agg_encoding, z), dim=2) 63 | h = self.leaky_relu(self.hidden(agg_encoding)) 64 | 65 | # Output trajectories 66 | traj = self.op_traj(h) 67 | traj = traj.reshape(batch_size, self.num_samples, self.op_len, 2) 68 | 69 | # Cluster 70 | traj_clustered, probs = cluster_traj(self.num_clusters, traj) 71 | 72 | predictions = {'traj': traj_clustered, 'probs': probs, 'att': inputs['att']} 73 | 74 | if type(inputs) is dict: 75 | for key, val in inputs.items(): 76 | if key != 'agg_encoding': 77 | predictions[key] = val 78 | 79 | return predictions 80 | -------------------------------------------------------------------------------- /models/decoders/mtp.py: -------------------------------------------------------------------------------- 1 | from models.decoders.decoder import PredictionDecoder 2 | from models.decoders.utils import bivariate_gaussian_activation 3 | import torch 4 | import torch.nn as nn 5 | from typing import Dict 6 | 7 | 8 | class MTP(PredictionDecoder): 9 | 10 | def __init__(self, args): 11 | """ 12 | Prediction decoder for MTP 13 | 14 | args to include: 15 | num_modes: int number of modes K 16 | op_len: int prediction horizon 17 | hidden_size: int hidden layer size 18 | encoding_size: int size of context encoding 19 | use_variance: Whether to output variance params along with mean predicted locations 20 | """ 21 | 22 | super().__init__() 23 | 24 | self.agg_type = args['agg_type'] 25 | self.num_modes = args['num_modes'] 26 | self.op_len = args['op_len'] 27 | self.use_variance = args['use_variance'] 28 | self.op_dim = 5 if self.use_variance else 2 29 | 30 | self.hidden = nn.Linear(args['encoding_size'], args['hidden_size']) 31 | self.traj_op = nn.Linear(args['hidden_size'], args['op_len'] * self.op_dim * self.num_modes) 32 | self.prob_op = nn.Linear(args['hidden_size'], self.num_modes) 33 | 34 | self.leaky_relu = nn.LeakyReLU(0.01) 35 | self.log_softmax = nn.LogSoftmax(dim=1) 36 | 37 | def forward(self, agg_encoding: torch.Tensor) -> Dict: 38 | """ 39 | Forward pass for MTP 40 | :param agg_encoding: aggregated context encoding 41 | :return predictions: dictionary with 'traj': K predicted trajectories and 42 | 'probs': K corresponding probabilities 43 | """ 44 | h = self.leaky_relu(self.hidden(agg_encoding)) 45 | batch_size = h.shape[0] 46 | traj = self.traj_op(h) 47 | probs = self.log_softmax(self.prob_op(h)) 48 | traj = traj.reshape(batch_size, self.num_modes, self.op_len, self.op_dim) 49 | probs = probs.squeeze(dim=-1) 50 | traj = bivariate_gaussian_activation(traj) if self.use_variance else traj 51 | 52 | predictions = {'traj': traj, 'probs': probs} 53 | 54 | return predictions 55 | -------------------------------------------------------------------------------- /models/decoders/multipath.py: -------------------------------------------------------------------------------- 1 | from models.decoders.mtp import MTP 2 | from models.decoders.utils import k_means_anchors 3 | import torch 4 | import torch.nn as nn 5 | from datasets.interface import SingleAgentDataset 6 | from typing import Dict 7 | 8 | 9 | class Multipath(MTP): 10 | 11 | def __init__(self, args): 12 | """ 13 | Prediction decoder for Multipath. Almost identical to MTP, but predicts residuals with respect to anchors, 14 | include the same arguments 15 | """ 16 | 17 | super().__init__(args) 18 | self.anchors = nn.Parameter(torch.zeros(self.num_modes, self.op_len, 2), requires_grad=False) 19 | 20 | def generate_anchors(self, ds: SingleAgentDataset): 21 | """ 22 | Function to initialize anchors 23 | :param ds: train dataset for single agent trajectory prediction 24 | """ 25 | 26 | self.anchors = nn.Parameter(k_means_anchors(self.num_modes, ds)) 27 | 28 | def forward(self, agg_encoding: torch.Tensor) -> Dict: 29 | """ 30 | Forward pass for Multipath 31 | :param agg_encoding: aggregated context encoding 32 | :return predictions: dictionary with 'traj': K predicted trajectories and 33 | 'probs': K corresponding probabilities 34 | """ 35 | 36 | predictions = super().forward(agg_encoding) 37 | predictions['traj'][..., :2] += self.anchors.unsqueeze(0) 38 | 39 | return predictions 40 | -------------------------------------------------------------------------------- /models/decoders/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets.interface import SingleAgentDataset 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | import psutil 6 | import ray 7 | from scipy.spatial.distance import cdist 8 | import numpy 9 | 10 | # Initialize device: 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | # Initialize ray: 15 | num_cpus = psutil.cpu_count(logical=False) 16 | ray.init(num_cpus=num_cpus, log_to_driver=False) 17 | 18 | 19 | def k_means_anchors(k, ds: SingleAgentDataset): 20 | """ 21 | Extracts anchors for multipath/covernet using k-means on train set trajectories 22 | """ 23 | prototype_traj = ds[0]['ground_truth']['traj'] 24 | traj_len = prototype_traj.shape[0] 25 | traj_dim = prototype_traj.shape[1] 26 | ds_size = len(ds) 27 | trajectories = np.zeros((ds_size, traj_len, traj_dim)) 28 | for i, data in enumerate(ds): 29 | trajectories[i] = data['ground_truth']['traj'] 30 | clustering = KMeans(n_clusters=k).fit(trajectories.reshape((ds_size, -1))) 31 | anchors = np.zeros((k, traj_len, traj_dim)) 32 | for i in range(k): 33 | anchors[i] = np.mean(trajectories[clustering.labels_ == i], axis=0) 34 | anchors = torch.from_numpy(anchors).float().to(device) 35 | return anchors 36 | 37 | 38 | def bivariate_gaussian_activation(ip: torch.Tensor) -> torch.Tensor: 39 | """ 40 | Activation function to output parameters of bivariate Gaussian distribution 41 | """ 42 | mu_x = ip[..., 0:1] 43 | mu_y = ip[..., 1:2] 44 | sig_x = ip[..., 2:3] 45 | sig_y = ip[..., 3:4] 46 | rho = ip[..., 4:5] 47 | sig_x = torch.exp(sig_x) 48 | sig_y = torch.exp(sig_y) 49 | rho = torch.tanh(rho) 50 | out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim=-1) 51 | return out 52 | 53 | 54 | @ray.remote 55 | def cluster_and_rank(k: int, data: np.ndarray): 56 | """ 57 | Combines the clustering and ranking steps so that ray.remote gets called just once 58 | """ 59 | 60 | def cluster(n_clusters: int, x: np.ndarray): 61 | """ 62 | Cluster using Scikit learn 63 | """ 64 | clustering_op = KMeans(n_clusters=n_clusters, n_init=1, max_iter=100, init='random').fit(x) 65 | return clustering_op.labels_, clustering_op.cluster_centers_ 66 | 67 | def rank_clusters(cluster_counts, cluster_centers): 68 | """ 69 | Rank the K clustered trajectories using Ward's criterion. Start with K cluster centers and cluster counts. 70 | Find the two clusters to merge based on Ward's criterion. Smaller of the two will get assigned rank K. 71 | Merge the two clusters. Repeat process to assign ranks K-1, K-2, ..., 2. 72 | """ 73 | 74 | num_clusters = len(cluster_counts) 75 | cluster_ids = np.arange(num_clusters) 76 | ranks = np.ones(num_clusters) 77 | 78 | for i in range(num_clusters, 0, -1): 79 | # Compute Ward distances: 80 | centroid_dists = cdist(cluster_centers, cluster_centers) 81 | n1 = cluster_counts.reshape(1, -1).repeat(len(cluster_counts), axis=0) 82 | n2 = n1.transpose() 83 | wts = n1 * n2 / (n1 + n2) 84 | dists = wts * centroid_dists + np.diag(np.inf * np.ones(len(cluster_counts))) 85 | 86 | # Get clusters with min Ward distance and select cluster with fewer counts 87 | c1, c2 = np.unravel_index(dists.argmin(), dists.shape) 88 | c = c1 if cluster_counts[c1] <= cluster_counts[c2] else c2 89 | c_ = c2 if cluster_counts[c1] <= cluster_counts[c2] else c1 90 | 91 | # Assign rank i to selected cluster 92 | ranks[cluster_ids[c]] = i 93 | 94 | # Merge clusters and update identity of merged cluster 95 | cluster_centers[c_] = (cluster_counts[c_] * cluster_centers[c_] + cluster_counts[c] * cluster_centers[c]) /\ 96 | (cluster_counts[c_] + cluster_counts[c]) 97 | cluster_counts[c_] += cluster_counts[c] 98 | 99 | # Discard merged cluster 100 | cluster_ids = np.delete(cluster_ids, c) 101 | cluster_centers = np.delete(cluster_centers, c, axis=0) 102 | cluster_counts = np.delete(cluster_counts, c) 103 | 104 | return ranks 105 | 106 | cluster_lbls, cluster_ctrs = cluster(k, data) 107 | cluster_cnts = np.unique(cluster_lbls, return_counts=True)[1] 108 | cluster_ranks = rank_clusters(cluster_cnts.copy(), cluster_ctrs.copy()) 109 | return {'lbls': cluster_lbls, 'ranks': cluster_ranks, 'counts': cluster_cnts} 110 | 111 | 112 | def cluster_traj(k: int, traj: torch.Tensor): 113 | """ 114 | clusters sampled trajectories to output K modes. 115 | :param k: number of clusters 116 | :param traj: set of sampled trajectories, shape [batch_size, num_samples, traj_len, 2] 117 | :return: traj_clustered: set of clustered trajectories, shape [batch_size, k, traj_len, 2] 118 | scores: scores for clustered trajectories (basically 1/rank), shape [batch_size, k] 119 | """ 120 | 121 | # Initialize output tensors 122 | batch_size = traj.shape[0] 123 | num_samples = traj.shape[1] 124 | traj_len = traj.shape[2] 125 | 126 | # Down-sample traj along time dimension for faster clustering 127 | data = traj[:, :, 0::3, :] 128 | data = data.reshape(batch_size, num_samples, -1).detach().cpu().numpy() 129 | 130 | # Cluster and rank 131 | cluster_ops = ray.get([cluster_and_rank.remote(k, data_slice) for data_slice in data]) 132 | cluster_lbls = numpy.array([cluster_op['lbls'] for cluster_op in cluster_ops]) 133 | cluster_counts = numpy.array([cluster_op['counts'] for cluster_op in cluster_ops]) 134 | cluster_ranks = numpy.array([cluster_op['ranks'] for cluster_op in cluster_ops]) 135 | 136 | # Compute mean (clustered) traj and scores 137 | lbls = torch.as_tensor(cluster_lbls, device=device).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, traj_len, 2).long() 138 | traj_summed = torch.zeros(batch_size, k, traj_len, 2, device=device).scatter_add(1, lbls, traj) 139 | cnt_tensor = torch.as_tensor(cluster_counts, device=device).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, traj_len, 2) 140 | traj_clustered = traj_summed / cnt_tensor 141 | scores = 1 / torch.as_tensor(cluster_ranks, device=device) 142 | scores = scores / torch.sum(scores, dim=1)[0] 143 | 144 | return traj_clustered, scores 145 | -------------------------------------------------------------------------------- /models/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import abc 3 | from typing import Dict 4 | 5 | 6 | class PredictionEncoder(nn.Module): 7 | """ 8 | Base class for encoders for single agent prediction. 9 | """ 10 | def __init__(self): 11 | super().__init__() 12 | 13 | @abc.abstractmethod 14 | def forward(self, inputs: Dict) -> Dict: 15 | """ 16 | Abstract method for forward pass. Returns dictionary of encodings. Should typically include 17 | 1) target agent encoding, 2) context encoding: encodes map and surrounding agents. 18 | 19 | Context encodings will typically be a set of features (agents or parts of the map), 20 | with shape: [batch_size, set_size, feature_dim], 21 | sometimes along with masks for some set elements to account for varying set sizes 22 | 23 | :param inputs: Dictionary with 24 | 'target_agent_representation': target agent history 25 | 'surrounding_agent_representation': surrounding agent history 26 | 'map_representation': HD map representation 27 | :return encodings: Dictionary with input encodings 28 | """ 29 | raise NotImplementedError() 30 | -------------------------------------------------------------------------------- /models/encoders/pgp_encoder.py: -------------------------------------------------------------------------------- 1 | from models.encoders.encoder import PredictionEncoder 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | from typing import Dict 6 | import numpy as np 7 | 8 | # Initialize device: 9 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class PGPEncoder(PredictionEncoder): 13 | 14 | def __init__(self, args: Dict): 15 | """ 16 | GRU based encoder from PGP. Lane node features and agent histories encoded using GRUs. 17 | Additionally, agent-node attention layers infuse each node encoding with nearby agent context. 18 | Finally GAT layers aggregate local context at each node. 19 | 20 | args to include: 21 | 22 | target_agent_feat_size: int Size of target agent features 23 | target_agent_emb_size: int Size of target agent embedding 24 | taret_agent_enc_size: int Size of hidden state of target agent GRU encoder 25 | 26 | node_feat_size: int Size of lane node features 27 | node_emb_size: int Size of lane node embedding 28 | node_enc_size: int Size of hidden state of lane node GRU encoder 29 | 30 | nbr_feat_size: int Size of neighboring agent features 31 | nbr_enb_size: int Size of neighboring agent embeddings 32 | nbr_enc_size: int Size of hidden state of neighboring agent GRU encoders 33 | 34 | num_gat_layers: int Number of GAT layers to use. 35 | """ 36 | 37 | super().__init__() 38 | 39 | # Target agent encoder 40 | self.target_agent_emb = nn.Linear(args['target_agent_feat_size'], args['target_agent_emb_size']) 41 | self.target_agent_enc = nn.GRU(args['target_agent_emb_size'], args['target_agent_enc_size'], batch_first=True) 42 | 43 | # Node encoders 44 | self.node_emb = nn.Linear(args['node_feat_size'], args['node_emb_size']) 45 | self.node_encoder = nn.GRU(args['node_emb_size'], args['node_enc_size'], batch_first=True) 46 | self.lane_mask_prob = args['lane_mask_p'] 47 | 48 | # Surrounding agent encoder 49 | self.nbr_emb = nn.Linear(args['nbr_feat_size'] + 1, args['nbr_emb_size']) 50 | self.nbr_enc = nn.GRU(args['nbr_emb_size'], args['nbr_enc_size'], batch_first=True) 51 | self.agent_mask_prob_v = args['agent_mask_p_veh'] 52 | self.agent_mask_prob_p = args['agent_mask_p_ped'] 53 | self.mask_frames_prob = 1-args['mask_frames_p'] 54 | 55 | # Agent-node attention 56 | self.query_emb = nn.Linear(args['node_enc_size'], args['node_enc_size']) 57 | self.key_emb = nn.Linear(args['nbr_enc_size'], args['node_enc_size']) 58 | self.val_emb = nn.Linear(args['nbr_enc_size'], args['node_enc_size']) 59 | self.a_n_att = nn.MultiheadAttention(args['node_enc_size'], num_heads=1) 60 | self.mix = nn.Linear(args['node_enc_size']*2, args['node_enc_size']) 61 | 62 | # Non-linearities 63 | self.leaky_relu = nn.LeakyReLU() 64 | 65 | # GAT layers 66 | self.gat = nn.ModuleList([GAT(args['node_enc_size'], args['node_enc_size']) 67 | for _ in range(args['num_gat_layers'])]) 68 | 69 | def forward(self, inputs: Dict) -> Dict: 70 | """ 71 | Forward pass for PGP encoder 72 | :param inputs: Dictionary with 73 | target_agent_representation: torch.Tensor, shape [batch_size, t_h, target_agent_feat_size] 74 | map_representation: Dict with 75 | 'lane_node_feats': torch.Tensor, shape [batch_size, max_nodes, max_poses, node_feat_size] 76 | 'lane_node_masks': torch.Tensor, shape [batch_size, max_nodes, max_poses, node_feat_size] 77 | 78 | (Optional) 79 | 's_next': Edge look-up table pointing to destination node from source node 80 | 'edge_type': Look-up table with edge type 81 | 82 | surrounding_agent_representation: Dict with 83 | 'vehicles': torch.Tensor, shape [batch_size, max_vehicles, t_h, nbr_feat_size] 84 | 'vehicle_masks': torch.Tensor, shape [batch_size, max_vehicles, t_h, nbr_feat_size] 85 | 'pedestrians': torch.Tensor, shape [batch_size, max_peds, t_h, nbr_feat_size] 86 | 'pedestrian_masks': torch.Tensor, shape [batch_size, max_peds, t_h, nbr_feat_size] 87 | agent_node_masks: Dict with 88 | 'vehicles': torch.Tensor, shape [batch_size, max_nodes, max_vehicles] 89 | 'pedestrians': torch.Tensor, shape [batch_size, max_nodes, max_pedestrians] 90 | 91 | Optionally may also include the following if edges are defined for graph traversal 92 | 'init_node': Initial node in the lane graph based on track history. 93 | 'node_seq_gt': Ground truth node sequence for pre-training 94 | 95 | :return: 96 | """ 97 | 98 | # Encode target agent 99 | target_agent_feats = inputs['target_agent_representation'] 100 | target_agent_embedding = self.leaky_relu(self.target_agent_emb(target_agent_feats)) 101 | _, target_agent_enc = self.target_agent_enc(target_agent_embedding) 102 | target_agent_enc = target_agent_enc.squeeze(0) #B,32 103 | 104 | # Encode lane nodes 105 | lane_node_feats = inputs['map_representation']['lane_node_feats'] 106 | lane_node_masks = inputs['map_representation']['lane_node_masks'] 107 | # Mask out lane_node_masks by node 'mask_prob'% of the time - 1 means mask out 108 | mask_out = torch.bernoulli(torch.ones((lane_node_masks.shape[:2])) * self.lane_mask_prob).unsqueeze(-1).repeat(1,1,lane_node_masks.shape[-2]).unsqueeze(-1).repeat(1,1,1,lane_node_masks.shape[-1]).to(lane_node_masks.device) 109 | lane_node_masks = ~lane_node_masks.bool() & ~mask_out.bool() 110 | inputs['map_representation']['lane_node_masks'] = lane_node_masks # for visualization purposes 111 | lane_node_embedding = self.leaky_relu(self.node_emb(lane_node_feats)) 112 | lane_node_enc = self.variable_size_gru_encode(lane_node_embedding, lane_node_masks, self.node_encoder) # B,164,32 113 | 114 | # Encode surrounding agents 115 | nbr_vehicle_feats = inputs['surrounding_agent_representation']['vehicles'] 116 | nbr_vehicle_feats = torch.cat((nbr_vehicle_feats, torch.zeros_like(nbr_vehicle_feats[:, :, :, 0:1])), dim=-1) 117 | nbr_vehicle_masks = inputs['surrounding_agent_representation']['vehicle_masks'] 118 | 119 | 120 | ###### Mask out only some frames of vehicles that are in the radius of 20m from agent 121 | if self.mask_frames_prob < 1: 122 | target_adj_matrix = inputs['surrounding_agent_representation']['adj_matrix'][:,0,1:nbr_vehicle_feats.shape[1]+1] 123 | target_adj_matrix = target_adj_matrix.unsqueeze(-1).repeat(1,1,nbr_vehicle_feats.shape[2]) 124 | # Mask out frames of nearby agents with a p% probability 125 | target_adj_matrix *= torch.bernoulli(target_adj_matrix * self.mask_frames_prob) 126 | nbr_vehicle_masks = nbr_vehicle_masks + target_adj_matrix.unsqueeze(-1).repeat(1,1,1,nbr_vehicle_masks.shape[-1]) 127 | inputs['agent_node_masks']['vehicles'] = inputs['agent_node_masks']['vehicles'].int() | nbr_vehicle_masks[:,:,:,0].any(-1).int().unsqueeze(1).repeat(1,164,1).int() 128 | ############## 129 | 130 | 131 | # Mask out nbr_vehicle_masks by agent 20% of the time - 1 means mask out 132 | mask_out = torch.bernoulli(torch.ones((nbr_vehicle_masks.shape[:2])) * self.agent_mask_prob_v).unsqueeze(-1).repeat(1,1,nbr_vehicle_masks.shape[-2]).unsqueeze(-1).repeat(1,1,1,nbr_vehicle_masks.shape[-1]).to(nbr_vehicle_masks.device) 133 | nbr_vehicle_masks = ~nbr_vehicle_masks.bool() & ~mask_out.bool() 134 | # Update agent_node_masks with mask_out 135 | inputs['agent_node_masks']['vehicles'] = inputs['agent_node_masks']['vehicles'].int() | mask_out[:,:,0,0].unsqueeze(1).repeat(1,164,1).int() 136 | nbr_vehicle_embedding = self.leaky_relu(self.nbr_emb(nbr_vehicle_feats)) 137 | nbr_vehicle_enc = self.variable_size_gru_encode(nbr_vehicle_embedding, nbr_vehicle_masks, self.nbr_enc) #B,84,32 138 | nbr_ped_feats = inputs['surrounding_agent_representation']['pedestrians'] 139 | nbr_ped_feats = torch.cat((nbr_ped_feats, torch.ones_like(nbr_ped_feats[:, :, :, 0:1])), dim=-1) 140 | nbr_ped_masks = inputs['surrounding_agent_representation']['pedestrian_masks'] 141 | # Mask out nbr_vehicle_masks by agent 20% of the time 142 | mask_out = torch.bernoulli(torch.ones((nbr_ped_masks.shape[:2])) * self.agent_mask_prob_p).unsqueeze(-1).repeat(1,1,nbr_ped_masks.shape[-2]).unsqueeze(-1).repeat(1,1,1,nbr_ped_masks.shape[-1]).to(nbr_ped_masks.device) 143 | nbr_ped_masks = ~nbr_ped_masks.bool() & ~mask_out.bool() 144 | inputs['agent_node_masks']['pedestrians'] = inputs['agent_node_masks']['pedestrians'].int() | mask_out[:,:,0,0].unsqueeze(1).repeat(1,164,1).int() 145 | nbr_ped_embedding = self.leaky_relu(self.nbr_emb(nbr_ped_feats)) 146 | nbr_ped_enc = self.variable_size_gru_encode(nbr_ped_embedding, nbr_ped_masks, self.nbr_enc) 147 | 148 | # Agent-node attention 149 | nbr_encodings = torch.cat((nbr_vehicle_enc, nbr_ped_enc), dim=1) 150 | queries = self.query_emb(lane_node_enc).permute(1, 0, 2) 151 | keys = self.key_emb(nbr_encodings).permute(1, 0, 2) 152 | vals = self.val_emb(nbr_encodings).permute(1, 0, 2) 153 | attn_masks = torch.cat((inputs['agent_node_masks']['vehicles'].float(), 154 | inputs['agent_node_masks']['pedestrians'].float()), dim=2) 155 | att_op, _ = self.a_n_att(queries, keys, vals, attn_mask=attn_masks) 156 | att_op = att_op.permute(1, 0, 2) 157 | 158 | # Concatenate with original node encodings and 1x1 conv 159 | lane_node_enc = self.leaky_relu(self.mix(torch.cat((lane_node_enc, att_op), dim=2))) 160 | 161 | # GAT layers 162 | adj_mat = self.build_adj_mat(inputs['map_representation']['s_next'], inputs['map_representation']['edge_type']) 163 | for gat_layer in self.gat: 164 | lane_node_enc += gat_layer(lane_node_enc, adj_mat) 165 | 166 | # Lane node masks 167 | lane_node_masks = ~lane_node_masks[:, :, :, 0].bool() 168 | lane_node_masks = lane_node_masks.any(dim=2) 169 | lane_node_masks = ~lane_node_masks 170 | lane_node_masks = lane_node_masks.float() # 0 if node exists 171 | 172 | # Return encodings 173 | encodings = {'target_agent_encoding': target_agent_enc, 174 | 'context_encoding': {'combined': lane_node_enc, 175 | 'combined_masks': lane_node_masks, 176 | 'map': None, 177 | 'vehicles': None, 178 | 'pedestrians': None, 179 | 'map_masks': None, 180 | 'vehicle_masks': None, 181 | 'pedestrian_masks': None 182 | }, 183 | } 184 | 185 | # Pass on initial nodes and edge structure to aggregator if included in inputs 186 | if 'init_node' in inputs: 187 | encodings['init_node'] = inputs['init_node'] 188 | encodings['node_seq_gt'] = inputs['node_seq_gt'] 189 | encodings['s_next'] = inputs['map_representation']['s_next'] 190 | encodings['edge_type'] = inputs['map_representation']['edge_type'] 191 | 192 | return encodings 193 | 194 | @staticmethod 195 | def variable_size_gru_encode(feat_embedding: torch.Tensor, masks: torch.Tensor, gru: nn.GRU) -> torch.Tensor: 196 | """ 197 | Returns GRU encoding for a batch of inputs where each sample in the batch is a set of a variable number 198 | of sequences, of variable lengths. 199 | """ 200 | 201 | # Form a large batch of all sequences in the batch 202 | # masks_for_batching = ~masks[:, :, :, 0].bool() # B, 164,20,6 203 | masks_for_batching = masks[:, :, :, 0].any(dim=-1).unsqueeze(2).unsqueeze(3) # 32, 164,1,1 204 | feat_embedding_batched = torch.masked_select(feat_embedding, masks_for_batching) 205 | feat_embedding_batched = feat_embedding_batched.view(-1, feat_embedding.shape[2], feat_embedding.shape[3]) # 970,20,16 206 | 207 | # Pack padded sequences 208 | seq_lens = torch.sum(masks[:, :, :, 0].long(), dim=-1) # B, 164 209 | seq_lens_batched = seq_lens[seq_lens != 0].cpu() # Bx164 != 0 =» 970 210 | if len(seq_lens_batched) != 0: 211 | feat_embedding_packed = pack_padded_sequence(feat_embedding_batched, seq_lens_batched, 212 | batch_first=True, enforce_sorted=False) 213 | 214 | # Encode 215 | _, encoding_batched = gru(feat_embedding_packed) 216 | encoding_batched = encoding_batched.squeeze(0) 217 | 218 | # Scatter back to appropriate batch index 219 | masks_for_scattering = masks_for_batching.squeeze(3).repeat(1, 1, encoding_batched.shape[-1]) 220 | encoding = torch.zeros(masks_for_scattering.shape, device=device) 221 | encoding = encoding.masked_scatter(masks_for_scattering, encoding_batched) 222 | 223 | else: 224 | batch_size = feat_embedding.shape[0] 225 | max_num = feat_embedding.shape[1] 226 | hidden_state_size = gru.hidden_size 227 | encoding = torch.zeros((batch_size, max_num, hidden_state_size), device=device) 228 | 229 | return encoding 230 | 231 | @staticmethod 232 | def build_adj_mat(s_next, edge_type): 233 | """ 234 | Builds adjacency matrix for GAT layers. 235 | """ 236 | batch_size = s_next.shape[0] 237 | max_nodes = s_next.shape[1] 238 | max_edges = s_next.shape[2] 239 | adj_mat = torch.diag(torch.ones(max_nodes, device=device)).unsqueeze(0).repeat(batch_size, 1, 1).bool() 240 | 241 | dummy_vals = torch.arange(max_nodes, device=device).unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, max_edges) 242 | dummy_vals = dummy_vals.float() 243 | s_next[edge_type == 0] = dummy_vals[edge_type == 0] 244 | batch_indices = torch.arange(batch_size).unsqueeze(1).unsqueeze(2).repeat(1, max_nodes, max_edges) 245 | src_indices = torch.arange(max_nodes).unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, max_edges) 246 | adj_mat[batch_indices[:, :, :-1], src_indices[:, :, :-1], s_next[:, :, :-1].long()] = True 247 | adj_mat = adj_mat | torch.transpose(adj_mat, 1, 2) 248 | 249 | return adj_mat 250 | 251 | 252 | class GAT(nn.Module): 253 | """ 254 | GAT layer for aggregating local context at each lane node. Uses scaled dot product attention using pytorch's 255 | multihead attention module. 256 | """ 257 | def __init__(self, in_channels, out_channels): 258 | """ 259 | Initialize GAT layer. 260 | :param in_channels: size of node encodings 261 | :param out_channels: size of aggregated node encodings 262 | """ 263 | super().__init__() 264 | self.query_emb = nn.Linear(in_channels, out_channels) 265 | self.key_emb = nn.Linear(in_channels, out_channels) 266 | self.val_emb = nn.Linear(in_channels, out_channels) 267 | self.att = nn.MultiheadAttention(out_channels, 1) 268 | 269 | def forward(self, node_encodings, adj_mat): 270 | """ 271 | Forward pass for GAT layer 272 | :param node_encodings: Tensor of node encodings, shape [batch_size, max_nodes, node_enc_size] 273 | :param adj_mat: Bool tensor, adjacency matrix for edges, shape [batch_size, max_nodes, max_nodes] 274 | :return: 275 | """ 276 | queries = self.query_emb(node_encodings.permute(1, 0, 2)) 277 | keys = self.key_emb(node_encodings.permute(1, 0, 2)) 278 | vals = self.val_emb(node_encodings.permute(1, 0, 2)) 279 | att_op, _ = self.att(queries, keys, vals, attn_mask=~adj_mat) 280 | 281 | return att_op.permute(1, 0, 2) 282 | -------------------------------------------------------------------------------- /models/encoders/polyline_subgraph.py: -------------------------------------------------------------------------------- 1 | from models.encoders.encoder import PredictionEncoder 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from typing import Dict, Tuple 6 | 7 | 8 | # TODO (WiP): Test with different datasets, visualize results. 9 | class PolylineSubgraphs(PredictionEncoder): 10 | 11 | def __init__(self, args: Dict): 12 | """ 13 | Polyline subgraph encoder from VectorNet (Gao et al., CVPR 2020). 14 | Has N encoder layers. Each layer encodes every feature in a polyline using an MLP with shared 15 | weights, followed by a permutation invariant aggregation operator (element-wise max used in the paper). 16 | Aggregated vector is concatenated with each independent feature encoding. 17 | Layer is repeated N times. Final encodings are passed through the permutation invariant 18 | aggregation operator to give polyline encodings. 19 | 20 | args to include 21 | 'num_layers': int Number of repeated encoder layers 22 | 'mlp_size': int Width of MLP hidden layer 23 | 'lane_feat_size': int Lane feature dimension 24 | 'agent_feat_size': int Agent feature dimension 25 | 26 | """ 27 | super().__init__() 28 | self.num_layers = args['num_layers'] 29 | self.mlp_size = args['mlp_size'] 30 | self.lane_feat_size = args['lane_feat_size'] 31 | self.agent_feat_size = args['agent_feat_size'] 32 | 33 | # Encoder layers 34 | 35 | """ 36 | Note: I'm not completely sure if VectorNet uses different MLPs for agents, map polylines and map polygons. 37 | The paper doesn't seem to mention this clearly. However, agents and map polylines will typically have different 38 | attribute features. At least the first linear layer has to be different. 39 | Shouldn't affect the global attention aggregator. All final feats will have the same dimensions. 40 | """ 41 | 42 | lane_encoders = [nn.Linear(self.lane_feat_size + 2, self.mlp_size)] 43 | for n in range(1, self.num_layers): 44 | lane_encoders.append(nn.Linear(self.mlp_size * 2, self.mlp_size)) 45 | self.lane_encoders = nn.ModuleList(lane_encoders) 46 | 47 | target_agent_encoders = [nn.Linear(self.agent_feat_size + 2, self.mlp_size)] 48 | for n in range(1, self.num_layers): 49 | target_agent_encoders.append(nn.Linear(self.mlp_size * 2, self.mlp_size)) 50 | self.target_agent_encoders = nn.ModuleList(target_agent_encoders) 51 | 52 | surrounding_vehicle_encoders = [nn.Linear(self.agent_feat_size + 2, self.mlp_size)] 53 | for n in range(1, self.num_layers): 54 | surrounding_vehicle_encoders.append(nn.Linear(self.mlp_size * 2, self.mlp_size)) 55 | self.surrounding_vehicle_encoders = nn.ModuleList(surrounding_vehicle_encoders) 56 | 57 | surrounding_ped_encoders = [nn.Linear(self.agent_feat_size + 2, self.mlp_size)] 58 | for n in range(1, self.num_layers): 59 | surrounding_ped_encoders.append(nn.Linear(self.mlp_size * 2, self.mlp_size)) 60 | self.surrounding_ped_encoders = nn.ModuleList(surrounding_ped_encoders) 61 | 62 | # Layer norm and relu 63 | self.layer_norm = nn.LayerNorm(self.mlp_size) 64 | self.relu = nn.ReLU() 65 | 66 | def forward(self, inputs: Dict) -> Dict: 67 | 68 | target_agent_feats = inputs['target_agent_representation'] 69 | lane_feats = inputs['map_representation']['lane_node_feats'] 70 | lane_masks = inputs['map_representation']['lane_node_masks'] 71 | vehicle_feats = inputs['surrounding_agent_representation']['vehicles'] 72 | vehicle_masks = inputs['surrounding_agent_representation']['vehicle_masks'] 73 | ped_feats = inputs['surrounding_agent_representation']['pedestrians'] 74 | ped_masks = inputs['surrounding_agent_representation']['pedestrian_masks'] 75 | 76 | # Encode target agent 77 | target_agent_feats = self.convert2vectornet_feat_format(target_agent_feats.unsqueeze(1)) 78 | target_agent_enc, _ = self.encode(self.target_agent_encoders, target_agent_feats, 79 | torch.zeros_like(target_agent_feats)) 80 | target_agent_enc = target_agent_enc.squeeze(1) 81 | 82 | # Encode lanes 83 | lane_feats = self.convert2vectornet_feat_format(lane_feats) 84 | lane_masks = lane_masks[:, :, :-1, :] 85 | lane_enc, lane_masks = self.encode(self.lane_encoders, lane_feats, lane_masks) 86 | 87 | # Encode surrounding agents 88 | vehicle_feats = self.convert2vectornet_feat_format(vehicle_feats) 89 | vehicle_masks = vehicle_masks[:, :, :-1, :] 90 | vehicle_enc, vehicle_masks = self.encode(self.surrounding_vehicle_encoders, vehicle_feats, vehicle_masks) 91 | ped_feats = self.convert2vectornet_feat_format(ped_feats) 92 | ped_masks = ped_masks[:, :, :-1, :] 93 | ped_enc, ped_masks = self.encode(self.surrounding_ped_encoders, ped_feats, ped_masks) 94 | 95 | # Return encodings 96 | encodings = {'target_agent_encoding': target_agent_enc, 97 | 'context_encoding': {'combined': None, 98 | 'combined_masks': None, 99 | 'map': lane_enc, 100 | 'vehicles': vehicle_enc, 101 | 'pedestrians': ped_enc, 102 | 'map_masks': lane_masks, 103 | 'vehicle_masks': vehicle_masks, 104 | 'pedestrian_masks': ped_masks 105 | }, 106 | } 107 | 108 | return encodings 109 | 110 | def encode(self, encoder_layers: nn.ModuleList, input_feats: torch.Tensor, 111 | masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 112 | """ 113 | Applies encoding layers to a given set of input feats 114 | """ 115 | masks = masks[..., 0] 116 | masks[masks == 1] = -math.inf 117 | 118 | encodings = input_feats 119 | for n in range(len(encoder_layers)): 120 | encodings = self.relu(self.layer_norm(encoder_layers[n](encodings))) 121 | encodings = encodings + masks.unsqueeze(-1) 122 | agg_enc, _ = torch.max(encodings, dim=2) 123 | encodings = torch.cat((encodings, agg_enc.unsqueeze(2).repeat(1, 1, encodings.shape[2], 1)), dim=3) 124 | encodings[encodings == -math.inf] = 0 125 | 126 | agg_encoding, _ = torch.max(encodings, dim=2) 127 | masks[masks == -math.inf] = 1 128 | 129 | return agg_encoding, masks[..., 0] 130 | 131 | @staticmethod 132 | def convert2vectornet_feat_format(feats: torch.Tensor) -> torch.Tensor: 133 | """ 134 | Helper function to convert a tensor of node features to the vectornet format. 135 | By default the datasets return node features of the format [x, y, attribute feats...]. 136 | Vectornet uses the following format [x, y, x_next, y_next, attribute_feats] 137 | :param feats: Tensor of feats, shape [batch_size, max_polylines, max_len, feat_dim] 138 | :return: Tensor of updated feats, shape [batch_size, max_polylines, max_len, feat_dim + 2] 139 | """ 140 | xy = feats[:, :, :-1, :2] 141 | xy_next = feats[:, :, 1:, :2] 142 | attr = feats[:, :, :-1, 2:] 143 | feats = torch.cat((xy, xy_next, attr), dim=3) 144 | return feats 145 | -------------------------------------------------------------------------------- /models/encoders/raster_encoder.py: -------------------------------------------------------------------------------- 1 | from models.encoders.encoder import PredictionEncoder 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import resnet18, resnet34, resnet50 5 | from positional_encodings.torch_encodings import PositionalEncodingPermute2D 6 | from typing import Dict 7 | 8 | 9 | class RasterEncoder(PredictionEncoder): 10 | 11 | def __init__(self, args: Dict): 12 | """ 13 | CNN encoder for raster representation of HD maps and surrounding agent trajectories. 14 | 15 | args to include 16 | 'backbone': str CNN backbone to use (resnet18, resnet34 or resnet50) 17 | 'input_channels': int Size of scene features at each grid cell 18 | 'use_positional_encoding: bool Whether or not to add positional encodings to final set of features 19 | 'target_agent_feat_size': int Size of target agent state 20 | """ 21 | 22 | super().__init__() 23 | 24 | # Anything more seems like overkill 25 | resnet_backbones = {'resnet18': resnet18, 26 | 'resnet34': resnet34, 27 | 'resnet50': resnet50} 28 | 29 | # Initialize backbone: 30 | resnet_model = resnet_backbones[args['backbone']](pretrained=False) 31 | conv1_new = nn.Conv2d(args['input_channels'], 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 32 | modules = list(resnet_model.children())[:-2] 33 | modules[0] = conv1_new 34 | self.backbone = nn.Sequential(*modules) 35 | 36 | # Positional encodings: 37 | num_channels = 2048 if self.backbone == 'resnet50' else 512 38 | self.use_pos_enc = args['use_positional_encoding'] 39 | if self.use_pos_enc: 40 | self.pos_enc = PositionalEncodingPermute2D(num_channels) 41 | 42 | # Linear layer to embed target agent representation. 43 | self.target_agent_encoder = nn.Linear(args['target_agent_feat_size'], args['target_agent_enc_size']) 44 | self.relu = nn.ReLU() 45 | 46 | def forward(self, inputs: Dict) -> Dict: 47 | 48 | """ 49 | Forward pass for raster encoder 50 | :param inputs: Dictionary with 51 | target_agent_representation: torch.Tensor with target agent state, shape[batch_size, target_agent_feat_size] 52 | surrounding_agent_representation: Rasterized BEV representation, shape [batch_size, 3, H, W] 53 | map_representation: Rasterized BEV representation, shape [batch_size, 3, H, W] 54 | :return encodings: Dictionary with 55 | 'target_agent_encoding': torch.Tensor of shape [batch_size, 3], 56 | 'context_encoding': torch.Tensor of shape [batch_size, N, backbone_feat_dim] 57 | """ 58 | 59 | # Unpack inputs: 60 | target_agent_representation = inputs['target_agent_representation'] 61 | surrounding_agent_representation = inputs['surrounding_agent_representation'] 62 | map_representation = inputs['map_representation'] 63 | 64 | # Apply Conv layers 65 | rasterized_input = torch.cat((map_representation, surrounding_agent_representation), dim=1) 66 | context_encoding = self.backbone(rasterized_input) 67 | 68 | # Add positional encoding 69 | if self.use_pos_enc: 70 | context_encoding = context_encoding + self.pos_enc(context_encoding) 71 | 72 | # Reshape to form a set of features 73 | context_encoding = context_encoding.view(context_encoding.shape[0], context_encoding.shape[1], -1) 74 | context_encoding = context_encoding.permute(0, 2, 1) 75 | 76 | # Target agent encoding 77 | target_agent_enc = self.relu(self.target_agent_encoder(target_agent_representation)) 78 | 79 | # Return encodings 80 | encodings = {'target_agent_encoding': target_agent_enc, 81 | 'context_encoding': {'combined': context_encoding, 82 | 'combined_masks': torch.zeros_like(context_encoding[..., 0]), 83 | 'map': None, 84 | 'vehicles': None, 85 | 'pedestrians': None, 86 | 'map_masks': None, 87 | 'vehicle_masks': None, 88 | 'pedestrian_masks': None 89 | }, 90 | } 91 | return encodings 92 | -------------------------------------------------------------------------------- /models/heterograph_models.py: -------------------------------------------------------------------------------- 1 | from models.layers import HeteroRGCNLayer, HGTLayer, ieHGCNConv 2 | from abc import ABCMeta 3 | import dgl 4 | from dgl.ops import edge_softmax 5 | from dgl.nn import TypedLinear 6 | 7 | import dgl.function as Fn 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch 11 | 12 | 13 | class BaseModel(nn.Module, metaclass=ABCMeta): 14 | @classmethod 15 | def build_model_from_args(cls, args, hg): 16 | r""" 17 | Build the model instance from args and hg. 18 | So every subclass inheriting it should override the method. 19 | """ 20 | raise NotImplementedError("Models must implement the build_model_from_args method") 21 | 22 | def __init__(self): 23 | super(BaseModel, self).__init__() 24 | 25 | def forward(self, *args): 26 | r""" 27 | The model plays a role of encoder. So the forward will encoder original features into new features. 28 | Parameters 29 | ----------- 30 | hg : dgl.DGlHeteroGraph 31 | the heterogeneous graph 32 | h_dict : dict[str, th.Tensor] 33 | the dict of heterogeneous feature 34 | Return 35 | ------- 36 | out_dic : dict[str, th.Tensor] 37 | A dict of encoded feature. In general, it should ouput all nodes embedding. 38 | It is allowed that just output the embedding of target nodes which are participated in loss calculation. 39 | """ 40 | raise NotImplementedError 41 | 42 | def extra_loss(self): 43 | r""" 44 | Some model want to use L2Norm which is not applied all parameters. 45 | Returns 46 | ------- 47 | th.Tensor 48 | """ 49 | raise NotImplementedError 50 | 51 | def h2dict(self, h, hdict): 52 | pre = 0 53 | out_dict = {} 54 | for i, value in hdict.items(): 55 | out_dict[i] = h[pre:value.shape[0]+pre] 56 | pre += value.shape[0] 57 | return out_dict 58 | 59 | def get_emb(self): 60 | r""" 61 | Return the embedding of a model for further analysis. 62 | Returns 63 | ------- 64 | numpy.array 65 | """ 66 | raise 67 | 68 | class HGT(nn.Module): 69 | def __init__(self, node_dict, edge_dict, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True): 70 | super(HGT, self).__init__() 71 | self.node_dict = node_dict 72 | self.edge_dict = edge_dict 73 | self.gcs = nn.ModuleList() 74 | self.n_inp = n_inp 75 | self.n_hid = n_hid 76 | self.n_out = n_out 77 | self.n_layers = n_layers 78 | self.adapt_ws = nn.ModuleList() 79 | for t in range(len(node_dict)): 80 | self.adapt_ws.append(nn.Linear(n_inp, n_hid)) 81 | for i in range(n_layers): 82 | self.gcs.append(HGTLayer(n_hid, n_hid, node_dict, edge_dict, n_heads[i], use_norm = use_norm)) 83 | self.out = nn.Linear(n_hid, n_out) 84 | 85 | def forward(self, G, out_key): 86 | with G.local_scope(): 87 | h = {} 88 | for ntype in G.ntypes: 89 | n_id = self.node_dict[ntype] 90 | h[ntype] = F.gelu(self.adapt_ws[n_id](G.nodes[ntype].data['inp'])) 91 | for i in range(self.n_layers): 92 | h = self.gcs[i](G, h, out_key) 93 | for key in out_key: 94 | h[key] = self.out(h[key]) 95 | return [h[key] for key in out_key] 96 | 97 | class SimpleHGN(BaseModel): 98 | r""" 99 | This is a model SimpleHGN from `Are we really making much progress? Revisiting, benchmarking, and 100 | refining heterogeneous graph neural networks 101 | `__ 102 | The model extend the original graph attention mechanism in GAT by including edge type information into attention calculation. 103 | Calculating the coefficient: 104 | 105 | .. math:: 106 | \alpha_{ij} = \frac{exp(LeakyReLU(a^T[Wh_i||Wh_j||W_r r_{\psi()}]))}{\Sigma_{k\in\mathcal{E}}{exp(LeakyReLU(a^T[Wh_i||Wh_k||W_r r_{\psi()}]))}} \quad (1) 107 | 108 | Residual connection including Node residual: 109 | 110 | .. math:: 111 | h_i^{(l)} = \sigma(\Sigma_{j\in \mathcal{N}_i} {\alpha_{ij}^{(l)}W^{(l)}h_j^{(l-1)}} + h_i^{(l-1)}) \quad (2) 112 | 113 | and Edge residual: 114 | 115 | .. math:: 116 | \alpha_{ij}^{(l)} = (1-\beta)\alpha_{ij}^{(l)}+\beta\alpha_{ij}^{(l-1)} \quad (3) 117 | 118 | Multi-heads: 119 | 120 | .. math:: 121 | h^{(l+1)}_j = \parallel^M_{m = 1}h^{(l + 1, m)}_j \quad (4) 122 | 123 | Residual: 124 | 125 | .. math:: 126 | h^{(l+1)}_j = h^{(l)}_j + \parallel^M_{m = 1}h^{(l + 1, m)}_j \quad (5) 127 | 128 | Parameters 129 | ---------- 130 | edge_dim: int 131 | the edge dimension 132 | num_etypes: int 133 | the number of the edge type 134 | in_dim: int 135 | the input dimension 136 | hidden_dim: int 137 | the output dimension 138 | num_classes: int 139 | the number of the output classes 140 | num_layers: int 141 | the number of layers we used in the computing 142 | heads: list 143 | the list of the number of heads in each layer 144 | feat_drop: float 145 | the feature drop rate 146 | negative_slope: float 147 | the negative slope used in the LeakyReLU 148 | residual: boolean 149 | if we need the residual operation 150 | beta: float 151 | the hyperparameter used in edge residual 152 | """ 153 | @classmethod 154 | def build_model_from_args(cls, args, hg): 155 | heads = [args.num_heads] * args.num_layers + [1] 156 | return cls(args.edge_dim, 157 | len(hg.etypes), 158 | [args.hidden_dim], 159 | args.h_dim, 160 | args.out_dim, 161 | args.num_layers, 162 | heads, 163 | args.feats_drop_rate, 164 | args.slope, 165 | True, 166 | args.beta 167 | ) 168 | 169 | def __init__(self, edge_dim, num_etypes, in_dim, hidden_dim, num_classes, 170 | num_layers, heads, feat_drop=0.0, negative_slope=0.2, residual=True, activation=F.elu, beta=0.0): 171 | super(SimpleHGN, self).__init__() 172 | self.num_layers = num_layers 173 | self.hgn_layers = nn.ModuleList() 174 | self.activation = F.elu 175 | 176 | # input projection (no residual) 177 | self.hgn_layers.append( 178 | SimpleHGNConv( 179 | edge_dim, 180 | in_dim, 181 | hidden_dim, 182 | heads[0], 183 | num_etypes, 184 | feat_drop, 185 | negative_slope, 186 | False, 187 | self.activation, 188 | beta=beta, 189 | ) 190 | ) 191 | # hidden layers 192 | for l in range(1, num_layers - 1): # noqa E741 193 | # due to multi-head, the in_dim = hidden_dim * num_heads 194 | self.hgn_layers.append( 195 | SimpleHGNConv( 196 | edge_dim, 197 | hidden_dim * heads[l - 1], 198 | hidden_dim, 199 | heads[l], 200 | num_etypes, 201 | feat_drop, 202 | negative_slope, 203 | residual, 204 | self.activation, 205 | beta=beta, 206 | ) 207 | ) 208 | # output projection 209 | self.hgn_layers.append( 210 | SimpleHGNConv( 211 | edge_dim, 212 | hidden_dim * heads[-2], 213 | num_classes, 214 | heads[-1], 215 | num_etypes, 216 | feat_drop, 217 | negative_slope, 218 | residual, 219 | None, 220 | beta=beta, 221 | ) 222 | ) 223 | 224 | def forward(self, hg, h_dict): 225 | """ 226 | The forward part of the SimpleHGN. 227 | 228 | Parameters 229 | ---------- 230 | hg : object 231 | the dgl heterogeneous graph 232 | h_dict: dict 233 | the feature dict of different node types 234 | 235 | Returns 236 | ------- 237 | dict 238 | The embeddings after the output projection. 239 | """ 240 | with hg.local_scope(): 241 | hg.ndata['h'] = h_dict 242 | g = dgl.to_homogeneous(hg, ndata = 'h') 243 | h = g.ndata['h'] 244 | for l in range(self.num_layers): 245 | h = self.hgn_layers[l](g, h, g.ndata['_TYPE'], g.edata['_TYPE'], True) 246 | h = h.flatten(1) 247 | 248 | h_dict = {} 249 | for index, ntype in enumerate(hg.ntypes): 250 | h_dict[ntype] = h[torch.where(g.ndata['_TYPE'] == index)] 251 | # g.ndata['h'] = h 252 | # hg = dgl.to_heterogeneous(g, hg.ntypes, hg.etypes) 253 | # h_dict = hg.ndata['h'] 254 | 255 | return h_dict['l'], h_dict['v'] 256 | 257 | class SimpleHGNConv(nn.Module): 258 | r""" 259 | The SimpleHGN convolution layer. 260 | Parameters 261 | ---------- 262 | edge_dim: int 263 | the edge dimension 264 | num_etypes: int 265 | the number of the edge type 266 | in_dim: int 267 | the input dimension 268 | out_dim: int 269 | the output dimension 270 | num_heads: int 271 | the number of heads 272 | num_etypes: int 273 | the number of edge type 274 | feat_drop: float 275 | the feature drop rate 276 | negative_slope: float 277 | the negative slope used in the LeakyReLU 278 | residual: boolean 279 | if we need the residual operation 280 | activation: str 281 | the activation function 282 | beta: float 283 | the hyperparameter used in edge residual 284 | """ 285 | def __init__(self, edge_dim, in_dim, out_dim, num_heads, num_etypes, feat_drop=0.0, 286 | negative_slope=0.2, residual=True, activation=F.elu, beta=0.0): 287 | super(SimpleHGNConv, self).__init__() 288 | self.edge_dim = edge_dim 289 | self.in_dim = in_dim 290 | self.out_dim = out_dim 291 | self.num_heads = num_heads 292 | self.num_etypes = num_etypes 293 | 294 | self.edge_emb = nn.Parameter(torch.empty(size=(num_etypes, edge_dim))) 295 | 296 | self.W = nn.Parameter(torch.FloatTensor( 297 | in_dim, out_dim * num_heads)) 298 | self.W_r = TypedLinear(edge_dim, edge_dim * num_heads, num_etypes) 299 | 300 | self.a_l = nn.Parameter(torch.empty(size=(1, num_heads, out_dim))) 301 | self.a_r = nn.Parameter(torch.empty(size=(1, num_heads, out_dim))) 302 | self.a_e = nn.Parameter(torch.empty(size=(1, num_heads, edge_dim))) 303 | 304 | nn.init.xavier_uniform_(self.edge_emb, gain=1.414) 305 | nn.init.xavier_uniform_(self.W, gain=1.414) 306 | nn.init.xavier_uniform_(self.a_l.data, gain=1.414) 307 | nn.init.xavier_uniform_(self.a_r.data, gain=1.414) 308 | nn.init.xavier_uniform_(self.a_e.data, gain=1.414) 309 | 310 | self.feat_drop = nn.Dropout(feat_drop) 311 | self.leakyrelu = nn.LeakyReLU(negative_slope) 312 | self.activation = activation 313 | 314 | if residual: 315 | self.residual = nn.Linear(in_dim, out_dim * num_heads) 316 | else: 317 | self.register_buffer("residual", None) 318 | 319 | self.beta = beta 320 | 321 | def forward(self, g, h, ntype, etype, presorted = False): 322 | """ 323 | The forward part of the SimpleHGNConv. 324 | Parameters 325 | ---------- 326 | g : object 327 | the dgl homogeneous graph 328 | h: tensor 329 | the original features of the graph 330 | ntype: tensor 331 | the node type of the graph 332 | etype: tensor 333 | the edge type of the graph 334 | presorted: boolean 335 | if the ntype and etype are preordered, default: ``False`` 336 | 337 | Returns 338 | ------- 339 | tensor 340 | The embeddings after aggregation. 341 | """ 342 | emb = self.feat_drop(h) 343 | emb = torch.matmul(emb, self.W).view(-1, self.num_heads, self.out_dim) 344 | emb[torch.isnan(emb)] = 0.0 345 | 346 | edge_emb = self.W_r(self.edge_emb[etype], etype, presorted).view(-1, 347 | self.num_heads, self.edge_dim) 348 | 349 | row = g.edges()[0] 350 | col = g.edges()[1] 351 | 352 | h_l = (self.a_l * emb).sum(dim=-1)[row.to(torch.long)] 353 | h_r = (self.a_r * emb).sum(dim=-1)[col.to(torch.long)] 354 | h_e = (self.a_e * edge_emb).sum(dim=-1) 355 | 356 | edge_attention = self.leakyrelu(h_l + h_r + h_e) 357 | edge_attention = edge_softmax(g, edge_attention) 358 | 359 | if 'alpha' in g.edata.keys(): 360 | res_attn = g.edata['alpha'] 361 | edge_attention = edge_attention * \ 362 | (1 - self.beta) + res_attn * self.beta 363 | if self.num_heads == 1: 364 | edge_attention = edge_attention[:, 0] 365 | edge_attention = edge_attention.unsqueeze(1) 366 | 367 | with g.local_scope(): 368 | emb = emb.permute(0, 2, 1).contiguous() 369 | g.edata['alpha'] = edge_attention 370 | g.srcdata['emb'] = emb 371 | g.update_all(Fn.u_mul_e('emb', 'alpha', 'm'), 372 | Fn.sum('m', 'emb')) 373 | # g.apply_edges(Fn.u_mul_e('emb', 'alpha', 'm')) 374 | h_output = g.ndata['emb'].view(-1, self.out_dim * self.num_heads) 375 | # h_prime = [] 376 | # for i in range(self.num_heads): 377 | # g.edata['alpha'] = edge_attention[:, i] 378 | # g.srcdata.update({'emb': emb[i]}) 379 | # g.update_all(Fn.u_mul_e('emb', 'alpha', 'm'), 380 | # Fn.sum('m', 'emb')) 381 | # h_prime.append(g.ndata['emb']) 382 | # h_output = torch.cat(h_prime, dim=1) 383 | 384 | g.edata['alpha'] = edge_attention 385 | if self.residual: 386 | res = self.residual(h) 387 | h_output += res 388 | if self.activation is not None: 389 | h_output = self.activation(h_output) 390 | 391 | return h_output 392 | 393 | 394 | 395 | class ieHGCN(BaseModel): 396 | r""" 397 | ie-HGCN from paper `Interpretable and Efficient Heterogeneous Graph Convolutional Network 398 | `__. 399 | `Source Code Link `_ 400 | 401 | The core part of ie-HGCN, the calculating flow of projection, object-level aggregation and type-level aggregation in 402 | a specific type block. 403 | Projection 404 | 405 | .. math:: 406 | Y^{Self-\Omega }=H^{\Omega} \cdot W^{Self-\Omega} \quad (1)-1 407 | Y^{\Gamma - \Omega}=H^{\Gamma} \cdot W^{\Gamma - \Omega} , \Gamma \in N_{\Omega} \quad (1)-2 408 | Object-level Aggregation 409 | 410 | .. math:: 411 | Z^{ Self - \Omega } = Y^{ Self - \Omega}=H^{\Omega} \cdot W^{Self - \Omega} \quad (2)-1 412 | Z^{\Gamma - \Omega}=\hat{A}^{\Omega-\Gamma} \cdot Y^{\Gamma - \Omega} = \hat{A}^{\Omega-\Gamma} \cdot H^{\Gamma} \cdot W^{\Gamma - \Omega} \quad (2)-2 413 | Type-level Aggregation 414 | 415 | .. math:: 416 | Q^{\Omega}=Z^{Self-\Omega} \cdot W_q^{\Omega} \quad (3)-1 417 | K^{Self-\Omega}=Z^{Self -\Omega} \cdot W_{k}^{\Omega} \quad (3)-2 418 | K^{\Gamma - \Omega}=Z^{\Gamma - \Omega} \cdot W_{k}^{\Omega}, \quad \Gamma \in N_{\Omega} \quad (3)-3 419 | .. math:: 420 | e^{Self-\Omega}={ELU} ([K^{ Self-\Omega} \| Q^{\Omega}] \cdot w_{a}^{\Omega}) \quad (4)-1 421 | e^{\Gamma - \Omega}={ELU} ([K^{\Gamma - \Omega} \| Q^{\Omega}] \cdot w_{a}^{\Omega}), \Gamma \in N_{\Omega} \quad (4)-2 422 | .. math:: 423 | [a^{Self-\Omega}\|a^{1 - \Omega}\| \ldots . a^{\Gamma - \Omega}\|\ldots\| a^{|N_{\Omega}| - \Omega}] \\ 424 | = {softmax}([e^{Self - \Omega}\|e^{1 - \Omega}\| \ldots\|e^{\Gamma - \Omega}\| \ldots \| e^{|N_{\Omega}| - \Omega}]) \quad (5) 425 | .. math:: 426 | H_{i,:}^{\Omega \prime}=\sigma(a_{i}^{Self-\Omega} \cdot Z_{i,:}^{Self-\Omega}+\sum_{\Gamma \in N_{\Omega}} a_{i}^{\Gamma - \Omega} \cdot Z_{i,:}^{\Gamma - \Omega}) \quad (6) 427 | 428 | Parameters 429 | ---------- 430 | num_layers: int 431 | the number of layers 432 | in_dim: int 433 | the input dimension 434 | hidden_dim: int 435 | the hidden dimension 436 | out_dim: int 437 | the output dimension 438 | attn_dim: int 439 | the dimension of attention vector 440 | ntypes: list 441 | the node type of a heterogeneous graph 442 | etypes: list 443 | the edge type of a heterogeneous graph 444 | """ 445 | @classmethod 446 | def build_model_from_args(cls, args, hg:dgl.DGLGraph): 447 | return cls(args.num_layers, 448 | args.in_dim, 449 | args.hidden_dim, 450 | args.out_dim, 451 | args.attn_dim, 452 | hg.ntypes, 453 | hg.etypes 454 | ) 455 | 456 | def __init__(self, num_layers, in_dim, hidden_dim, out_dim, attn_dim, ntypes, etypes): 457 | super(ieHGCN, self).__init__() 458 | self.num_layers = num_layers 459 | self.activation = F.elu 460 | self.hgcn_layers = nn.ModuleList() 461 | 462 | self.hgcn_layers.append( 463 | ieHGCNConv( 464 | in_dim, 465 | hidden_dim, 466 | attn_dim, 467 | ntypes, 468 | etypes, 469 | self.activation, 470 | ) 471 | ) 472 | 473 | for i in range(1, num_layers - 1): 474 | self.hgcn_layers.append( 475 | ieHGCNConv( 476 | hidden_dim, 477 | hidden_dim, 478 | attn_dim, 479 | ntypes, 480 | etypes, 481 | self.activation 482 | ) 483 | ) 484 | 485 | self.hgcn_layers.append( 486 | ieHGCNConv( 487 | hidden_dim, 488 | out_dim, 489 | attn_dim, 490 | ntypes, 491 | etypes, 492 | None, 493 | ) 494 | ) 495 | 496 | def forward(self, hg, h_dict): 497 | """ 498 | The forward part of the ieHGCN. 499 | 500 | Parameters 501 | ---------- 502 | hg : object 503 | the dgl heterogeneous graph 504 | h_dict: dict 505 | the feature dict of different node types 506 | 507 | Returns 508 | ------- 509 | dict 510 | The embeddings after the output projection. 511 | """ 512 | with hg.local_scope(): 513 | for l in range(self.num_layers): 514 | h_dict, attention = self.hgcn_layers[l](hg, h_dict) 515 | 516 | return h_dict['l'], h_dict['v'], attention 517 | 518 | 519 | class HeteroRGCN(nn.Module): 520 | def __init__(self, G, in_size, hidden_size, out_size): 521 | super(HeteroRGCN, self).__init__() 522 | # create layers 523 | self.layer1 = HeteroRGCNLayer(in_size, hidden_size, G.etypes) 524 | self.layer2 = HeteroRGCNLayer(hidden_size, out_size, G.etypes) 525 | 526 | def forward(self, G, out_key): 527 | input_dict = {ntype : G.nodes[ntype].data['inp'] for ntype in G.ntypes} 528 | h_dict = self.layer1(G, input_dict) 529 | h_dict = {k : F.leaky_relu(h) for k, h in h_dict.items()} 530 | h_dict = self.layer2(G, h_dict) 531 | # get appropriate logits 532 | return h_dict[out_key] 533 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import models.encoders.encoder as enc 4 | import models.aggregators.aggregator as agg 5 | import models.decoders.decoder as dec 6 | from typing import Dict, Union 7 | 8 | 9 | class PredictionModel(nn.Module): 10 | """ 11 | Single-agent prediction model 12 | """ 13 | def __init__(self, encoder: enc.PredictionEncoder, 14 | aggregator: agg.PredictionAggregator, 15 | decoder: dec.PredictionDecoder): 16 | """ 17 | Initializes model for single-agent trajectory prediction 18 | """ 19 | super().__init__() 20 | self.encoder = encoder 21 | self.aggregator = aggregator 22 | self.decoder = decoder 23 | 24 | def forward(self, inputs: Dict) -> Union[torch.Tensor, Dict]: 25 | """ 26 | Forward pass for prediction model 27 | :param inputs: Dictionary with 28 | 'target_agent_representation': target agent history 29 | 'surrounding_agent_representation': surrounding agent history 30 | 'map_representation': HD map representation 31 | :return outputs: K Predicted trajectories and/or their probabilities 32 | """ 33 | encodings = self.encoder(inputs) 34 | agg_encoding = self.aggregator(encodings) 35 | outputs = self.decoder(agg_encoding) 36 | 37 | return outputs 38 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from train_eval.preprocessor import preprocess_data 3 | import yaml 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("-c", "--config", help="Config file with dataset parameters", required=True) 8 | parser.add_argument("-r", "--data_root", help="Root directory with data", required=True) 9 | parser.add_argument("-d", "--data_dir", help="Directory to extract data", required=True) 10 | args = parser.parse_args() 11 | 12 | # Read config file 13 | with open(args.config, 'r') as yaml_file: 14 | cfg = yaml.safe_load(yaml_file) 15 | 16 | preprocess_data(cfg, args.data_root, args.data_dir) 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from train_eval.trainer import Trainer 4 | from torch.utils.tensorboard import SummaryWriter 5 | import os 6 | import wandb 7 | 8 | # Parse arguments 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("-c", "--config", help="Config file with dataset parameters", required=True) 11 | parser.add_argument("-r", "--data_root", help="Root directory with data", required=True) 12 | parser.add_argument("-d", "--data_dir", help="Directory to extract data", required=True) 13 | parser.add_argument("-o", "--output_dir", help="Directory to save checkpoints and logs", required=True) 14 | parser.add_argument("-n", "--num_epochs", help="Number of epochs to run training for", required=True) 15 | parser.add_argument("-w", "--checkpoint", help="Path to pre-trained or intermediate checkpoint", required=False) 16 | parser.add_argument('--nowandb', action='store_true', help='use this flag to DISABLE wandb logging') 17 | parser.add_argument('--sweep', action='store_true', help='use this flag to indicate that this is a sweep run') 18 | parser.add_argument('--aggregator_args.num_heads', type=int, default=16, help='number of heads for aggregator') 19 | parser.add_argument('--aggregator_args.pre_train', type=bool, default=True) 20 | parser.add_argument('--encoder_args.target_agent_enc_size', type=int, default=128) 21 | parser.add_argument('--encoder_args.target_agent_emb_size', type=int, default=64) 22 | parser.add_argument('--encoder_args.num_heads_lanes', type=int, default=1) 23 | parser.add_argument('--encoder_args.feat_drop', type=float, default=0.) 24 | parser.add_argument('--encoder_args.attn_drop', type=float, default=0.) 25 | parser.add_argument('--encoder_args.num_layers', type=int, default=3) 26 | parser.add_argument('--encoder_args.node_hgt_size', type=int, default=32) 27 | parser.add_argument('--encoder_args.hg', type=str, default="simple") 28 | parser.add_argument('--optim_args.scheduler_step', type=int, default=10) 29 | parser.add_argument('--optim_args.lr', type=float, default=0.001) 30 | parser.add_argument('--batch_size', type=int, default=16) 31 | 32 | 33 | args = parser.parse_args() 34 | 35 | # Load config 36 | with open(args.config, 'r') as yaml_file: 37 | cfg = yaml.safe_load(yaml_file) 38 | 39 | # Initialize wandb loger 40 | wandb_logger = None 41 | if not args.nowandb: 42 | wandb_logger = wandb.init(job_type="training", entity='entity', project='xmp', 43 | config=cfg, sync_tensorboard=True) 44 | wandb_logger.name=wandb.run.name 45 | if args.sweep: 46 | enc_args = {key.split('.')[-1]: value for key, value in vars(args).items() if 'encoder' in key.lower()} 47 | agg_args = {key.split('.')[-1]: value for key, value in vars(args).items() if 'aggregator' in key.lower()} 48 | optim_args = {key.split('.')[-1]: value for key, value in vars(args).items() if 'optim' in key.lower()} 49 | cfg['encoder_args'].update(enc_args) 50 | cfg['aggregator_args'].update(agg_args) 51 | cfg['optim_args'].update(optim_args) 52 | cfg.update({'batch_size': args.batch_size}) 53 | cfg['encoder_args'].update({'num_heads_lanes': [enc_args['num_heads_lanes']]*enc_args['num_layers']}) 54 | cfg['encoder_args'].update({'node_emb_size': enc_args['target_agent_emb_size']}) 55 | cfg['encoder_args'].update({'nbr_emb_size': enc_args['target_agent_emb_size']}) 56 | cfg['encoder_args'].update({'node_enc_size': enc_args['target_agent_enc_size']}) 57 | cfg['encoder_args'].update({'nbr_enc_size': enc_args['target_agent_enc_size']}) 58 | cfg['encoder_args'].update({'node_out_hgt_size': enc_args['target_agent_enc_size']}) 59 | cfg['aggregator_args'].update({'target_agent_enc_size': enc_args['target_agent_enc_size']*2}) 60 | cfg['aggregator_args'].update({'node_enc_size': enc_args['target_agent_enc_size']}) 61 | cfg['aggregator_args'].update({'pi_h1_size': enc_args['target_agent_enc_size']}) 62 | cfg['aggregator_args'].update({'pi_h2_size': enc_args['target_agent_enc_size']}) 63 | cfg['aggregator_args'].update({'emb_size': enc_args['target_agent_enc_size']*4}) 64 | cfg['decoder_args'].update({'encoding_size': enc_args['target_agent_enc_size']*6}) 65 | wandb.config.update(cfg, allow_val_change=True) 66 | args.output_dir = os.path.join(args.output_dir, wandb.run.name) 67 | 68 | 69 | # Make directories 70 | if not os.path.isdir(args.output_dir): 71 | os.mkdir(args.output_dir) 72 | if not os.path.isdir(os.path.join(args.output_dir, 'checkpoints')): 73 | os.mkdir(os.path.join(args.output_dir, 'checkpoints')) 74 | if not os.path.isdir(os.path.join(args.output_dir, 'tensorboard_logs')): 75 | os.mkdir(os.path.join(args.output_dir, 'tensorboard_logs')) 76 | 77 | # Initialize tensorboard writer 78 | writer = SummaryWriter(log_dir=os.path.join(args.output_dir, 'tensorboard_logs')) 79 | 80 | # Train 81 | trainer = Trainer(cfg, args.data_root, args.data_dir, checkpoint_path=args.checkpoint, writer=writer, wandb_writer=wandb_logger) 82 | trainer.train(num_epochs=int(args.num_epochs), output_dir=args.output_dir) 83 | 84 | 85 | # Close tensorboard writer 86 | writer.close() 87 | -------------------------------------------------------------------------------- /train_eval/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as torch_data 2 | from typing import Dict 3 | from train_eval.initialization import initialize_prediction_model, initialize_metric,\ 4 | initialize_dataset, get_specific_args 5 | import torch 6 | import os 7 | import train_eval.utils as u 8 | import numpy as np 9 | from nuscenes.prediction.helper import convert_local_coords_to_global 10 | from nuscenes.eval.prediction.data_classes import Prediction 11 | import json 12 | from train_eval.utils import collate_fn_dgl_hetero 13 | 14 | 15 | # Initialize device: 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | 19 | class Evaluator: 20 | """ 21 | Class for evaluating trained models 22 | """ 23 | def __init__(self, cfg: Dict, data_root: str, data_dir: str, checkpoint_path: str): 24 | """ 25 | Initialize evaluator object 26 | :param cfg: Configuration parameters 27 | :param data_root: Root directory with data 28 | :param data_dir: Directory with extracted, pre-processed data 29 | :param checkpoint_path: Path to checkpoint with trained weights 30 | """ 31 | 32 | # Initialize dataset 33 | ds_type = cfg['dataset'] + '_' + cfg['agent_setting'] + '_' + cfg['input_representation'] 34 | spec_args = get_specific_args(cfg['dataset'], data_root, cfg['version'] if 'version' in cfg.keys() else None)[0] 35 | test_set = initialize_dataset(ds_type, ['load_data', data_dir, cfg['test_set_args']] + spec_args) 36 | 37 | # Initialize dataloader 38 | if 'scout' in cfg['encoder_type']: 39 | collate_fn = collate_fn_dgl_hetero 40 | else: 41 | collate_fn = None 42 | self.dl = torch_data.DataLoader(test_set, cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'], collate_fn=collate_fn) 43 | 44 | # Initialize model 45 | self.model = initialize_prediction_model(cfg['encoder_type'], cfg['aggregator_type'], cfg['decoder_type'], 46 | cfg['encoder_args'], cfg['aggregator_args'], cfg['decoder_args']) 47 | self.model = self.model.float().to(device) 48 | self.model.eval() 49 | 50 | # Load checkpoint 51 | checkpoint = torch.load(checkpoint_path) 52 | self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) 53 | 54 | # Initialize metrics 55 | self.metrics = [initialize_metric(cfg['val_metrics'][i], cfg['val_metric_args'][i]) 56 | for i in range(len(cfg['val_metrics']))] 57 | 58 | def evaluate(self, output_dir: str): 59 | """ 60 | Main function to evaluate trained model 61 | :param output_dir: Output directory to store results 62 | """ 63 | 64 | # Initialize aggregate metrics 65 | agg_metrics = self.initialize_aggregate_metrics() 66 | 67 | with torch.no_grad(): 68 | for i, data in enumerate(self.dl): 69 | 70 | # Load data 71 | data = u.send_to_device(u.convert_double_to_float(data)) 72 | 73 | # Forward pass 74 | predictions = self.model(data['inputs']) 75 | 76 | # Aggregate metrics 77 | agg_metrics = self.aggregate_metrics(agg_metrics, predictions, data['ground_truth']) 78 | 79 | self.print_progress(i) 80 | 81 | # compute and print average metrics 82 | self.print_progress(len(self.dl)) 83 | with open(os.path.join(output_dir, 'results', "results.txt"), "w") as out_file: 84 | for metric in self.metrics: 85 | avg_metric = agg_metrics[metric.name]/agg_metrics['sample_count'] 86 | output = metric.name + ': ' + format(avg_metric, '0.2f') 87 | print(output) 88 | out_file.write(output + '\n') 89 | 90 | def initialize_aggregate_metrics(self): 91 | """ 92 | Initialize aggregate metrics for test set. 93 | """ 94 | agg_metrics = {'sample_count': 0} 95 | for metric in self.metrics: 96 | agg_metrics[metric.name] = 0 97 | 98 | return agg_metrics 99 | 100 | def aggregate_metrics(self, agg_metrics: Dict, model_outputs: Dict, ground_truth: Dict): 101 | """ 102 | Aggregates metrics for evaluation 103 | """ 104 | minibatch_metrics = {} 105 | for metric in self.metrics: 106 | minibatch_metrics[metric.name] = metric.compute(model_outputs, ground_truth).item() 107 | 108 | batch_size = ground_truth['traj'].shape[0] 109 | agg_metrics['sample_count'] += batch_size 110 | 111 | for metric in self.metrics: 112 | agg_metrics[metric.name] += minibatch_metrics[metric.name] * batch_size 113 | 114 | return agg_metrics 115 | 116 | def print_progress(self, minibatch_count: int): 117 | """ 118 | Prints progress bar 119 | """ 120 | epoch_progress = minibatch_count / len(self.dl) * 100 121 | print('\rEvaluating:', end=" ") 122 | progress_bar = '[' 123 | for i in range(20): 124 | if i < epoch_progress // 5: 125 | progress_bar += '=' 126 | else: 127 | progress_bar += ' ' 128 | progress_bar += ']' 129 | print(progress_bar, format(epoch_progress, '0.2f'), '%', end="\n" if epoch_progress == 100 else " ") 130 | 131 | def generate_nuscenes_benchmark_submission(self, output_dir: str): 132 | """ 133 | Sets up list of Prediction objects for the nuScenes benchmark. 134 | """ 135 | 136 | # NuScenes prediction helper 137 | helper = self.dl.dataset.helper 138 | 139 | # List of predictions 140 | preds = [] 141 | 142 | with torch.no_grad(): 143 | for i, data in enumerate(self.dl): 144 | 145 | # Load data 146 | data = u.send_to_device(u.convert_double_to_float(data)) 147 | 148 | # Forward pass 149 | predictions = self.model(data['inputs']) 150 | traj = predictions['traj'] 151 | probs = predictions['probs'] 152 | 153 | # Load instance and sample tokens for batch 154 | instance_tokens = data['inputs']['instance_token'] 155 | sample_tokens = data['inputs']['sample_token'] 156 | 157 | # Create prediction object and add to list of predictions 158 | for n in range(traj.shape[0]): 159 | 160 | traj_local = traj[n].detach().cpu().numpy() 161 | probs_n = probs[n].detach().cpu().numpy() 162 | starting_annotation = helper.get_sample_annotation(instance_tokens[n], sample_tokens[n]) 163 | traj_global = np.zeros_like(traj_local) 164 | for m in range(traj_local.shape[0]): 165 | traj_global[m] = convert_local_coords_to_global(traj_local[m], 166 | starting_annotation['translation'], 167 | starting_annotation['rotation']) 168 | 169 | preds.append(Prediction(instance=instance_tokens[n], sample=sample_tokens[n], 170 | prediction=traj_global, probabilities=probs_n).serialize()) 171 | 172 | # Print progress bar 173 | self.print_progress(i) 174 | 175 | # Save predictions to json file 176 | json.dump(preds, open(os.path.join(output_dir, 'results', "evalai_submission.json"), "w")) 177 | self.print_progress(len(self.dl)) 178 | -------------------------------------------------------------------------------- /train_eval/initialization.py: -------------------------------------------------------------------------------- 1 | # Import datasets 2 | from models.encoders.scout_encoder import SCOUTEncoder 3 | from nuscenes import NuScenes 4 | from nuscenes.prediction import PredictHelper 5 | from datasets.interface import TrajectoryDataset 6 | from datasets.nuScenes.nuScenes_raster import NuScenesRaster 7 | from datasets.nuScenes.nuScenes_vector import NuScenesVector 8 | from datasets.nuScenes.nuScenes_graphs import NuScenesGraphs 9 | 10 | # Import models 11 | from models.model import PredictionModel 12 | from models.encoders.raster_encoder import RasterEncoder 13 | from models.encoders.polyline_subgraph import PolylineSubgraphs 14 | from models.encoders.pgp_encoder import PGPEncoder 15 | from models.encoders.scout_encoder import SCOUTEncoder 16 | from models.encoders.pgp_scout_encoder import PGP_SCOUTEncoder 17 | from models.aggregators.concat import Concat 18 | from models.aggregators.global_attention import GlobalAttention 19 | from models.aggregators.goal_conditioned import GoalConditioned 20 | from models.aggregators.pgp import PGP 21 | from models.decoders.mtp import MTP 22 | from models.decoders.multipath import Multipath 23 | from models.decoders.covernet import CoverNet 24 | from models.decoders.lvm import LVM 25 | 26 | # Import metrics 27 | from metrics.mtp_loss import MTPLoss 28 | from metrics.min_ade import MinADEK 29 | from metrics.min_fde import MinFDEK 30 | from metrics.miss_rate import MissRateK 31 | from metrics.covernet_loss import CoverNetLoss 32 | from metrics.pi_bc import PiBehaviorCloning 33 | from metrics.goal_pred_nll import GoalPredictionNLL 34 | 35 | from typing import List, Dict, Union 36 | 37 | 38 | # Datasets 39 | def initialize_dataset(dataset_type: str, args: List) -> TrajectoryDataset: 40 | """ 41 | Helper function to initialize appropriate dataset by dataset type string 42 | """ 43 | # TODO: Add more datasets as implemented 44 | dataset_classes = {#'nuScenes_single_agent_raster': NuScenesRaster, 45 | 'nuScenes_single_agent_vector': NuScenesVector, 46 | 'nuScenes_single_agent_graphs': NuScenesGraphs, 47 | } 48 | return dataset_classes[dataset_type](*args) 49 | 50 | 51 | def get_specific_args(dataset_name: str, data_root: str, version: str = None) -> List: 52 | """ 53 | Helper function to get dataset specific arguments. 54 | """ 55 | # TODO: Add more datasets as implemented 56 | specific_args = [] 57 | if dataset_name == 'nuScenes': 58 | ns = NuScenes(version, dataroot=data_root) 59 | pred_helper = PredictHelper(ns) 60 | specific_args.append([pred_helper]) 61 | specific_args.append([ns]) 62 | 63 | return specific_args 64 | 65 | 66 | # Models 67 | def initialize_prediction_model(encoder_type: str, aggregator_type: str, decoder_type: str, 68 | encoder_args: Dict, aggregator_args: Union[Dict, None], decoder_args: Dict): 69 | """ 70 | Helper function to initialize appropriate encoder, aggegator and decoder models 71 | """ 72 | encoder = initialize_encoder(encoder_type, encoder_args) 73 | aggregator = initialize_aggregator(aggregator_type, aggregator_args) 74 | decoder = initialize_decoder(decoder_type, decoder_args) 75 | model = PredictionModel(encoder, aggregator, decoder) 76 | 77 | return model 78 | 79 | 80 | def initialize_encoder(encoder_type: str, encoder_args: Dict): 81 | """ 82 | Initialize appropriate encoder by type. 83 | """ 84 | # TODO: Update as we add more encoder types 85 | encoder_mapping = { 86 | 'raster_encoder': RasterEncoder, 87 | 'polyline_subgraphs': PolylineSubgraphs, 88 | 'pgp_encoder': PGPEncoder, 89 | 'scout_encoder': SCOUTEncoder, 90 | 'pgp_scout_encoder': PGP_SCOUTEncoder 91 | } 92 | 93 | return encoder_mapping[encoder_type](encoder_args) 94 | 95 | 96 | def initialize_aggregator(aggregator_type: str, aggregator_args: Union[Dict, None]): 97 | """ 98 | Initialize appropriate aggregator by type. 99 | """ 100 | # TODO: Update as we add more aggregator types 101 | aggregator_mapping = { 102 | 'concat': Concat, 103 | 'global_attention': GlobalAttention, 104 | 'gc': GoalConditioned, 105 | 'pgp': PGP 106 | } 107 | 108 | if aggregator_args: 109 | return aggregator_mapping[aggregator_type](aggregator_args) 110 | else: 111 | return aggregator_mapping[aggregator_type]() 112 | 113 | 114 | def initialize_decoder(decoder_type: str, decoder_args: Dict): 115 | """ 116 | Initialize appropriate decoder by type. 117 | """ 118 | # TODO: Update as we add more decoder types 119 | decoder_mapping = { 120 | 'mtp': MTP, 121 | 'multipath': Multipath, 122 | 'covernet': CoverNet, 123 | 'lvm': LVM 124 | } 125 | 126 | return decoder_mapping[decoder_type](decoder_args) 127 | 128 | 129 | # Metrics 130 | def initialize_metric(metric_type: str, metric_args: Dict = None): 131 | """ 132 | Initialize appropriate metric by type. 133 | """ 134 | # TODO: Update as we add more metrics 135 | metric_mapping = { 136 | 'mtp_loss': MTPLoss, 137 | 'covernet_loss': CoverNetLoss, 138 | 'min_ade_k': MinADEK, 139 | 'min_fde_k': MinFDEK, 140 | 'miss_rate_k': MissRateK, 141 | 'pi_bc': PiBehaviorCloning, 142 | 'goal_pred_nll': GoalPredictionNLL 143 | } 144 | 145 | if metric_args is not None: 146 | return metric_mapping[metric_type](metric_args) 147 | else: 148 | return metric_mapping[metric_type]() 149 | -------------------------------------------------------------------------------- /train_eval/preprocessor.py: -------------------------------------------------------------------------------- 1 | from datasets.interface import TrajectoryDataset 2 | import torch.utils.data as torch_data 3 | from typing import List, Dict 4 | import torch 5 | import os 6 | import pickle 7 | from train_eval.initialization import get_specific_args, initialize_dataset 8 | 9 | 10 | def preprocess_data(cfg: Dict, data_root: str, data_dir: str, compute_stats=True, extract=True): 11 | """ 12 | Main function for pre-processing data 13 | 14 | :param cfg: Dictionary with configuration parameters 15 | :param data_root: Root directory for the dataset 16 | :param data_dir: Directory to extract pre-processed data 17 | :param compute_stats: Flag, whether to compute stats 18 | :param extract: Flag, whether to extract data 19 | """ 20 | 21 | # String describing dataset type 22 | ds_type = cfg['dataset'] + '_' + cfg['agent_setting'] + '_' + cfg['input_representation'] 23 | 24 | # Get dataset specific args 25 | specific_args = get_specific_args(cfg['dataset'], data_root, cfg['version'] if 'version' in cfg.keys() else None)[0] 26 | 27 | # Compute stats 28 | if compute_stats: 29 | train_set = initialize_dataset(ds_type, ['compute_stats', data_dir, cfg['train_set_args']] + specific_args) 30 | val_set = initialize_dataset(ds_type, ['compute_stats', data_dir, cfg['val_set_args']] + specific_args) 31 | test_set = initialize_dataset(ds_type, ['compute_stats', data_dir, cfg['test_set_args']] + specific_args) 32 | compute_dataset_stats([train_set, val_set, test_set], cfg['batch_size'], cfg['num_workers'], 33 | verbose=cfg['verbosity']) 34 | 35 | # Extract data 36 | if extract: 37 | train_set = initialize_dataset(ds_type, ['extract_data', data_dir, cfg['train_set_args']] + specific_args) 38 | val_set = initialize_dataset(ds_type, ['extract_data', data_dir, cfg['val_set_args']] + specific_args) 39 | test_set = initialize_dataset(ds_type, ['extract_data', data_dir, cfg['test_set_args']] + specific_args) 40 | extract_data([train_set, val_set, test_set], cfg['batch_size'], cfg['num_workers'], verbose=cfg['verbosity']) 41 | 42 | 43 | def compute_dataset_stats(dataset_splits: List[TrajectoryDataset], batch_size: int, num_workers: int, verbose=False): 44 | """ 45 | Computes dataset stats 46 | 47 | :param dataset_splits: List of dataset objects usually corresponding to the train, val and test splits 48 | :param batch_size: Batch size for dataloader 49 | :param num_workers: Number of workers for dataloader 50 | :param verbose: Whether to print progress 51 | """ 52 | # Check if all datasets have been initialized with the correct mode 53 | for dataset in dataset_splits: 54 | if dataset.mode != 'compute_stats': 55 | raise Exception('Dataset mode should be compute_stats') 56 | 57 | # Initialize data loaders 58 | data_loaders = [] 59 | for dataset in dataset_splits: 60 | dl = torch_data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 61 | data_loaders.append(dl) 62 | 63 | # Initialize dataset statistics 64 | stats = {} 65 | 66 | # For printing progress 67 | print("Computing dataset stats...") 68 | num_mini_batches = sum([len(data_loader) for data_loader in data_loaders]) 69 | mini_batch_count = 0 70 | 71 | # Loop over splits and mini-batches 72 | for data_loader in data_loaders: 73 | for i, mini_batch_stats in enumerate(data_loader): 74 | for k, v in mini_batch_stats.items(): 75 | if k in stats.keys(): 76 | stats[k] = max(stats[k], torch.max(v).item()) 77 | else: 78 | stats[k] = torch.max(v).item() 79 | 80 | # Show progress 81 | if verbose: 82 | print("mini batch " + str(mini_batch_count + 1) + '/' + str(num_mini_batches)) 83 | mini_batch_count += 1 84 | 85 | # Save stats 86 | filename = os.path.join(dataset_splits[0].data_dir, 'stats.pickle') 87 | with open(filename, 'wb') as handle: 88 | pickle.dump(stats, handle, protocol=pickle.HIGHEST_PROTOCOL) 89 | 90 | 91 | def extract_data(dataset_splits: List[TrajectoryDataset], batch_size: int, num_workers: int, verbose=False): 92 | """ 93 | Extracts pre-processed data 94 | 95 | :param dataset_splits: List of dataset objects usually corresponding to the train, val and test splits 96 | :param batch_size: Batch size for dataloader 97 | :param num_workers: Number of workers for dataloader 98 | :param verbose: Whether to print progress 99 | """ 100 | # Check if all datasets have been initialized with the correct mode 101 | for dataset in dataset_splits: 102 | if dataset.mode != 'extract_data': 103 | raise Exception('Dataset mode should be extract_data') 104 | 105 | # Initialize data loaders 106 | data_loaders = [] 107 | for dataset in dataset_splits: 108 | dl = torch_data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 109 | data_loaders.append(dl) 110 | 111 | # For printing progress 112 | print("Extracting pre-processed data...") 113 | num_mini_batches = sum([len(data_loader) for data_loader in data_loaders]) 114 | mini_batch_count = 0 115 | 116 | # Loop over splits and mini-batches 117 | for data_loader in data_loaders: 118 | for i, _ in enumerate(data_loader): 119 | 120 | # Show progress 121 | if verbose: 122 | print("mini batch " + str(mini_batch_count + 1) + '/' + str(num_mini_batches)) 123 | mini_batch_count += 1 124 | -------------------------------------------------------------------------------- /train_eval/trainer.py: -------------------------------------------------------------------------------- 1 | import torch.optim 2 | import torch.utils.data as torch_data 3 | from typing import Dict 4 | from train_eval.initialization import initialize_prediction_model, initialize_metric,\ 5 | initialize_dataset, get_specific_args 6 | import torch 7 | import time 8 | import math 9 | import os 10 | import train_eval.utils as u 11 | 12 | 13 | # Initialize device: 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | 17 | class Trainer: 18 | """ 19 | Trainer class for running train-val loops 20 | """ 21 | def __init__(self, cfg: Dict, data_root: str, data_dir: str, checkpoint_path=None, just_weights=False, writer=None, wandb_writer=None): 22 | """ 23 | Initialize trainer object 24 | :param cfg: Configuration parameters 25 | :param data_root: Root directory with data 26 | :param data_dir: Directory with extracted, pre-processed data 27 | :param checkpoint_path: Path to checkpoint with trained weights 28 | :param just_weights: Load just weights from checkpoint 29 | :param writer: Tensorboard summary writer 30 | """ 31 | 32 | # Initialize datasets: 33 | ds_type = cfg['dataset'] + '_' + cfg['agent_setting'] + '_' + cfg['input_representation'] 34 | spec_args = get_specific_args(cfg['dataset'], data_root, cfg['version'] if 'version' in cfg.keys() else None)[0] 35 | train_set = initialize_dataset(ds_type, ['load_data', data_dir, cfg['train_set_args']] + spec_args) 36 | val_set = initialize_dataset(ds_type, ['load_data', data_dir, cfg['val_set_args']] + spec_args) 37 | datasets = {'train': train_set, 'val': val_set} 38 | 39 | # Initialize dataloaders 40 | if 'scout' in cfg['encoder_type']: 41 | collate_fn = u.collate_fn_dgl_hetero 42 | else: 43 | collate_fn = None 44 | self.tr_dl = torch_data.DataLoader(datasets['train'], cfg['batch_size'], shuffle=True, 45 | num_workers=cfg['num_workers'], pin_memory=True, collate_fn=collate_fn) 46 | self.val_dl = torch_data.DataLoader(datasets['val'], cfg['batch_size'], shuffle=False, 47 | num_workers=cfg['num_workers'], pin_memory=True, collate_fn=collate_fn ) 48 | 49 | # Initialize model 50 | self.model = initialize_prediction_model(cfg['encoder_type'], cfg['aggregator_type'], cfg['decoder_type'], 51 | cfg['encoder_args'], cfg['aggregator_args'], cfg['decoder_args']) 52 | self.model = self.model.float().to(device) 53 | 54 | # Initialize optimizer and scheduler 55 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg['optim_args']['lr']) 56 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=cfg['optim_args']['scheduler_step'], 57 | gamma=cfg['optim_args']['scheduler_gamma']) 58 | 59 | # Initialize epochs 60 | self.current_epoch = 0 61 | 62 | # Initialize losses 63 | self.losses = [initialize_metric(cfg['losses'][i], cfg['loss_args'][i]) for i in range(len(cfg['losses']))] 64 | self.loss_weights = cfg['loss_weights'] 65 | 66 | # Initialize metrics 67 | self.train_metrics = [initialize_metric(cfg['tr_metrics'][i], cfg['tr_metric_args'][i]) 68 | for i in range(len(cfg['tr_metrics']))] 69 | self.val_metrics = [initialize_metric(cfg['val_metrics'][i], cfg['val_metric_args'][i]) 70 | for i in range(len(cfg['val_metrics']))] 71 | self.val_metric = math.inf 72 | self.min_val_metric = 1.36 73 | 74 | # Print metrics after these many minibatches to keep track of training 75 | self.log_period = len(self.tr_dl)//cfg['log_freq'] 76 | 77 | # Initialize tensorboard writer 78 | self.writer = writer 79 | self.tb_iters = 0 80 | 81 | self.wandb_writer = wandb_writer 82 | if self.wandb_writer is not None: 83 | self.wandb_writer.watch( 84 | self.model, 85 | criterion=self.val_metrics[0].name, 86 | log= None, 87 | log_freq = 1000, 88 | log_graph= True 89 | ) 90 | 91 | # Load checkpoint if checkpoint path is provided 92 | if checkpoint_path is not None: 93 | print() 94 | print("Loading checkpoint from " + checkpoint_path + " ...", end=" ") 95 | self.load_checkpoint(checkpoint_path, just_weights=just_weights) 96 | print("Done") 97 | 98 | # Generate anchors if using an anchor based trajectory decoder 99 | if hasattr(self.model.decoder, 'anchors') and torch.as_tensor(self.model.decoder.anchors == 0).all(): 100 | print() 101 | print("Extracting anchors for decoder ...", end=" ") 102 | self.model.decoder.generate_anchors(self.tr_dl.dataset) 103 | print("Done") 104 | 105 | def train(self, num_epochs: int, output_dir: str): 106 | """ 107 | Main function to train model 108 | :param num_epochs: Number of epochs to run training for 109 | :param output_dir: Output directory to store tensorboard logs and checkpoints 110 | :return: 111 | """ 112 | 113 | # Run training, validation for given number of epochs 114 | start_epoch = self.current_epoch 115 | for epoch in range(start_epoch, start_epoch + num_epochs): 116 | 117 | # Set current epoch 118 | self.current_epoch = epoch 119 | print() 120 | print('Epoch (' + str(self.current_epoch + 1) + '/' + str(start_epoch + num_epochs) + ')') 121 | 122 | # Train 123 | train_epoch_metrics = self.run_epoch('train', self.tr_dl) 124 | self.print_metrics(train_epoch_metrics, self.tr_dl, mode='train') 125 | 126 | # Validate 127 | with torch.no_grad(): 128 | val_epoch_metrics = self.run_epoch('val', self.val_dl) 129 | self.print_metrics(val_epoch_metrics, self.val_dl, mode='val') 130 | 131 | # Scheduler step 132 | self.scheduler.step() 133 | 134 | # Update validation metric 135 | self.val_metric = val_epoch_metrics[self.val_metrics[0].name] / val_epoch_metrics['minibatch_count'] 136 | 137 | # save best checkpoint when applicable 138 | if self.val_metric < self.min_val_metric: 139 | self.min_val_metric = self.val_metric 140 | self.save_checkpoint(os.path.join(output_dir, 'checkpoints', 'best.tar')) 141 | self.wandb_writer.log({"best_val_ade_5": self.min_val_metric, "epoch_best_val": self.current_epoch}) 142 | 143 | # Save checkpoint 144 | # self.save_checkpoint(os.path.join(output_dir, 'checkpoints', str(self.current_epoch) + '.tar')) 145 | 146 | def run_epoch(self, mode: str, dl: torch_data.DataLoader): 147 | """ 148 | Runs an epoch for a given dataloader 149 | :param mode: 'train' or 'val' 150 | :param dl: dataloader object 151 | """ 152 | if mode == 'val': 153 | self.model.eval() 154 | else: 155 | self.model.train() 156 | 157 | # Initialize epoch metrics 158 | epoch_metrics = self.initialize_metrics_for_epoch(mode) 159 | 160 | # Main loop 161 | st_time = time.time() 162 | for i, data in enumerate(dl): 163 | 164 | # Load data 165 | data = u.send_to_device(u.convert_double_to_float(data)) 166 | 167 | # Forward pass 168 | predictions = self.model(data['inputs']) 169 | 170 | # Compute loss and backprop if training 171 | if mode == 'train': 172 | loss = self.compute_loss(predictions, data['ground_truth']) 173 | self.back_prop(loss) 174 | 175 | # Keep time 176 | minibatch_time = time.time() - st_time 177 | st_time = time.time() 178 | 179 | # Aggregate metrics 180 | minibatch_metrics, epoch_metrics = self.aggregate_metrics(epoch_metrics, minibatch_time, 181 | predictions, data['ground_truth'], mode) 182 | 183 | # Log minibatch metrics to tensorboard during training 184 | if mode == 'train': 185 | self.log_tensorboard_train(minibatch_metrics) 186 | 187 | # Display metrics at a predefined frequency 188 | if i % self.log_period == self.log_period - 1: 189 | self.print_metrics(epoch_metrics, dl, mode) 190 | 191 | # Log val metrics for the complete epoch to tensorboard 192 | if mode == 'val': 193 | self.log_tensorboard_val(epoch_metrics) 194 | 195 | return epoch_metrics 196 | 197 | def compute_loss(self, model_outputs: Dict, ground_truth: Dict) -> torch.Tensor: 198 | """ 199 | Computes loss given model outputs and ground truth labels 200 | """ 201 | loss_vals = [loss.compute(model_outputs, ground_truth) for loss in self.losses] 202 | total_loss = torch.as_tensor(0, device=device).float() 203 | for n in range(len(loss_vals)): 204 | total_loss += self.loss_weights[n] * loss_vals[n] 205 | 206 | return total_loss 207 | 208 | def back_prop(self, loss: torch.Tensor, grad_clip_thresh=10): 209 | """ 210 | Backpropagate loss 211 | """ 212 | self.optimizer.zero_grad() 213 | loss.backward() 214 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_thresh) 215 | self.optimizer.step() 216 | 217 | def initialize_metrics_for_epoch(self, mode: str): 218 | """ 219 | Initialize metrics for epoch 220 | """ 221 | metrics = self.train_metrics if mode == 'train' else self.val_metrics 222 | epoch_metrics = {'minibatch_count': 0, 'time_elapsed': 0} 223 | for metric in metrics: 224 | epoch_metrics[metric.name] = 0 225 | 226 | return epoch_metrics 227 | 228 | def aggregate_metrics(self, epoch_metrics: Dict, minibatch_time: float, model_outputs: Dict, ground_truth: Dict, 229 | mode: str): 230 | """ 231 | Aggregates metrics by minibatch for the entire epoch 232 | """ 233 | metrics = self.train_metrics if mode == 'train' else self.val_metrics 234 | 235 | minibatch_metrics = {} 236 | for metric in metrics: 237 | minibatch_metrics[metric.name] = metric.compute(model_outputs, ground_truth).item() 238 | 239 | epoch_metrics['minibatch_count'] += 1 240 | epoch_metrics['time_elapsed'] += minibatch_time 241 | for metric in metrics: 242 | epoch_metrics[metric.name] += minibatch_metrics[metric.name] 243 | 244 | return minibatch_metrics, epoch_metrics 245 | 246 | def print_metrics(self, epoch_metrics: Dict, dl: torch_data.DataLoader, mode: str): 247 | """ 248 | Prints aggregated metrics 249 | """ 250 | metrics = self.train_metrics if mode == 'train' else self.val_metrics 251 | minibatches_left = len(dl) - epoch_metrics['minibatch_count'] 252 | eta = (epoch_metrics['time_elapsed']/epoch_metrics['minibatch_count']) * minibatches_left 253 | epoch_progress = int(epoch_metrics['minibatch_count']/len(dl) * 100) 254 | print('\rTraining:' if mode == 'train' else '\rValidating:', end=" ") 255 | progress_bar = '[' 256 | for i in range(20): 257 | if i < epoch_progress // 5: 258 | progress_bar += '=' 259 | else: 260 | progress_bar += ' ' 261 | progress_bar += ']' 262 | print(progress_bar, str(epoch_progress), '%', end=", ") 263 | print('ETA:', int(eta), end="s, ") 264 | print('Metrics', end=": { ") 265 | for metric in metrics: 266 | metric_val = epoch_metrics[metric.name]/epoch_metrics['minibatch_count'] 267 | print(metric.name + ':', format(metric_val, '0.2f'), end=", ") 268 | print('\b\b }', end="\n" if eta == 0 else "") 269 | 270 | def load_checkpoint(self, checkpoint_path, just_weights=False): 271 | """ 272 | Loads checkpoint from given path 273 | """ 274 | checkpoint = torch.load(checkpoint_path) 275 | self.model.load_state_dict(checkpoint['model_state_dict']) 276 | if not just_weights: 277 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 278 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 279 | self.current_epoch = checkpoint['epoch'] 280 | self.val_metric = checkpoint['val_metric'] 281 | self.min_val_metric = checkpoint['min_val_metric'] 282 | 283 | def save_checkpoint(self, checkpoint_path): 284 | """ 285 | Saves checkpoint to given path 286 | """ 287 | torch.save({ 288 | 'epoch': self.current_epoch + 1, 289 | 'model_state_dict': self.model.state_dict(), 290 | 'optimizer_state_dict': self.optimizer.state_dict(), 291 | 'scheduler_state_dict': self.scheduler.state_dict(), 292 | 'val_metric': self.val_metric, 293 | 'min_val_metric': self.min_val_metric 294 | }, checkpoint_path) 295 | 296 | 297 | def log_tensorboard_train(self, minibatch_metrics: Dict): 298 | """ 299 | Logs minibatch metrics during training 300 | """ 301 | for metric_name, metric_val in minibatch_metrics.items(): 302 | self.writer.add_scalar('train/' + metric_name, metric_val, self.tb_iters) 303 | if self.wandb_writer is not None: 304 | self.wandb_writer.log({'train/'+ metric_name: metric_val, 'epoch': self.current_epoch, 'batch': self.tb_iters}) 305 | self.tb_iters += 1 306 | 307 | def log_tensorboard_val(self, epoch_metrics): 308 | """ 309 | Logs epoch metrics for validation set 310 | """ 311 | for metric_name, metric_val in epoch_metrics.items(): 312 | if metric_name != 'minibatch_count' and metric_name != 'time_elapsed': 313 | metric_val /= epoch_metrics['minibatch_count'] 314 | self.writer.add_scalar('val/' + metric_name, metric_val, self.tb_iters) 315 | if self.wandb_writer is not None: 316 | self.wandb_writer.log({'val/' + metric_name: metric_val, 'epoch': self.current_epoch, 'batch': self.tb_iters}) 317 | 318 | -------------------------------------------------------------------------------- /train_eval/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import v_measure_score 2 | import torch.optim 3 | from typing import Dict, Union 4 | import torch 5 | import numpy as np 6 | import dgl 7 | import scipy.sparse as spp 8 | from torch.utils.data._utils.collate import default_collate 9 | 10 | 11 | # Initialize device: 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | def convert_double_to_float(data: Union[Dict, torch.Tensor]): 16 | """ 17 | Utility function to convert double tensors to float tensors in nested dictionary with Tensors 18 | """ 19 | if type(data) is torch.Tensor and data.dtype == torch.float64: 20 | return data.float() 21 | elif type(data) is dict: 22 | for k, v in data.items(): 23 | data[k] = convert_double_to_float(v) 24 | return data 25 | else: 26 | return data 27 | 28 | 29 | def send_to_device(data: Union[Dict, torch.Tensor]): 30 | """ 31 | Utility function to send nested dictionary with Tensors to GPU 32 | """ 33 | if type(data) is torch.Tensor: 34 | return data.to(device) 35 | elif type(data) is dict: 36 | for k, v in data.items(): 37 | data[k] = send_to_device(v) 38 | return data 39 | else: 40 | return data 41 | 42 | 43 | def convert2tensors(data): 44 | """ 45 | Converts data (dictionary of nd arrays etc.) to tensor with batch_size 1 46 | """ 47 | if type(data) is np.ndarray: 48 | return torch.as_tensor(data).unsqueeze(0) 49 | elif type(data) is dict: 50 | for k, v in data.items(): 51 | data[k] = convert2tensors(v) 52 | return data 53 | else: 54 | return data 55 | 56 | class Collate_heterograph(object): 57 | def __init__(self, args): 58 | self.mask_frames = args['mask_frames_p'] 59 | self.agent_mask_prob_v = args['agent_mask_p_veh'] 60 | self.lane_mask_prob = args['lane_mask_p'] 61 | def __call__(self,batch): 62 | # Collate function for dataloader. 63 | lanes_graphs = [] 64 | for element in batch: 65 | adj = element['inputs']['surrounding_agent_representation']['adj_matrix'] 66 | lane_node_masks = element['inputs']['map_representation']['lane_node_masks'] 67 | veh_mask=element['inputs']['surrounding_agent_representation']['vehicle_masks'] 68 | num_v = np.where(veh_mask[:,:,0]==0)[0].max()+2 if len(np.where(veh_mask[:,:,0]==0)[0])>0 else 0 # +1 to account for focal agent which is not present in veh_mask 69 | 70 | ### ROBUSTNESS ANALYSYS ### 71 | mask_out_lanes = [] 72 | if 'mask_out_lanes' in element['inputs']['map_representation']: 73 | mask_out_lanes = element['inputs']['map_representation']['mask_out_lanes'] 74 | elif self.lane_mask_prob > 0.: 75 | ###### Mask out lane_node_masks by lane p% of the time - 1 means mask out 76 | mask_out = np.tile(np.expand_dims((np.random.random((lane_node_masks.shape[0])) < self.lane_mask_prob), [-1,-2]), [ 1,lane_node_masks.shape[-2],lane_node_masks.shape[-1]]) 77 | lane_node_masks = lane_node_masks.astype(int) | mask_out.astype(int) 78 | # Indeces of masked out lanes 79 | mask_out_lanes = np.where(mask_out[:,0,0] == True)[0] 80 | element['inputs']['map_representation']['lane_node_masks'] = lane_node_masks 81 | # Update with masked out lanes 82 | element['inputs']['map_representation']['succ_adj_matrix'] = element['inputs']['map_representation']['succ_adj_matrix'] * (1-lane_node_masks[:,0,0]) 83 | element['inputs']['map_representation']['prox_adj_matrix'] = element['inputs']['map_representation']['prox_adj_matrix'] * (1-lane_node_masks[:,0,0]) 84 | element['inputs']['agent_node_masks']['vehicles'] = element['inputs']['agent_node_masks']['vehicles'].astype(int) | np.expand_dims((lane_node_masks[:,0,0]), -1).astype(int) 85 | 86 | if self.mask_frames > 0.0: 87 | target_adj_matrix = element['inputs']['surrounding_agent_representation']['adj_matrix'][0,1:veh_mask.shape[0]+1] 88 | target_adj_matrix = np.tile(np.expand_dims(target_adj_matrix,-1), [1,veh_mask.shape[-2]]) 89 | target_adj_matrix *= (np.random.random((target_adj_matrix.shape[0], target_adj_matrix.shape[1])) > self.mask_frames).astype(int) 90 | veh_mask = veh_mask.astype(int) | np.tile(np.expand_dims((1-target_adj_matrix),-1), [1,1,veh_mask.shape[-1]]).astype(int) 91 | # Mask out frames of nearby agents with a 60% probability 92 | element['inputs']['surrounding_agent_representation']['vehicle_masks'] = veh_mask 93 | element['inputs']['agent_node_masks']['vehicles'] = element['inputs']['agent_node_masks']['vehicles'].astype(int) | np.tile(np.expand_dims(veh_mask[:,:,0].any(-1 ),0), [164,1]) 94 | ############################# 95 | # Update with new masked out vehicles to update the graph 96 | v_nodes_mask = (veh_mask[:,:,0].sum(-1)==veh_mask.shape[-2]) == False # True where there is a vehicle 97 | v_nodes = np.where(v_nodes_mask)[0] # 0 to 83 98 | v_nodes = np.insert((v_nodes + np.ones((v_nodes.shape[0]))),0,0).astype(int) # 0 is the focal vehicle 99 | adj_matrix_v = adj[v_nodes][:,v_nodes] 100 | veh_u, veh_v = np.nonzero(adj_matrix_v) 101 | 102 | # Pedestrians 103 | ped_mask=element['inputs']['surrounding_agent_representation']['pedestrian_masks'] 104 | num_p = np.where(ped_mask[:,:,0]==0)[0].max()+1 if len(np.where(ped_mask[:,:,0]==0)[0])>0 else 0 105 | ped_veh_u, ped_veh_v = np.where(adj[num_v:num_v+num_p, v_nodes] == 1) 106 | # mask those pedestrians that don't appear in the interaction graph v2 107 | max_p_in_graph = ped_veh_u.max()+1 if len(ped_veh_u)>0 else 0 108 | ped_mask[max_p_in_graph:,:, :] = 1 109 | element['inputs']['surrounding_agent_representation']['pedestrian_masks'] = ped_mask 110 | 111 | # Objects 112 | # obj_mask=element['inputs']['surrounding_agent_representation']['object_masks'] 113 | # num_o = np.where(obj_mask[:,:,0]==0)[0].max()+1 if len(np.where(obj_mask[:,:,0]==0)[0])>0 else 0 114 | # obj_veh_u, obj_veh_v = np.where(adj[num_v+num_p:num_v+num_p+num_o, v_nodes] == 1) 115 | # max_o_in_graph = obj_veh_u.max()+1 if len(obj_veh_u)>0 else 0 116 | # obj_mask[max_o_in_graph:,:, :] = 1 117 | # element['inputs']['surrounding_agent_representation']['object_masks'] = obj_mask 118 | 119 | # Lane graph 120 | lane_veh_adj_matrix = element['inputs']['agent_node_masks']['vehicles'].transpose(1,0) 121 | # Remove masked lanes 122 | lane_veh_adj_matrix = np.delete(lane_veh_adj_matrix, mask_out_lanes, 1) 123 | # To keep the indexing consistent, we set to 0 the edge type of masked out lanes, i.e. no edge. 124 | if len(mask_out_lanes) > 0: 125 | for lane in mask_out_lanes: 126 | element['inputs']['map_representation']['edge_type'] = np.where( element['inputs']['map_representation']['s_next'] == lane, 0, element['inputs']['map_representation']['edge_type']) 127 | element['inputs']['map_representation']['s_next'][mask_out_lanes] = 0 128 | # Update with new masked out vehicles 129 | lane_veh_adj_matrix = lane_veh_adj_matrix[v_nodes_mask] #num_nbr_vehicles x 164 130 | # create a mask for the lanes that are not empty 131 | lane_mask = np.delete( (~(lane_node_masks[:,:,0]!=0)).any(-1), mask_out_lanes) # 164 132 | # Add row of zeros to the adjacency matrix to account for the focal vehicle 133 | lane_veh_adj_matrix = np.vstack((~lane_mask*1, lane_veh_adj_matrix)) # num_veh x 164 134 | veh_lane_u, veh_lane_v = np.where(lane_veh_adj_matrix==0) 135 | succ_adj_matrix = np.delete( np.delete( element['inputs']['map_representation']['succ_adj_matrix'], mask_out_lanes, 1), mask_out_lanes, 0) 136 | prox_adj_matrix = np.delete( np.delete( element['inputs']['map_representation']['prox_adj_matrix'], mask_out_lanes, 1), mask_out_lanes, 0) 137 | succ_u, succ_v = np.nonzero(succ_adj_matrix)[0], np.nonzero(succ_adj_matrix)[1] 138 | prox_u, prox_v = np.nonzero(prox_adj_matrix) 139 | 140 | # Create heterogeneous graph 141 | lanes_graphs.append(dgl.heterograph({ 142 | ('l','successor','l'): (torch.tensor(succ_u, dtype=torch.int), torch.tensor(succ_v, dtype=torch.int)), 143 | ('l','proximal','l'): (torch.tensor(prox_u, dtype=torch.int), torch.tensor(prox_v, dtype=torch.int)), 144 | ('v', 'v_close_l','l'): (torch.tensor(veh_lane_u, dtype=torch.int), torch.tensor(veh_lane_v, dtype=torch.int)), 145 | ('v', 'v_interact_v','v'): (torch.tensor(veh_u, dtype=torch.int), torch.tensor(veh_v, dtype=torch.int)), 146 | ('p', 'p_interact_v','v'): (torch.tensor(ped_veh_u, dtype=torch.int), torch.tensor(ped_veh_v, dtype=torch.int)), 147 | #('o', 'o_interact_v','v'): (torch.tensor(obj_veh_u, dtype=torch.int), torch.tensor(obj_veh_v, dtype=torch.int)), 148 | }) ) 149 | 150 | assert len(np.nonzero(veh_mask[:,:,0].sum(-1)<5)[0])+1 == lanes_graphs[-1].num_nodes('v') 151 | 152 | lanes_batched_graph = dgl.batch(lanes_graphs) 153 | data = default_collate(batch) 154 | data['inputs']['lanes_graphs'] = lanes_batched_graph 155 | 156 | return data 157 | 158 | 159 | 160 | def collate_fn_dgl_hetero(batch): 161 | # Collate function for dataloader. 162 | lanes_graphs = [] 163 | for element in batch: 164 | # Interaction graph 165 | adj = element['inputs']['surrounding_agent_representation']['adj_matrix'] 166 | lane_veh_adj_matrix = element['inputs']['agent_node_masks']['vehicles'].transpose(1,0) # 84 x 164 167 | lane_node_masks = element['inputs']['map_representation']['lane_node_masks'] 168 | # create a mask for the lanes that are not empty 169 | lane_mask = (~(lane_node_masks[:,:,0]!=0)).any(-1) # 164 170 | # Add row of zeros to the adjacency matrix to account for the focal vehicle 171 | lane_veh_adj_matrix = np.vstack((~lane_mask*1, lane_veh_adj_matrix)) # 85 x 164 172 | veh_lane_u, veh_lane_v = np.where(lane_veh_adj_matrix==0) 173 | veh_mask=element['inputs']['surrounding_agent_representation']['vehicle_masks'] 174 | ped_mask=element['inputs']['surrounding_agent_representation']['pedestrian_masks'] 175 | num_v = np.where(veh_mask[:,:,0]==0)[0].max()+2 if len(np.where(veh_mask[:,:,0]==0)[0])>0 else 0 # +1 to account for focal agent which is not present in veh_mask 176 | num_p = np.where(ped_mask[:,:,0]==0)[0].max()+1 if len(np.where(ped_mask[:,:,0]==0)[0])>0 else 0 177 | adj_matrix_v = adj[:num_v, :num_v] 178 | veh_u, veh_v = np.nonzero(adj_matrix_v) 179 | ped_veh_u, ped_veh_v = np.where(adj[num_v:num_v+num_p, :num_v] == 1) 180 | # mask those pedestrians that don't appear in the interaction graph v2p 181 | max_p_in_graph = ped_veh_u.max()+1 if len(ped_veh_u)>0 else 0 182 | ped_mask[max_p_in_graph:,:, :] = 1 183 | element['inputs']['surrounding_agent_representation']['pedestrian_masks'] = ped_mask 184 | # interaction_graphs.append(dgl.from_scipy(spp.coo_matrix(adj_matrix)).int()) 185 | # Lane graph 186 | succ_adj_matrix = element['inputs']['map_representation']['succ_adj_matrix'] 187 | prox_adj_matrix = element['inputs']['map_representation']['prox_adj_matrix'] 188 | 189 | succ_u, succ_v = np.nonzero(succ_adj_matrix)[0], np.nonzero(succ_adj_matrix)[1] 190 | prox_u, prox_v = np.nonzero(prox_adj_matrix) 191 | lanes_graphs.append( 192 | dgl.heterograph({ 193 | ('l','successor','l'): (torch.tensor(succ_u, dtype=torch.int), torch.tensor(succ_v, dtype=torch.int)), 194 | ('l','proximal','l'): (torch.tensor(prox_u, dtype=torch.int), torch.tensor(prox_v, dtype=torch.int)), 195 | ('v', 'v_close_l','l'): (torch.tensor(veh_lane_u, dtype=torch.int), torch.tensor(veh_lane_v, dtype=torch.int)), 196 | ('v', 'v_interact_v','v'): (torch.tensor(veh_u, dtype=torch.int), torch.tensor(veh_v, dtype=torch.int)), 197 | ('p', 'p_interact_v','v'): (torch.tensor(ped_veh_u, dtype=torch.int), torch.tensor(ped_veh_v, dtype=torch.int)), 198 | }) ) 199 | 200 | 201 | lanes_batched_graph = dgl.batch(lanes_graphs) 202 | 203 | data = default_collate(batch) 204 | data['inputs']['lanes_graphs'] = lanes_batched_graph 205 | return data -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | from train_eval.visualizer import Visualizer 4 | import os 5 | 6 | # Parse arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("-c", "--config", help="Config file with dataset parameters", required=True) 9 | parser.add_argument("-r", "--data_root", help="Root directory with data", required=True) 10 | parser.add_argument("-d", "--data_dir", help="Directory to extract data", required=True) 11 | parser.add_argument("-o", "--output_dir", help="Directory to save results", required=True) 12 | parser.add_argument("-w", "--checkpoint", help="Path to pre-trained or intermediate checkpoint", required=True) 13 | parser.add_argument("--num_modes", help="Number of modes to visualize", type=int, default=10) 14 | parser.add_argument("--example", help="Example to visualize", type=int, default=1) 15 | parser.add_argument("--tf", help="Prediction horizon in seconds", type=int, default=6) 16 | parser.add_argument("--show_predictions", help="Show predictions", action="store_true", default=True) 17 | parser.add_argument("--counterfactual", help="Include counterfactual", action="store_true") 18 | parser.add_argument("--mask_lane", help="Mask gt lanes", action="store_true") 19 | parser.add_argument("--name", type=str, default='') 20 | args = parser.parse_args() 21 | 22 | 23 | # Make directories 24 | if not os.path.isdir(args.output_dir): 25 | os.mkdir(args.output_dir) 26 | if not os.path.isdir(os.path.join(args.output_dir, 'results')): 27 | os.mkdir(os.path.join(args.output_dir, 'results')) 28 | 29 | 30 | # Load config 31 | with open(args.config, 'r') as yaml_file: 32 | cfg = yaml.safe_load(yaml_file) 33 | 34 | 35 | # Visualize 36 | vis = Visualizer(cfg, args.data_root, args.data_dir, args.checkpoint, args.example,args.show_predictions, 37 | args.tf, args.num_modes, args.counterfactual, args.mask_lane, args.name) 38 | vis.visualize(output_dir=args.output_dir, dataset_type=cfg['dataset']) 39 | --------------------------------------------------------------------------------