├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ └── nll.py │ └── modules │ │ ├── __init__.py │ │ ├── pos_emb.py │ │ ├── mlp.py │ │ ├── point_net.py │ │ ├── multi_modal.py │ │ ├── rpe.py │ │ ├── decoder_ensemble.py │ │ ├── attention.py │ │ └── transformer.py ├── utils │ ├── __init__.py │ ├── pose_pe.py │ ├── transform_utils.py │ └── submission.py ├── callbacks │ ├── __init__.py │ └── wandb_callbacks.py ├── data_modules │ ├── __init__.py │ ├── post_processing.py │ ├── scene_centric.py │ ├── data_h5_womd.py │ ├── data_h5_av2.py │ ├── sc_global.py │ └── sc_relative.py ├── pl_modules │ └── __init__.py ├── run.py └── pack_h5_av2.py ├── docs ├── hptr_banner.png ├── hptr_teaser.png ├── hptr_efficiency.png └── ablation_models.md ├── configs ├── resume │ ├── empty.yaml │ ├── sub_av2.yaml │ └── sub_womd.yaml ├── loggers │ └── wandb.yaml ├── datamodule │ ├── h5_av2.yaml │ └── h5_womd.yaml ├── run.yaml ├── trainer │ ├── av2.yaml │ └── womd.yaml ├── callbacks │ └── wandb.yaml └── model │ ├── acg_womd.yaml │ ├── scg_womd.yaml │ ├── scr_av2.yaml │ └── scr_womd.yaml ├── bash ├── pack_h5.sh ├── train.sh └── submission.sh ├── .gitignore ├── env_av2.yml ├── README.md └── environment.yml /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/pl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/hptr_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_banner.png -------------------------------------------------------------------------------- /docs/hptr_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_teaser.png -------------------------------------------------------------------------------- /configs/resume/empty.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: null 2 | resume_trainer: True 3 | model_overrides: {} 4 | -------------------------------------------------------------------------------- /docs/hptr_efficiency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_efficiency.png -------------------------------------------------------------------------------- /configs/loggers/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: pytorch_lightning.loggers.WandbLogger 3 | project: debug 4 | group: null 5 | name: my_name 6 | notes: my_notes 7 | tags: [] 8 | job_type: train 9 | entity: YOUR_ENTITY 10 | -------------------------------------------------------------------------------- /configs/datamodule/h5_av2.yaml: -------------------------------------------------------------------------------- 1 | _target_: data_modules.data_h5_av2.DataH5av2 2 | 3 | data_dir: /cluster/scratch/zhejzhan/h5_av2_hptr 4 | filename_train: training 5 | filename_val: validation 6 | filename_test: testing 7 | n_agent: 64 8 | 9 | batch_size: 3 10 | num_workers: 4 11 | -------------------------------------------------------------------------------- /configs/datamodule/h5_womd.yaml: -------------------------------------------------------------------------------- 1 | _target_: data_modules.data_h5_womd.DataH5womd 2 | 3 | data_dir: /cluster/scratch/zhejzhan/h5_womd_hptr 4 | filename_train: training 5 | filename_val: validation 6 | filename_test: testing 7 | n_agent: 64 8 | 9 | batch_size: 3 10 | num_workers: 4 11 | -------------------------------------------------------------------------------- /configs/run.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - trainer: womd 6 | - model: scr_womd 7 | - datamodule: h5_womd 8 | - callbacks: wandb 9 | - loggers: wandb 10 | - resume: empty 11 | 12 | hydra: 13 | run: 14 | dir: logs/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | 16 | work_dir: ${hydra:runtime.cwd} 17 | seed: 2023 18 | action: fit # fit, validate, test 19 | -------------------------------------------------------------------------------- /configs/trainer/av2.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | limit_train_batches: 0.5 4 | limit_val_batches: 0.25 5 | limit_test_batches: 1.0 6 | num_sanity_val_steps: 0 7 | max_epochs: null 8 | min_epochs: null 9 | # max_time: 00:120:00:00 10 | 11 | log_every_n_steps: 200 12 | 13 | gradient_clip_val: 0.5 14 | track_grad_norm: 2 15 | gpus: -1 16 | precision: 32 17 | benchmark: False 18 | deterministic: False 19 | sync_batchnorm: False 20 | detect_anomaly: False 21 | accumulate_grad_batches: 1 22 | resume_from_checkpoint: null 23 | enable_progress_bar: True 24 | -------------------------------------------------------------------------------- /configs/trainer/womd.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | limit_train_batches: 0.25 4 | limit_val_batches: 0.25 5 | limit_test_batches: 1.0 6 | num_sanity_val_steps: 0 7 | max_epochs: null 8 | min_epochs: null 9 | # max_time: 00:120:00:00 10 | 11 | log_every_n_steps: 200 12 | 13 | gradient_clip_val: 0.5 14 | track_grad_norm: 2 15 | gpus: -1 16 | precision: 32 17 | benchmark: False 18 | deterministic: False 19 | sync_batchnorm: False 20 | detect_anomaly: False 21 | accumulate_grad_batches: 1 22 | resume_from_checkpoint: null 23 | enable_progress_bar: True 24 | -------------------------------------------------------------------------------- /configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: callbacks.wandb_callbacks.ModelCheckpointWB 3 | dirpath: "checkpoints/" 4 | filename: "{epoch:02d}" 5 | monitor: "val/loss" # name of the logged metric which determines when model is improving 6 | save_top_k: 1 # save k best models (determined by above metric) 7 | save_last: True # additionaly always save model from last epoch 8 | mode: "min" # can be "max" or "min" 9 | verbose: True 10 | save_only_best: True # if True, only save best model according to "monitor" metric 11 | 12 | lr_monitor: 13 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 14 | 15 | stochastic_weight_avg: 16 | _target_: pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging -------------------------------------------------------------------------------- /configs/resume/sub_av2.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: null 2 | resume_trainer: True 3 | model_overrides: 4 | n_video_batch: 0 5 | post_processing: 6 | to_dict: 7 | _target_: data_modules.post_processing.ToDict 8 | predictions: [pos, cov3, spd, vel, yaw_bbox] 9 | get_cov_mat: 10 | _target_: data_modules.post_processing.GetCovMat 11 | rho_clamp: 5.0 12 | std_min: -1.609 13 | std_max: 5.0 14 | waymo: 15 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 16 | k_pred: 6 17 | use_ade: True 18 | score_temperature: 0.5 19 | mpa_nms_thresh: [] # veh, ped, cyc 20 | gt_in_local: True 21 | sub_womd: 22 | _target_: utils.submission.SubWOMD 23 | activate: False 24 | method_name: METHOD_NAME 25 | authors: [NAME1, NAME2] 26 | affiliation: AFFILIATION 27 | description: scr_womd 28 | method_link: METHOD_LINK 29 | account_name: ACCOUNT_NAME 30 | sub_av2: 31 | _target_: utils.submission.SubAV2 32 | activate: True 33 | -------------------------------------------------------------------------------- /configs/resume/sub_womd.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: null 2 | resume_trainer: True 3 | model_overrides: 4 | n_video_batch: 0 5 | post_processing: 6 | to_dict: 7 | _target_: data_modules.post_processing.ToDict 8 | predictions: [pos, cov3, spd, vel, yaw_bbox] 9 | get_cov_mat: 10 | _target_: data_modules.post_processing.GetCovMat 11 | rho_clamp: 5.0 12 | std_min: -1.609 13 | std_max: 5.0 14 | waymo: 15 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 16 | k_pred: 6 17 | use_ade: True 18 | score_temperature: -1 19 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc 20 | gt_in_local: True 21 | sub_womd: 22 | _target_: utils.submission.SubWOMD 23 | activate: True 24 | method_name: HPTR 25 | authors: [Zhejun Zhang, Alexander Liniger, Christos Sakaridis, Fisher Yu, Luc Van Gool] 26 | affiliation: "CVL, ETH Zurich" 27 | description: "Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding. NeurIPS 2023. https://github.com/zhejz/HPTR" 28 | method_link: "https://github.com/zhejz/HPTR" 29 | account_name: "YOUR_ACCOUNT_NAME" 30 | sub_av2: 31 | _target_: utils.submission.SubAV2 32 | activate: False 33 | -------------------------------------------------------------------------------- /bash/pack_h5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=./logs/%j.out 3 | #SBATCH --error=./logs/%j.out 4 | #SBATCH --time=120:00:00 5 | #SBATCH -n 1 6 | #SBATCH --cpus-per-task=12 7 | #SBATCH --mem-per-cpu=5000 8 | #SBATCH --tmp=200000 9 | #SBATCH --open-mode=truncate 10 | 11 | trap "echo sigterm recieved, exiting!" SIGTERM 12 | 13 | run () { 14 | python -u src/pack_h5_womd.py --dataset=training \ 15 | --out-dir=/cluster/scratch/zhejzhan/h5_womd_hptr \ 16 | --data-dir=/cluster/scratch/zhejzhan/womd_scenario_v_1_2_0 17 | } 18 | 19 | # ! for validation and testing 20 | # python -u scripts/pack_h5_womd.py --dataset=validation --rand-pos=-1 --rand-yaw=-1 \ 21 | # python -u scripts/pack_h5_womd.py --dataset=testing --rand-pos=-1 --rand-yaw=-1 \ 22 | 23 | # ! for packing av2 24 | # conda activate hptr_av2 25 | # run () { 26 | # python -u src/pack_h5_av2.py --dataset=training \ 27 | # --out-dir=/cluster/scratch/zhejzhan/h5_av2_hptr \ 28 | # --data-dir=/cluster/scratch/zhejzhan/av2_motion 29 | # } 30 | 31 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh 32 | conda activate hptr # for av2: conda activate hptr_av2 33 | 34 | echo Running on host: `hostname` 35 | echo In directory: `pwd` 36 | echo Starting on: `date` 37 | 38 | 39 | type run 40 | echo START: `date` 41 | run & 42 | wait 43 | echo DONE: `date` 44 | 45 | mkdir -p ./logs/slurm 46 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out 47 | 48 | echo finished at: `date` 49 | exit 0; -------------------------------------------------------------------------------- /bash/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=./logs/%j.out 3 | #SBATCH --error=./logs/%j.out 4 | #SBATCH --time=120:00:00 5 | #SBATCH -n 1 6 | #SBATCH --cpus-per-task=12 7 | #SBATCH --mem-per-cpu=5000 8 | #SBATCH --tmp=200000 9 | #SBATCH --gpus=rtx_2080_ti:4 10 | #SBATCH --open-mode=truncate 11 | 12 | trap "echo sigterm recieved, exiting!" SIGTERM 13 | 14 | DATASET_DIR="h5_womd_hptr" 15 | run () { 16 | python -u src/run.py \ 17 | trainer=womd \ 18 | model=scr_womd \ 19 | datamodule=h5_womd \ 20 | loggers.wandb.name="hptr_womd" \ 21 | loggers.wandb.project="hptr_train" \ 22 | loggers.wandb.entity="YOUR_ENTITY" \ 23 | datamodule.data_dir=${TMPDIR}/datasets \ 24 | hydra.run.dir='/cluster/scratch/zhejzhan/logs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 25 | } 26 | 27 | # ! For AV2 dataset. 28 | # DATASET_DIR="h5_av2_hptr" 29 | # trainer=av2 \ 30 | # model=scr_av2 \ 31 | # datamodule=h5_av2 \ 32 | 33 | # ! To resume training. 34 | # resume.checkpoint=YOUR_WANDB_RUN_NAME:latest \ 35 | 36 | 37 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh 38 | conda activate hptr 39 | 40 | echo Running on host: `hostname` 41 | echo In directory: `pwd` 42 | echo Starting on: `date` 43 | 44 | echo START copying data: `date` 45 | mkdir $TMPDIR/datasets 46 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/training.h5 $TMPDIR/datasets/ 47 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/validation.h5 $TMPDIR/datasets/ 48 | echo DONE copying: `date` 49 | 50 | type run 51 | echo START: `date` 52 | run & 53 | wait 54 | echo DONE: `date` 55 | 56 | mkdir -p ./logs/slurm 57 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out 58 | 59 | echo finished at: `date` 60 | exit 0; 61 | -------------------------------------------------------------------------------- /bash/submission.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --output=./logs/%j.out 3 | #SBATCH --error=./logs/%j.out 4 | #SBATCH --time=4:00:00 5 | #SBATCH -n 1 6 | #SBATCH --cpus-per-task=8 7 | #SBATCH --mem-per-cpu=5000 8 | #SBATCH --tmp=100000 9 | #SBATCH --gpus=rtx_2080_ti:1 10 | #SBATCH --open-mode=truncate 11 | 12 | trap "echo sigterm recieved, exiting!" SIGTERM 13 | 14 | DATASET_DIR="h5_womd_hptr" 15 | run () { 16 | python -u src/run.py \ 17 | trainer=womd \ 18 | model=scr_womd \ 19 | datamodule=h5_womd \ 20 | resume=sub_womd \ 21 | action=validate \ 22 | trainer.limit_val_batches=1.0 \ 23 | resume.checkpoint=YOUR_WANDB_RUN_NAME:latest \ 24 | loggers.wandb.name="hptr_womd_val" \ 25 | loggers.wandb.project="hptr_sub" \ 26 | loggers.wandb.entity="YOUR_ENTITY" \ 27 | datamodule.data_dir=${TMPDIR}/datasets \ 28 | hydra.run.dir='/cluster/scratch/zhejzhan/logs/${now:%Y-%m-%d}/${now:%H-%M-%S}' 29 | } 30 | 31 | # ! For AV2 dataset. 32 | # DATASET_DIR="h5_av2_hptr" 33 | # trainer=av2 \ 34 | # model=scr_av2 \ 35 | # datamodule=h5_av2 \ 36 | # resume=sub_av2 \ 37 | 38 | # ! For testing. 39 | # action=test \ 40 | 41 | 42 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh 43 | conda activate hptr 44 | 45 | echo Running on host: `hostname` 46 | echo In directory: `pwd` 47 | echo Starting on: `date` 48 | 49 | echo START copying data: `date` 50 | mkdir $TMPDIR/datasets 51 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/validation.h5 $TMPDIR/datasets/ 52 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/testing.h5 $TMPDIR/datasets/ 53 | echo DONE copying: `date` 54 | 55 | type run 56 | echo START: `date` 57 | run & 58 | wait 59 | echo DONE: `date` 60 | 61 | mkdir -p ./logs/slurm 62 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out 63 | 64 | echo finished at: `date` 65 | exit 0; 66 | -------------------------------------------------------------------------------- /src/models/modules/pos_emb.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import torch 3 | from torch import Tensor, nn 4 | 5 | 6 | class PositionalEmbedding(nn.Module): 7 | def __init__(self, dim: int, theta: float = 10000): 8 | super().__init__() 9 | assert dim % 2 == 0 10 | self.dim = dim 11 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 12 | # [dim] 13 | freqs = freqs.repeat_interleave(2, 0) 14 | self.register_buffer("freqs", freqs) 15 | 16 | def forward(self, x: Tensor): 17 | """ 18 | Args: 19 | x: [...] 20 | Returns: 21 | pos_enc: [..., dim] 22 | """ 23 | # [..., dim] 24 | pos_enc = x.unsqueeze(-1) * self.freqs.view([1] * x.dim() + [-1]) 25 | pos_enc = torch.cat([torch.cos(pos_enc[..., ::2]), torch.sin(pos_enc[..., 1::2])], dim=-1) 26 | return pos_enc 27 | 28 | 29 | class PositionalEmbeddingRad(nn.Module): 30 | def __init__(self, dim: int): 31 | """ 32 | if dim=2, then just [cos(theta), sin(theta)] 33 | """ 34 | super().__init__() 35 | assert dim % 2 == 0 36 | self.dim = dim 37 | # [dim]: [1,1,2,2,4,4,8,8] 38 | # freqs = 2 ** (torch.arange(0, dim // 2).float()) 39 | # [dim]: [1,1,2,2,3,3,4,4] 40 | freqs = torch.arange(0, dim // 2) + 1.0 41 | freqs = freqs.repeat_interleave(2, 0) 42 | self.register_buffer("freqs", freqs) 43 | 44 | def forward(self, x: Tensor): 45 | """ 46 | Args: 47 | x: [...], in rad 48 | Returns: 49 | pos_enc: [..., dim] 50 | """ 51 | # [..., dim] 52 | pos_enc = x.unsqueeze(-1) * self.freqs.view([1] * x.dim() + [-1]) 53 | pos_enc = torch.cat([torch.cos(pos_enc[..., ::2]), torch.sin(pos_enc[..., 1::2])], dim=-1) 54 | return pos_enc 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # my ignore 132 | .vscode/settings.json 133 | logs/ 134 | runs/ 135 | outputs/ 136 | wandb/ 137 | data/ 138 | *.pyc -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import hydra 3 | import torch 4 | from omegaconf import DictConfig 5 | from typing import List 6 | from pytorch_lightning import seed_everything, LightningDataModule, LightningModule, Trainer, Callback 7 | from pytorch_lightning.loggers import LightningLoggerBase 8 | import os 9 | 10 | 11 | def download_checkpoint(loggers, wb_ckpt) -> None: 12 | if os.environ.get("LOCAL_RANK", 0) == 0: 13 | artifact = loggers[0].experiment.use_artifact(wb_ckpt, type="model") 14 | artifact_dir = artifact.download("ckpt") 15 | 16 | 17 | @hydra.main(config_path="../configs/", config_name="run.yaml") 18 | def main(config: DictConfig) -> None: 19 | 20 | seed_everything(config.seed, workers=True) 21 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 22 | 23 | callbacks: List[Callback] = [] 24 | if "callbacks" in config: 25 | for _, cb_conf in config.callbacks.items(): 26 | callbacks.append(hydra.utils.instantiate(cb_conf)) 27 | 28 | loggers: List[LightningLoggerBase] = [] 29 | if "loggers" in config: 30 | for _, lg_conf in config.loggers.items(): 31 | loggers.append(hydra.utils.instantiate(lg_conf)) 32 | 33 | if config.resume.checkpoint is None: 34 | model: LightningModule = hydra.utils.instantiate( 35 | config.model, data_size=datamodule.tensor_size_train, _recursive_=False 36 | ) 37 | else: 38 | download_checkpoint(loggers, config.resume.checkpoint) 39 | ckpt_path = "ckpt/model.ckpt" 40 | modelClass = hydra.utils.get_class(config.model._target_) 41 | 42 | model = modelClass.load_from_checkpoint( 43 | ckpt_path, wb_artifact=config.resume.checkpoint, **config.resume.model_overrides 44 | ) 45 | 46 | if config.resume.resume_trainer and config.action == "fit": 47 | config.trainer.resume_from_checkpoint = ckpt_path 48 | 49 | # from pytorch_lightning.plugins import DDPPlugin 50 | # strategy = DDPPlugin(gradient_as_bucket_view=True) 51 | strategy = None 52 | if torch.cuda.device_count() > 1: 53 | strategy = "ddp" 54 | trainer: Trainer = hydra.utils.instantiate( 55 | config.trainer, strategy=strategy, callbacks=callbacks, logger=loggers, _convert_="partial" 56 | ) 57 | if config.action == "fit": 58 | trainer.fit(model=model, datamodule=datamodule) 59 | elif config.action == "validate": 60 | trainer.validate(model=model, datamodule=datamodule) 61 | elif config.action == "test": 62 | trainer.test(model=model, datamodule=datamodule) 63 | else: 64 | raise NotImplementedError 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /docs/ablation_models.md: -------------------------------------------------------------------------------- 1 | # Configurations of Ablation Models 2 | 3 | Use the following configurations and adapt [bash/train.sh](../bash/train.sh) to train the ablation models. 4 | 5 | ## Input Representation 6 | - Our `HPTR` for Waymo dataset. The model has 15.2M parameters. 7 | ``` 8 | model=scr_womd \ 9 | ``` 10 | - Our `HPTR` for AV2 dataset. 11 | ``` 12 | trainer=av2 \ 13 | model=scr_av2 \ 14 | datamodule=h5_av2 \ 15 | ``` 16 | - Agent-centric baseline [Wayformer](https://arxiv.org/abs/2207.05844), i.e. `WF baseline`. 17 | ``` 18 | model=acg_womd \ 19 | ``` 20 | - Scene-centric baseline [SceneTransformer](https://arxiv.org/abs/2106.08417), i.e. `HPTR SC`. 21 | ``` 22 | model=scg_womd \ 23 | ``` 24 | 25 | ## Hierarchical Architecture 26 | 27 | - `HPTR diag+full` with 15.4M parameters. It needs RTX 3090 for training. 28 | ``` 29 | model.model.intra_class_encoder.n_layer_tf_map=6 \ 30 | model.model.intra_class_encoder.n_layer_tf_tl=2 \ 31 | model.model.intra_class_encoder.n_layer_tf_agent=2 \ 32 | model.model.decoder.tf_n_layer=2 \ 33 | model.model.decoder.k_reinforce_tl=-1 \ 34 | model.model.decoder.k_reinforce_agent=-1 \ 35 | model.model.decoder.k_reinforce_all=1 \ 36 | ``` 37 | - `HPTR diag` with 15.4M parameters. 38 | ``` 39 | model.model.intra_class_encoder.n_layer_tf_map=6 \ 40 | model.model.intra_class_encoder.n_layer_tf_tl=3 \ 41 | model.model.intra_class_encoder.n_layer_tf_agent=3 \ 42 | model.model.decoder.tf_n_layer=2 \ 43 | model.model.decoder.k_reinforce_tl=-1 \ 44 | model.model.decoder.k_reinforce_agent=-1 \ 45 | ``` 46 | - `HPTR full` with 15.2M parameters. It needs RTX 3090 for training. 47 | ``` 48 | model.model.intra_class_encoder.n_layer_tf_map=-1 \ 49 | model.model.decoder.tf_n_layer=6 \ 50 | model.model.decoder.k_reinforce_tl=-1 \ 51 | model.model.decoder.k_reinforce_agent=-1 \ 52 | model.model.decoder.k_reinforce_all=1 \ 53 | ``` 54 | 55 | ## Others 56 | - Different polyline embedding. 57 | ``` 58 | model.pre_processing.relative.pose_pe.agent=xy_dir \ 59 | model.pre_processing.relative.pose_pe.map=xy_dir \ 60 | ``` 61 | - Attention without bias. 62 | ``` 63 | model.model.tf_cfg.bias=False \ 64 | ``` 65 | - Different RPE mode. 66 | ``` 67 | model.model.rpe_mode=xy_dir \ 68 | model.model.rpe_mode=pe_xy_dir \ 69 | ``` 70 | - Apply RPE to query. It needs RTX 3090 for training. 71 | ``` 72 | model.model.tf_cfg.apply_q_rpe=True \ 73 | ``` 74 | - Without anchor reinforce (17.5M parameters). 75 | ``` 76 | model.model.decoder.tf_n_layer=3 \ 77 | model.model.decoder.k_reinforce_agent=8 \ 78 | model.model.decoder.k_reinforce_anchor=-1 \ 79 | ``` 80 | - Without anchor reinforce, larger model (23.3 parameters). 81 | ``` 82 | model.model.n_tgt_knn=50 \ 83 | model.model.intra_class_encoder.n_layer_tf_map=6 \ 84 | model.model.decoder.tf_n_layer=4 \ 85 | model.model.decoder.k_reinforce_agent=8 \ 86 | model.model.decoder.k_reinforce_anchor=-1 \ 87 | ``` -------------------------------------------------------------------------------- /src/models/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import List, Tuple, Union, Optional 3 | from torch import Tensor, nn 4 | 5 | 6 | def _get_activation(activation: str, inplace: bool) -> nn.Module: 7 | if activation == "relu": 8 | return nn.ReLU(inplace=inplace) 9 | elif activation == "gelu": 10 | return nn.GELU() 11 | raise RuntimeError("activation {} not implemented".format(activation)) 12 | 13 | 14 | class MLP(nn.Module): 15 | def __init__( 16 | self, 17 | fc_dims: Union[List, Tuple], 18 | dropout_p: Optional[float] = None, 19 | use_layernorm: bool = False, 20 | activation: str = "relu", 21 | end_layer_activation: bool = True, 22 | init_weight_norm: bool = False, 23 | init_bias: Optional[float] = None, 24 | use_batchnorm: bool = False, 25 | ) -> None: 26 | super(MLP, self).__init__() 27 | assert len(fc_dims) >= 2 28 | assert not (use_layernorm and use_batchnorm) 29 | layers: List[nn.Module] = [] 30 | for i in range(0, len(fc_dims) - 1): 31 | 32 | fc = nn.Linear(fc_dims[i], fc_dims[i + 1]) 33 | 34 | if init_weight_norm: 35 | fc.weight.data *= 1.0 / fc.weight.norm(dim=1, p=2, keepdim=True) 36 | if init_bias is not None and i == len(fc_dims) - 2: 37 | fc.bias.data *= 0 38 | fc.bias.data += init_bias 39 | 40 | layers.append(fc) 41 | 42 | if i < len(fc_dims) - 2: 43 | if use_layernorm: 44 | layers.append(nn.LayerNorm(fc_dims[i + 1])) 45 | elif use_batchnorm: 46 | layers.append(nn.BatchNorm1d(fc_dims[i + 1])) 47 | if dropout_p is not None: 48 | layers.append(nn.Dropout(p=dropout_p)) 49 | layers.append(_get_activation(activation, inplace=True)) 50 | if i == len(fc_dims) - 2: 51 | if end_layer_activation: 52 | if use_layernorm: 53 | layers.append(nn.LayerNorm(fc_dims[i + 1])) 54 | elif use_batchnorm: 55 | layers.append(nn.BatchNorm1d(fc_dims[i + 1])) 56 | if dropout_p is not None: 57 | layers.append(nn.Dropout(p=dropout_p)) 58 | self.end_layer_activation = _get_activation(activation, inplace=True) 59 | else: 60 | self.end_layer_activation = None 61 | 62 | self.input_dim = fc_dims[0] 63 | self.output_dim = fc_dims[-1] 64 | self.fc_layers = nn.Sequential(*layers) 65 | 66 | def forward(self, x: Tensor, valid_mask: Optional[Tensor] = None, fill_invalid: float = 0.0) -> Tensor: 67 | """ 68 | Args: 69 | x: [..., input_dim] 70 | valid_mask: [...] 71 | Returns: 72 | x: [..., output_dim] 73 | """ 74 | x = self.fc_layers(x.flatten(0, -2)).view(*x.shape[:-1], self.output_dim) 75 | if valid_mask is not None: 76 | x.masked_fill_(~valid_mask.unsqueeze(-1), fill_invalid) 77 | if self.end_layer_activation is not None: 78 | self.end_layer_activation(x) 79 | return x 80 | -------------------------------------------------------------------------------- /src/models/modules/point_net.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Tuple, List, Optional 3 | import torch 4 | from torch import Tensor, nn 5 | from .mlp import MLP 6 | 7 | 8 | class PointNet(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | n_layer: int = 3, 14 | use_layernorm: bool = False, 15 | use_batchnorm: bool = False, 16 | end_layer_activation: bool = True, 17 | dropout_p: Optional[float] = None, 18 | pool_mode: str = "max", # max, mean, first 19 | ) -> None: 20 | super().__init__() 21 | self.pool_mode = pool_mode 22 | self.input_mlp = MLP( 23 | [input_dim, hidden_dim, hidden_dim], 24 | dropout_p=dropout_p, 25 | use_layernorm=use_layernorm, 26 | use_batchnorm=use_batchnorm, 27 | ) 28 | 29 | mlp_layers: List[nn.Module] = [] 30 | for _ in range(n_layer - 2): 31 | mlp_layers.append( 32 | MLP( 33 | [hidden_dim, hidden_dim // 2], 34 | dropout_p=dropout_p, 35 | use_layernorm=use_layernorm, 36 | use_batchnorm=use_batchnorm, 37 | ) 38 | ) 39 | mlp_layers.append( 40 | MLP( 41 | [hidden_dim, hidden_dim // 2], 42 | dropout_p=dropout_p, 43 | use_layernorm=use_layernorm, 44 | use_batchnorm=use_batchnorm, 45 | end_layer_activation=end_layer_activation, 46 | ) 47 | ) 48 | self.mlp_layers = nn.ModuleList(mlp_layers) 49 | 50 | def forward(self, x: Tensor, valid: Tensor) -> Tuple[Tensor, Tensor]: 51 | """c.f. VectorNet and SceneTransformer, Aggregate polyline/track level feature. 52 | 53 | Args: 54 | x: [n_batch, n_pl, n_pl_node, attr_dim] 55 | valid: [n_batch, n_pl, n_pl_node] bool 56 | 57 | Returns: 58 | emb: [n_batch, n_pl, hidden_dim] 59 | emb_valid: [n_batch, n_pl] 60 | """ 61 | x = self.input_mlp(x, valid) # [n_batch, n_pl, n_pl_node, hidden_dim] 62 | 63 | for mlp in self.mlp_layers: 64 | feature_encoded = mlp(x, valid, float("-inf")) # [n_batch, n_pl, n_pl_node, hidden_dim//2] 65 | feature_pooled = feature_encoded.amax(dim=2, keepdim=True) 66 | x = torch.cat((feature_encoded, feature_pooled.expand(-1, -1, valid.shape[-1], -1)), dim=-1) 67 | 68 | if self.pool_mode == "max": 69 | x.masked_fill_(~valid.unsqueeze(-1), float("-inf")) # [n_batch, n_pl, n_pl_node, hidden_dim] 70 | emb = x.amax(dim=2, keepdim=False) # [n_batch, n_pl, hidden_dim] 71 | elif self.pool_mode == "first": 72 | emb = x[:, :, 0] 73 | elif self.pool_mode == "mean": 74 | x.masked_fill_(~valid.unsqueeze(-1), 0) # [n_batch, n_pl, n_pl_node, hidden_dim] 75 | emb = x.sum(dim=2, keepdim=False) # [batch_size, n_pl, hidden_dim] 76 | emb = emb / (valid.sum(dim=-1, keepdim=True) + torch.finfo(x.dtype).eps) 77 | 78 | emb_valid = valid.any(-1) # [n_batch, n_pl] 79 | emb = emb.masked_fill(~emb_valid.unsqueeze(-1), 0) # [n_batch, n_pl, hidden_dim] 80 | return emb, emb_valid 81 | -------------------------------------------------------------------------------- /src/models/modules/multi_modal.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import torch 3 | from torch import Tensor, nn 4 | from .mlp import MLP 5 | 6 | 7 | class MultiModalAnchors(nn.Module): 8 | def __init__( 9 | self, 10 | mode_emb: str, 11 | mode_init: str, 12 | hidden_dim: int, 13 | n_pred: int, 14 | emb_dim: int, 15 | use_agent_type: bool, 16 | scale: float = 1.0, 17 | ) -> None: 18 | super().__init__() 19 | self.n_pred = n_pred 20 | self.use_agent_type = use_agent_type 21 | 22 | self.mode_init = mode_init 23 | n_anchors = 3 if use_agent_type else 1 24 | if self.mode_init == "xavier": 25 | self.anchors = torch.empty((n_anchors, n_pred, hidden_dim)) 26 | nn.init.xavier_normal_(self.anchors) 27 | self.anchors = nn.Parameter(self.anchors * scale, requires_grad=True) 28 | elif self.mode_init == "uniform": 29 | self.anchors = torch.empty((n_anchors, n_pred, hidden_dim)) 30 | self.anchors.uniform_(-scale, scale) 31 | self.anchors = nn.Parameter(self.anchors, requires_grad=True) 32 | elif self.mode_init == "randn": 33 | self.anchors = nn.Parameter(torch.randn([n_anchors, n_pred, hidden_dim]) * scale, requires_grad=True) 34 | else: 35 | raise NotImplementedError 36 | 37 | self.mode_emb = mode_emb 38 | if self.mode_emb == "linear": 39 | self.mlp_anchor = nn.Linear(self.anchors.shape[-1] + emb_dim, hidden_dim, bias=False) 40 | elif self.mode_emb == "mlp": 41 | self.mlp_anchor = MLP([self.anchors.shape[-1] + emb_dim] + [hidden_dim] * 2, end_layer_activation=False) 42 | elif self.mode_emb == "add" or self.mode_emb == "none": 43 | assert emb_dim == hidden_dim 44 | if self.anchors.shape[-1] != hidden_dim: 45 | self.mlp_anchor = nn.Linear(self.anchors.shape[-1], hidden_dim, bias=False) 46 | else: 47 | self.mlp_anchor = None 48 | else: 49 | raise NotImplementedError 50 | 51 | def forward(self, valid: Tensor, emb: Tensor, agent_type: Tensor) -> Tensor: 52 | """ 53 | Args: 54 | valid: [n_scene*n_agent] 55 | emb: [n_scene*n_agent, in_dim] 56 | agent_type: [n_scene*n_agent, 3] 57 | 58 | Returns: 59 | mm_emb: [n_scene*n_agent, n_pred, out_dim] 60 | """ 61 | # [n_scene*n_agent, n_pred, emb_dim] 62 | if self.use_agent_type: 63 | anchors = (self.anchors.unsqueeze(0) * agent_type[:, :, None, None]).sum(1) 64 | else: 65 | anchors = self.anchors.expand(valid.shape[0], -1, -1) 66 | 67 | if self.mode_emb == "linear" or self.mode_emb == "mlp": 68 | # [n_scene*n_agent, n_pred, hidden_dim + emb_dim] 69 | mm_emb = torch.cat([emb.unsqueeze(1).expand(-1, self.n_pred, -1), anchors], dim=-1) 70 | mm_emb = self.mlp_anchor(mm_emb) 71 | elif self.mode_emb == "add": 72 | if self.mlp_anchor is not None: 73 | anchors = self.mlp_anchor(anchors) # [n_scene*n_agent, n_pred, hidden_dim] 74 | mm_emb = emb.unsqueeze(1) + anchors 75 | elif self.mode_emb == "none": 76 | if self.mlp_anchor is not None: 77 | anchors = self.mlp_anchor(anchors) # [n_scene*n_agent, n_pred, hidden_dim] 78 | mm_emb = anchors 79 | return mm_emb.masked_fill(~valid[:, None, None], 0) 80 | -------------------------------------------------------------------------------- /src/models/modules/rpe.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional, Tuple, Union 3 | import torch 4 | from torch import Tensor 5 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_rad2local 6 | 7 | 8 | @torch.no_grad() 9 | def get_rel_pose(pose: Tensor, invalid: Tensor) -> Tuple[Tensor, Tensor]: 10 | """ 11 | Args: 12 | pose: [n_scene, n_emb, 3], (x,y,yaw), in global coordinate 13 | invalid: [n_scene, n_emb] 14 | 15 | Returns: 16 | rel_pose: [n_scene, n_emb, n_emb, 3] (x,y,yaw) 17 | rel_dist: [n_scene, n_emb, n_emb] 18 | """ 19 | xy = pose[:, :, :2] # [n_scene, n_emb, 2] 20 | yaw = pose[:, :, -1] # [n_scene, n_emb] 21 | rel_pose = torch.cat( 22 | [ 23 | torch_pos2local(xy.unsqueeze(1), xy.unsqueeze(2), torch_rad2rot(yaw)), 24 | torch_rad2local(yaw.unsqueeze(1), yaw, cast=False).unsqueeze(-1), 25 | ], 26 | dim=-1, 27 | ) # [n_scene, n_emb, n_emb, 3] 28 | rel_dist = torch.norm(rel_pose[..., :2], dim=-1) # [n_scene, n_emb, n_emb] 29 | rel_dist.masked_fill_(invalid.unsqueeze(1) | invalid.unsqueeze(2), float("inf")) 30 | return rel_pose, rel_dist 31 | 32 | 33 | @torch.no_grad() 34 | def get_rel_dist(xy: Tensor, invalid: Tensor) -> Tensor: 35 | """ 36 | Args: 37 | xy: [n_scene, n_emb, 2], in global coordinate 38 | invalid: [n_scene, n_emb] 39 | 40 | Returns: 41 | rel_dist: [n_scene, n_emb, n_emb] 42 | """ 43 | rel_dist = torch.norm(xy.unsqueeze(1) - xy.unsqueeze(2), dim=-1) # [n_scene, n_emb, n_emb] 44 | rel_dist.masked_fill_(invalid.unsqueeze(1) | invalid.unsqueeze(2), float("inf")) 45 | return rel_dist 46 | 47 | 48 | @torch.no_grad() 49 | def get_tgt_knn_idx( 50 | tgt_invalid: Tensor, rel_pose: Optional[Tensor], rel_dist: Tensor, n_tgt_knn: int, dist_limit: Union[float, Tensor], 51 | ) -> Tuple[Optional[Tensor], Tensor, Optional[Tensor]]: 52 | """ 53 | Args: 54 | tgt_invalid: [n_scene, n_tgt] 55 | rel_pose: [n_scene, n_src, n_tgt, 3] 56 | rel_dist: [n_scene, n_src, n_tgt] 57 | knn: int, set to <=0 to skip knn, i.e. n_tgt_knn=n_tgt 58 | dist_limit: float, or Tensor [n_scene, n_tgt, 1] 59 | 60 | Returns: 61 | idx_tgt: [n_scene, n_src, n_tgt_knn], or None 62 | tgt_invalid_knn: [n_scene, n_src, n_tgt_knn] 63 | rpe: [n_scene, n_src, n_tgt_knn, 3] 64 | """ 65 | n_scene, n_src, _ = rel_dist.shape 66 | idx_scene = torch.arange(n_scene)[:, None, None] # [n_scene, 1, 1] 67 | idx_src = torch.arange(n_src)[None, :, None] # [1, n_src, 1] 68 | 69 | if 0 < n_tgt_knn < tgt_invalid.shape[1]: 70 | # [n_scene, n_src, n_tgt_knn] 71 | dist_knn, idx_tgt = torch.topk(rel_dist, n_tgt_knn, dim=-1, largest=False, sorted=False) 72 | # [n_scene, n_src, n_tgt_knn] 73 | tgt_invalid_knn = tgt_invalid.unsqueeze(1).expand(-1, n_src, -1)[idx_scene, idx_src, idx_tgt] 74 | # [n_batch, n_src, n_tgt_knn, 3] 75 | if rel_pose is None: 76 | rpe = None 77 | else: 78 | rpe = rel_pose[idx_scene, idx_src, idx_tgt] 79 | else: 80 | dist_knn = rel_dist 81 | tgt_invalid_knn = tgt_invalid.unsqueeze(1).expand(-1, n_src, -1) # [n_scene, n_src, n_tgt] 82 | rpe = rel_pose 83 | idx_tgt = None 84 | 85 | tgt_invalid_knn = tgt_invalid_knn | (dist_knn > dist_limit) 86 | if rpe is not None: 87 | rpe = rpe.masked_fill(tgt_invalid_knn.unsqueeze(-1), 0) 88 | 89 | return idx_tgt, tgt_invalid_knn, rpe 90 | -------------------------------------------------------------------------------- /env_av2.yml: -------------------------------------------------------------------------------- 1 | name: hptr_av2 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - asttokens=2.2.1=pyhd8ed1ab_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=pyhd8ed1ab_3 11 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 12 | - ca-certificates=2022.12.7=ha878542_0 13 | - certifi=2022.12.7=pyhd8ed1ab_0 14 | - debugpy=1.5.1=py39h295c915_0 15 | - decorator=5.1.1=pyhd8ed1ab_0 16 | - entrypoints=0.4=pyhd8ed1ab_0 17 | - executing=1.2.0=pyhd8ed1ab_0 18 | - ipykernel=6.15.0=pyh210e3f2_0 19 | - ipython=8.11.0=pyh41d4057_0 20 | - jedi=0.18.2=pyhd8ed1ab_0 21 | - jupyter_client=7.0.6=pyhd8ed1ab_0 22 | - jupyter_core=5.2.0=py39hf3d152e_0 23 | - ld_impl_linux-64=2.38=h1181459_1 24 | - libffi=3.4.2=h6a678d5_6 25 | - libgcc-ng=11.2.0=h1234567_1 26 | - libgomp=11.2.0=h1234567_1 27 | - libsodium=1.0.18=h36c2ea0_1 28 | - libstdcxx-ng=11.2.0=h1234567_1 29 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 30 | - ncurses=6.4=h6a678d5_0 31 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 32 | - openssl=1.1.1t=h7f8727e_0 33 | - packaging=23.0=pyhd8ed1ab_0 34 | - parso=0.8.3=pyhd8ed1ab_0 35 | - pexpect=4.8.0=pyh1a96a4e_2 36 | - pickleshare=0.7.5=py_1003 37 | - pip=23.0.1=py39h06a4308_0 38 | - platformdirs=3.1.1=pyhd8ed1ab_0 39 | - prompt-toolkit=3.0.38=pyha770c72_0 40 | - prompt_toolkit=3.0.38=hd8ed1ab_0 41 | - psutil=5.9.0=py39h5eee18b_0 42 | - ptyprocess=0.7.0=pyhd3deb0d_0 43 | - pure_eval=0.2.2=pyhd8ed1ab_0 44 | - pygments=2.14.0=pyhd8ed1ab_0 45 | - python=3.9.16=h7a1cb2a_2 46 | - python-dateutil=2.8.2=pyhd8ed1ab_0 47 | - python_abi=3.9=2_cp39 48 | - pyzmq=19.0.2=py39hb69f2a1_2 49 | - readline=8.2=h5eee18b_0 50 | - setuptools=65.6.3=py39h06a4308_0 51 | - six=1.16.0=pyh6c4a22f_0 52 | - sqlite=3.40.1=h5082296_0 53 | - stack_data=0.6.2=pyhd8ed1ab_0 54 | - tk=8.6.12=h1ccaba5_0 55 | - tornado=6.1=py39hb9d737c_3 56 | - traitlets=5.9.0=pyhd8ed1ab_0 57 | - typing-extensions=4.5.0=hd8ed1ab_0 58 | - typing_extensions=4.5.0=pyha770c72_0 59 | - tzdata=2022g=h04d1e81_0 60 | - wcwidth=0.2.6=pyhd8ed1ab_0 61 | - wheel=0.38.4=py39h06a4308_0 62 | - xz=5.2.10=h5eee18b_1 63 | - zeromq=4.3.4=h9c3ff4c_1 64 | - zlib=1.2.13=h5eee18b_0 65 | - pip: 66 | - argcomplete==2.1.1 67 | - av==10.0.0 68 | - av2==0.2.1 69 | - black==23.1.0 70 | - click==8.1.3 71 | - colorlog==6.7.0 72 | - contourpy==1.0.7 73 | - cycler==0.11.0 74 | - distlib==0.3.6 75 | - filelock==3.9.0 76 | - fonttools==4.39.0 77 | - fsspec==2023.3.0 78 | - h5py==3.8.0 79 | - importlib-resources==5.12.0 80 | - joblib==1.2.0 81 | - kiwisolver==1.4.4 82 | - llvmlite==0.39.1 83 | - markdown-it-py==2.2.0 84 | - matplotlib==3.7.1 85 | - mdurl==0.1.2 86 | - mypy-extensions==1.0.0 87 | - nox==2022.11.21 88 | - numba==0.56.4 89 | - numpy==1.23.5 90 | - nvidia-cublas-cu11==11.10.3.66 91 | - nvidia-cuda-nvrtc-cu11==11.7.99 92 | - nvidia-cuda-runtime-cu11==11.7.99 93 | - nvidia-cudnn-cu11==8.5.0.96 94 | - opencv-python==4.7.0.72 95 | - pandas==1.5.3 96 | - pathspec==0.11.1 97 | - pillow==9.4.0 98 | - pyarrow==11.0.0 99 | - pyparsing==3.0.9 100 | - pyproj==3.4.1 101 | - pytz==2022.7.1 102 | - rich==13.3.2 103 | - scipy==1.10.1 104 | - tomli==2.0.1 105 | - torch==1.13.1 106 | - tqdm==4.65.0 107 | - transforms3d==0.4.1 108 | - upath==1.0 109 | - virtualenv==20.21.0 110 | - zipp==3.15.0 111 | -------------------------------------------------------------------------------- /src/utils/pose_pe.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import torch 3 | from torch import Tensor, nn 4 | from models.modules.pos_emb import PositionalEmbedding, PositionalEmbeddingRad 5 | 6 | 7 | class PosePE(nn.Module): 8 | def __init__(self, mode: str, pe_dim: int = 256, theta_xy: float = 1e3, theta_cs: float = 1e1): 9 | super().__init__() 10 | self.mode = mode 11 | if self.mode == "xy_dir": 12 | self.out_dim = 4 13 | elif self.mode == "mpa_pl": 14 | self.out_dim = 7 15 | elif self.mode == "pe_xy_dir": 16 | self.out_dim = pe_dim 17 | self.pe_xy = PositionalEmbedding(dim=pe_dim // 4, theta=theta_xy) 18 | self.pe_dir = PositionalEmbedding(dim=pe_dim // 4, theta=theta_cs) 19 | elif self.mode == "pe_xy_yaw": 20 | self.out_dim = pe_dim 21 | self.pe_xy = PositionalEmbedding(dim=pe_dim // 4, theta=theta_xy) 22 | self.pe_yaw = PositionalEmbeddingRad(dim=pe_dim // 2) 23 | else: 24 | raise NotImplementedError 25 | 26 | def forward(self, xy: Tensor, dir: Tensor): 27 | """ 28 | Args: input either dir or yaw. 29 | xy: [..., 2] 30 | dir: cos/sin [..., 2] or yaw [..., 1] 31 | 32 | Returns: 33 | pos_out: [..., self.out_dim] 34 | """ 35 | if self.mode == "xy_dir": 36 | if dir.shape[-1] == 1: 37 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1) 38 | pos_out = torch.cat([xy, dir], dim=-1) 39 | elif self.mode == "mpa_pl": 40 | if dir.shape[-1] == 1: 41 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1) 42 | pos_out = self.encode_polyline(xy, dir) 43 | elif self.mode == "pe_xy_dir": 44 | if dir.shape[-1] == 1: 45 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1) 46 | pos_out = torch.cat( 47 | [self.pe_xy(xy[..., 0]), self.pe_xy(xy[..., 1]), self.pe_dir(dir[..., 0]), self.pe_dir(dir[..., 1])], 48 | dim=-1, 49 | ) 50 | elif self.mode == "pe_xy_yaw": 51 | if dir.shape[-1] == 1: 52 | dir = dir.squeeze(-1) 53 | else: 54 | dir = torch.atan2(dir[..., 1], dir[..., 0]) 55 | pos_out = torch.cat([self.pe_xy(xy[..., 0]), self.pe_xy(xy[..., 1]), self.pe_yaw(dir)], dim=-1) 56 | return pos_out 57 | 58 | @staticmethod 59 | def encode_polyline(pos: Tensor, dir: Tensor) -> Tensor: 60 | """ 61 | Args: pos and dir with respect to the agent 62 | pos: [..., 2] 63 | dir: [..., 2] 64 | 65 | Returns: 66 | pl_feature: [..., 7] 67 | """ 68 | eps = torch.finfo(pos.dtype).eps 69 | # [n_scene, n_target, n_map, n_pl_node, 2] 70 | segments_start = pos 71 | segment_vec = dir 72 | # [n_scene, n_target, n_map, n_pl_node] 73 | segment_proj = (-segments_start * segment_vec).sum(-1) / ((segment_vec * segment_vec).sum(-1) + eps) 74 | # [n_scene, n_target, n_map, n_pl_node, 2] 75 | closest_points = segments_start + torch.clamp(segment_proj, min=0, max=1).unsqueeze(-1) * segment_vec 76 | # [n_scene, n_target, n_map, n_pl_node, 1] 77 | r_norm = torch.norm(closest_points, dim=-1, keepdim=True) 78 | segment_vec_norm = torch.norm(segment_vec, dim=-1, keepdim=True) 79 | pl_feature = torch.cat( 80 | [ 81 | r_norm, # 1 82 | closest_points / (r_norm + eps), # 2 83 | segment_vec / (segment_vec_norm + eps), # 2 84 | segment_vec_norm, # 1 85 | torch.norm(segments_start + segment_vec - closest_points, dim=-1, keepdim=True), # 1 86 | ], 87 | dim=-1, 88 | ) 89 | return pl_feature -------------------------------------------------------------------------------- /configs/model/acg_womd.yaml: -------------------------------------------------------------------------------- 1 | _target_: pl_modules.waymo_motion.WaymoMotion 2 | 3 | time_step_current: 10 4 | time_step_end: 90 5 | n_video_batch: 3 6 | interactive_challenge: False 7 | inference_cache_map: False 8 | inference_repeat_n: 1 9 | 10 | train_metric: 11 | _target_: models.metrics.nll.NllMetrics 12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd 13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2 14 | p_rand_train_agent: -1 # 0.2 15 | n_step_add_train_agent: [-1, -1, -1] # -1 to turn off 16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off 17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc 18 | w_pos: [1.0, 1.0, 1.0] 19 | w_yaw: [1.0, 1.0, 1.0] 20 | w_vel: [1.0, 1.0, 1.0] 21 | w_spd: [0, 0, 0] 22 | 23 | waymo_metric: 24 | _target_: models.metrics.waymo.WaymoMetrics 25 | n_max_pred_agent: 8 26 | 27 | pre_processing: 28 | agent_centric: 29 | _target_: data_modules.agent_centric.AgentCentricPreProcessing 30 | mask_invalid: False 31 | n_target: 8 32 | n_other: 48 33 | n_map: 512 34 | n_tl: 24 35 | ac_global: 36 | _target_: data_modules.ac_global.AgentCentricGlobal 37 | dropout_p_history: 0.15 38 | use_current_tl: False 39 | add_ohe: True 40 | pl_aggr: False 41 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw 42 | agent: xy_dir 43 | map: mpa_pl 44 | tl: mpa_pl 45 | 46 | post_processing: 47 | to_dict: 48 | _target_: data_modules.post_processing.ToDict 49 | predictions: ${...model.decoder.mlp_head.predictions} 50 | get_cov_mat: 51 | _target_: data_modules.post_processing.GetCovMat 52 | rho_clamp: 5.0 53 | std_min: -1.609 54 | std_max: 5.0 55 | waymo: 56 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 57 | k_pred: 6 58 | use_ade: True 59 | score_temperature: -1 60 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc 61 | gt_in_local: True 62 | 63 | model: 64 | _target_: models.ac_global.AgentCentricGlobal 65 | hidden_dim: 256 66 | n_decoders: 1 67 | tf_cfg: 68 | n_head: 4 69 | dropout_p: 0.1 70 | norm_first: True 71 | bias: True 72 | intra_class_encoder: 73 | add_learned_pe: True 74 | use_point_net: False 75 | n_layer_mlp: 3 76 | mlp_cfg: 77 | end_layer_activation: True 78 | use_layernorm: False 79 | use_batchnorm: False 80 | dropout_p: null 81 | n_layer_tf: -1 82 | decoder: 83 | _target_: models.ac_global.Decoder 84 | n_latent_query: 192 85 | n_layer_tf_all2all: 6 86 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n 87 | n_layer_tf_anchor: 8 88 | n_pred: 6 89 | latent_query: 90 | use_agent_type: False 91 | mode_emb: none # linear, mlp, add, none 92 | mode_init: xavier # uniform, xavier 93 | scale: 5.0 94 | multi_modal_anchors: 95 | use_agent_type: True 96 | mode_emb: none # linear, mlp, add, none 97 | mode_init: xavier # uniform, xavier 98 | scale: 5.0 99 | anchor_self_attn: True 100 | mlp_head: 101 | predictions: [pos, cov3] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox 102 | use_agent_type: False 103 | flatten_conf_head: False 104 | out_mlp_layernorm: False 105 | out_mlp_batchnorm: False 106 | n_step_future: 80 107 | use_vmap: True 108 | 109 | # * optimizer 110 | optimizer: 111 | _target_: torch.optim.AdamW 112 | lr: 2e-4 # 2e-4, 1e-5 113 | lr_scheduler: 114 | _target_: torch.optim.lr_scheduler.StepLR 115 | gamma: 0.5 116 | step_size: 20 117 | 118 | sub_womd: 119 | _target_: utils.submission.SubWOMD 120 | activate: False 121 | method_name: METHOD_NAME 122 | authors: [NAME1, NAME2] 123 | affiliation: AFFILIATION 124 | description: scr_womd 125 | method_link: METHOD_LINK 126 | account_name: ACCOUNT_NAME 127 | sub_av2: 128 | _target_: utils.submission.SubAV2 129 | activate: False 130 | -------------------------------------------------------------------------------- /src/data_modules/post_processing.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import List, Dict 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | 7 | class ToDict(nn.Module): 8 | def __init__(self, predictions: List[str]) -> None: 9 | super().__init__() 10 | self.dims = {"pred_pos": None, "pred_spd": None, "pred_vel": None, "pred_yaw_bbox": None, "pred_cov": None} 11 | 12 | pred_dim = 0 13 | if "pos" in predictions: 14 | self.dims["pred_pos"] = (pred_dim, pred_dim + 2) 15 | pred_dim += 2 16 | if "spd" in predictions: 17 | self.dims["pred_spd"] = (pred_dim, pred_dim + 1) 18 | pred_dim += 1 19 | if "vel" in predictions: 20 | self.dims["pred_vel"] = (pred_dim, pred_dim + 2) 21 | pred_dim += 2 22 | if "yaw_bbox" in predictions: 23 | self.dims["pred_yaw_bbox"] = (pred_dim, pred_dim + 1) 24 | pred_dim += 1 25 | if "cov1" in predictions: 26 | self.dims["pred_cov"] = (pred_dim, pred_dim + 1) 27 | pred_dim += 1 28 | elif "cov2" in predictions: 29 | self.dims["pred_cov"] = (pred_dim, pred_dim + 2) 30 | pred_dim += 2 31 | elif "cov3" in predictions: 32 | self.dims["pred_cov"] = (pred_dim, pred_dim + 3) 33 | pred_dim += 3 34 | 35 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 36 | """ 37 | Inputs: 38 | valid: [n_scene, n_target] 39 | conf: [n_decoders, n_scene, n_target, n_pred], not normalized! 40 | pred: [n_decoders, n_scene, n_target, n_pred, n_step_future, pred_dim] 41 | """ 42 | for k, v in self.dims.items(): 43 | if v is None: 44 | pred_dict[k] = None 45 | else: 46 | pred_dict[k] = pred_dict["pred"][..., v[0] : v[1]] 47 | # del pred_dict["pred"] 48 | return pred_dict 49 | 50 | 51 | class GetCovMat(nn.Module): 52 | def __init__(self, rho_clamp: float, std_min: float, std_max: float) -> None: 53 | super().__init__() 54 | self.rho_clamp = rho_clamp 55 | self.std_min = std_min 56 | self.std_max = std_max 57 | 58 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 59 | """ 60 | Inputs: 61 | pred_cov: [n_decoders, n_scene, n_target, n_pred, n_step_future, 1/2/3] 62 | 63 | Outputs: 64 | pred_cov: [n_decoders, n_scene, n_target, n_pred, n_step_future, 2, 2], in tril form. 65 | """ 66 | if pred_dict["pred_cov"] is not None: 67 | cov_shape = pred_dict["pred_cov"].shape 68 | cov_dof = cov_shape[-1] 69 | if cov_dof == 3: 70 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp() 71 | b = torch.clamp(pred_dict["pred_cov"][..., 1], min=self.std_min, max=self.std_max).exp() 72 | c = torch.clamp(pred_dict["pred_cov"][..., 2], min=-self.rho_clamp, max=self.rho_clamp) 73 | elif cov_dof == 2: 74 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp() 75 | b = torch.clamp(pred_dict["pred_cov"][..., 1], min=self.std_min, max=self.std_max).exp() 76 | c = torch.zeros_like(a) 77 | elif cov_dof == 1: 78 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp() 79 | b = a 80 | c = torch.zeros_like(a) 81 | 82 | pred_dict["pred_cov"] = torch.stack([a, torch.zeros_like(a), c, b], dim=-1).view(*cov_shape[:-1], 2, 2) 83 | 84 | return pred_dict 85 | 86 | 87 | class OffsetToKmeans(nn.Module): 88 | def __init__(self, *args, **kwargs) -> None: 89 | super().__init__() 90 | 91 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 92 | # not used, for back compatibility 93 | return pred_dict 94 | -------------------------------------------------------------------------------- /configs/model/scg_womd.yaml: -------------------------------------------------------------------------------- 1 | _target_: pl_modules.waymo_motion.WaymoMotion 2 | 3 | time_step_current: 10 4 | time_step_end: 90 5 | n_video_batch: 3 6 | interactive_challenge: False 7 | inference_cache_map: False 8 | inference_repeat_n: 1 9 | 10 | train_metric: 11 | _target_: models.metrics.nll.NllMetrics 12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd 13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2 14 | p_rand_train_agent: -1 # 0.2 15 | n_step_add_train_agent: [-1, 40, 40] # -1 to turn off 16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off 17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc 18 | w_pos: [1.0, 1.0, 1.0] 19 | w_yaw: [1.0, 1.0, 1.0] 20 | w_vel: [1.0, 1.0, 1.0] 21 | w_spd: [1.0, 1.0, 1.0] 22 | 23 | waymo_metric: 24 | _target_: models.metrics.waymo.WaymoMetrics 25 | n_max_pred_agent: 8 26 | 27 | pre_processing: 28 | scene_centric: 29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing 30 | gt_in_local: True 31 | mask_invalid: False 32 | global: 33 | _target_: data_modules.sc_global.SceneCentricGlobal 34 | dropout_p_history: -1 35 | use_current_tl: True 36 | add_ohe: True 37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet 38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw 39 | agent: pe_xy_dir 40 | map: pe_xy_dir 41 | tl: pe_xy_dir 42 | 43 | post_processing: 44 | to_dict: 45 | _target_: data_modules.post_processing.ToDict 46 | predictions: ${...model.decoder.mlp_head.predictions} 47 | get_cov_mat: 48 | _target_: data_modules.post_processing.GetCovMat 49 | rho_clamp: 5.0 50 | std_min: -1.609 51 | std_max: 5.0 52 | waymo: 53 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 54 | k_pred: 6 55 | use_ade: True 56 | score_temperature: -1 57 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc 58 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local} 59 | 60 | model: 61 | _target_: models.sc_global.SceneCentricGlobal 62 | hidden_dim: 256 63 | n_tgt_knn: 36 64 | dist_limit_map: 1500 65 | dist_limit_tl: 1000 66 | dist_limit_agent: [1500, 500, 1000] 67 | decoder_remove_ego_agent: False 68 | n_decoders: 1 69 | tf_cfg: 70 | n_head: 4 71 | dropout_p: 0.1 72 | norm_first: True 73 | bias: True 74 | intra_class_encoder: 75 | n_layer_mlp: 3 76 | mlp_cfg: 77 | end_layer_activation: True 78 | use_layernorm: False 79 | use_batchnorm: False 80 | dropout_p: null 81 | n_layer_tf_map: 6 82 | n_layer_tf_tl: -1 83 | n_layer_tf_agent: -1 84 | decoder: 85 | _target_: models.sc_global.Decoder 86 | n_pred: 6 87 | tf_n_layer: 2 88 | k_reinforce_tl: 2 89 | k_reinforce_agent: 4 90 | k_reinforce_all: -1 91 | k_reinforce_anchor: 10 92 | n_latent_query: -1 93 | latent_query: 94 | use_agent_type: False 95 | mode_emb: linear # linear, mlp, add, none 96 | mode_init: xavier # uniform, xavier 97 | scale: 5.0 98 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n 99 | multi_modal_anchors: 100 | use_agent_type: True 101 | mode_emb: linear # linear, mlp, add, none 102 | mode_init: xavier # uniform, xavier 103 | scale: 5.0 104 | anchor_self_attn: True 105 | mlp_head: 106 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox 107 | use_agent_type: False 108 | flatten_conf_head: False 109 | out_mlp_layernorm: False 110 | out_mlp_batchnorm: False 111 | n_step_future: 80 112 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor 113 | use_vmap: True 114 | 115 | # * optimizer 116 | optimizer: 117 | _target_: torch.optim.AdamW 118 | lr: 1e-4 # 2e-4, 1e-5 119 | lr_scheduler: 120 | _target_: torch.optim.lr_scheduler.StepLR 121 | gamma: 0.5 122 | step_size: 25 123 | 124 | sub_womd: 125 | _target_: utils.submission.SubWOMD 126 | activate: False 127 | method_name: METHOD_NAME 128 | authors: [NAME1, NAME2] 129 | affiliation: AFFILIATION 130 | description: scr_womd 131 | method_link: METHOD_LINK 132 | account_name: ACCOUNT_NAME 133 | sub_av2: 134 | _target_: utils.submission.SubAV2 135 | activate: False 136 | -------------------------------------------------------------------------------- /configs/model/scr_av2.yaml: -------------------------------------------------------------------------------- 1 | _target_: pl_modules.waymo_motion.WaymoMotion 2 | 3 | time_step_current: 49 4 | time_step_end: 109 5 | n_video_batch: 3 6 | interactive_challenge: False 7 | inference_cache_map: False 8 | inference_repeat_n: 1 9 | 10 | train_metric: 11 | _target_: models.metrics.nll.NllMetrics 12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd 13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2 14 | p_rand_train_agent: -1 # 0.2 15 | n_step_add_train_agent: [-1, -1, -1] # -1 to turn off 16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off 17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc 18 | w_pos: [1.0, 1.0, 1.0] 19 | w_yaw: [1.0, 1.0, 1.0] 20 | w_vel: [1.0, 1.0, 1.0] 21 | w_spd: [1.0, 1.0, 1.0] 22 | 23 | waymo_metric: 24 | _target_: models.metrics.waymo.WaymoMetrics 25 | n_max_pred_agent: 1 26 | 27 | pre_processing: 28 | scene_centric: 29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing 30 | gt_in_local: True 31 | mask_invalid: False 32 | relative: 33 | _target_: data_modules.sc_relative.SceneCentricRelative 34 | dropout_p_history: -1 35 | use_current_tl: True 36 | add_ohe: True 37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet 38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw 39 | agent: mpa_pl 40 | map: mpa_pl 41 | 42 | post_processing: 43 | to_dict: 44 | _target_: data_modules.post_processing.ToDict 45 | predictions: ${...model.decoder.mlp_head.predictions} 46 | get_cov_mat: 47 | _target_: data_modules.post_processing.GetCovMat 48 | rho_clamp: 5.0 49 | std_min: -1.609 50 | std_max: 5.0 51 | waymo: 52 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 53 | k_pred: 6 54 | use_ade: True 55 | score_temperature: 0.5 56 | mpa_nms_thresh: [] # veh, ped, cyc 57 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local} 58 | 59 | model: 60 | _target_: models.sc_relative.SceneCentricRelative 61 | hidden_dim: 256 62 | rpe_mode: pe_xy_yaw # xy_dir, pe_xy_dir, pe_xy_yaw 63 | n_tgt_knn: 36 64 | dist_limit_map: 1500 65 | dist_limit_tl: 1000 66 | dist_limit_agent: [1500, 500, 1000] 67 | decoder_remove_ego_agent: False 68 | n_decoders: 1 69 | tf_cfg: 70 | n_head: 4 71 | dropout_p: 0.1 72 | norm_first: True 73 | apply_q_rpe: False 74 | bias: True 75 | intra_class_encoder: 76 | n_layer_mlp: 3 77 | mlp_cfg: 78 | end_layer_activation: True 79 | use_layernorm: False 80 | use_batchnorm: False 81 | dropout_p: null 82 | n_layer_tf_map: 6 83 | n_layer_tf_tl: -1 84 | n_layer_tf_agent: -1 85 | decoder: 86 | _target_: models.sc_relative.Decoder 87 | n_pred: 6 88 | tf_n_layer: 2 89 | k_reinforce_tl: 2 90 | k_reinforce_agent: 4 91 | k_reinforce_all: -1 92 | k_reinforce_anchor: 10 93 | n_latent_query: -1 94 | latent_query: 95 | use_agent_type: False 96 | mode_emb: linear # linear, mlp, add, none 97 | mode_init: randn # uniform, xavier 98 | scale: 5.0 99 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n 100 | multi_modal_anchors: 101 | use_agent_type: True 102 | mode_emb: linear # linear, mlp, add, none 103 | mode_init: randn # uniform, xavier 104 | scale: 5.0 105 | anchor_self_attn: True 106 | mlp_head: 107 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox 108 | use_agent_type: False 109 | flatten_conf_head: False 110 | out_mlp_layernorm: False 111 | out_mlp_batchnorm: False 112 | n_step_future: 60 113 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor 114 | use_vmap: True 115 | 116 | # * optimizer 117 | optimizer: 118 | _target_: torch.optim.AdamW 119 | lr: 1e-4 # 2e-4, 1e-5 120 | lr_scheduler: 121 | _target_: torch.optim.lr_scheduler.StepLR 122 | gamma: 0.5 123 | step_size: 25 124 | 125 | sub_womd: 126 | _target_: utils.submission.SubWOMD 127 | activate: False 128 | method_name: METHOD_NAME 129 | authors: [NAME1, NAME2] 130 | affiliation: AFFILIATION 131 | description: scr_womd 132 | method_link: METHOD_LINK 133 | account_name: ACCOUNT_NAME 134 | sub_av2: 135 | _target_: utils.submission.SubAV2 136 | activate: False 137 | -------------------------------------------------------------------------------- /configs/model/scr_womd.yaml: -------------------------------------------------------------------------------- 1 | _target_: pl_modules.waymo_motion.WaymoMotion 2 | 3 | time_step_current: 10 4 | time_step_end: 90 5 | n_video_batch: 3 6 | interactive_challenge: False 7 | inference_cache_map: False 8 | inference_repeat_n: 1 9 | 10 | train_metric: 11 | _target_: models.metrics.nll.NllMetrics 12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd 13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2 14 | p_rand_train_agent: -1 # 0.2 15 | n_step_add_train_agent: [-1, 40, 40] # -1 to turn off 16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off 17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc 18 | w_pos: [1.0, 1.0, 1.0] 19 | w_yaw: [1.0, 1.0, 1.0] 20 | w_vel: [1.0, 1.0, 1.0] 21 | w_spd: [1.0, 1.0, 1.0] 22 | 23 | waymo_metric: 24 | _target_: models.metrics.waymo.WaymoMetrics 25 | n_max_pred_agent: 8 26 | 27 | pre_processing: 28 | scene_centric: 29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing 30 | gt_in_local: True 31 | mask_invalid: False 32 | relative: 33 | _target_: data_modules.sc_relative.SceneCentricRelative 34 | dropout_p_history: -1 35 | use_current_tl: True 36 | add_ohe: True 37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet 38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw 39 | agent: mpa_pl 40 | map: mpa_pl 41 | 42 | post_processing: 43 | to_dict: 44 | _target_: data_modules.post_processing.ToDict 45 | predictions: ${...model.decoder.mlp_head.predictions} 46 | get_cov_mat: 47 | _target_: data_modules.post_processing.GetCovMat 48 | rho_clamp: 5.0 49 | std_min: -1.609 50 | std_max: 5.0 51 | waymo: 52 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing 53 | k_pred: 6 54 | use_ade: True 55 | score_temperature: -1 56 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc 57 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local} 58 | 59 | model: 60 | _target_: models.sc_relative.SceneCentricRelative 61 | hidden_dim: 256 62 | rpe_mode: pe_xy_yaw # xy_dir, pe_xy_dir, pe_xy_yaw 63 | n_tgt_knn: 36 64 | dist_limit_map: 1500 65 | dist_limit_tl: 1000 66 | dist_limit_agent: [1500, 500, 1000] 67 | decoder_remove_ego_agent: False 68 | n_decoders: 1 69 | tf_cfg: 70 | n_head: 4 71 | dropout_p: 0.1 72 | norm_first: True 73 | apply_q_rpe: False 74 | bias: True 75 | intra_class_encoder: 76 | n_layer_mlp: 3 77 | mlp_cfg: 78 | end_layer_activation: True 79 | use_layernorm: False 80 | use_batchnorm: False 81 | dropout_p: null 82 | n_layer_tf_map: 6 83 | n_layer_tf_tl: -1 84 | n_layer_tf_agent: -1 85 | decoder: 86 | _target_: models.sc_relative.Decoder 87 | n_pred: 6 88 | tf_n_layer: 2 89 | k_reinforce_tl: 2 90 | k_reinforce_agent: 4 91 | k_reinforce_all: -1 92 | k_reinforce_anchor: 10 93 | n_latent_query: -1 94 | latent_query: 95 | use_agent_type: False 96 | mode_emb: linear # linear, mlp, add, none 97 | mode_init: xavier # uniform, xavier 98 | scale: 5.0 99 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n 100 | multi_modal_anchors: 101 | use_agent_type: True 102 | mode_emb: linear # linear, mlp, add, none 103 | mode_init: xavier # uniform, xavier 104 | scale: 5.0 105 | anchor_self_attn: True 106 | mlp_head: 107 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox 108 | use_agent_type: False 109 | flatten_conf_head: False 110 | out_mlp_layernorm: False 111 | out_mlp_batchnorm: False 112 | n_step_future: 80 113 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor 114 | use_vmap: True 115 | 116 | # * optimizer 117 | optimizer: 118 | _target_: torch.optim.AdamW 119 | lr: 1e-4 # 2e-4, 1e-5 120 | lr_scheduler: 121 | _target_: torch.optim.lr_scheduler.StepLR 122 | gamma: 0.5 123 | step_size: 25 124 | 125 | sub_womd: 126 | _target_: utils.submission.SubWOMD 127 | activate: False 128 | method_name: METHOD_NAME 129 | authors: [NAME1, NAME2] 130 | affiliation: AFFILIATION 131 | description: scr_womd 132 | method_link: METHOD_LINK 133 | account_name: ACCOUNT_NAME 134 | sub_av2: 135 | _target_: utils.submission.SubAV2 136 | activate: False 137 | -------------------------------------------------------------------------------- /src/models/modules/decoder_ensemble.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Tuple, List 3 | import hydra 4 | import torch 5 | from torch import nn, Tensor 6 | from omegaconf import DictConfig 7 | from functorch import combine_state_for_ensemble, vmap 8 | from .mlp import MLP 9 | 10 | 11 | class DecoderEnsemble(nn.Module): 12 | def __init__(self, n_decoders: int, decoder_cfg: DictConfig) -> None: 13 | super().__init__() 14 | self.use_vmap = decoder_cfg["use_vmap"] 15 | self.n_decoders = n_decoders 16 | if self.use_vmap and self.n_decoders > 1: 17 | _decoders = [hydra.utils.instantiate(decoder_cfg) for _ in range(n_decoders)] 18 | fmodel_decoders, params_decoders, buffers_decoders = combine_state_for_ensemble(_decoders) 19 | assert buffers_decoders == () 20 | self.v_model = vmap(fmodel_decoders, randomness="different") 21 | [p.requires_grad_() for p in params_decoders] 22 | self.params_decoders = nn.ParameterList(params_decoders) 23 | else: 24 | self._decoders = nn.ModuleList([hydra.utils.instantiate(decoder_cfg) for _ in range(n_decoders)]) 25 | 26 | def forward(self, **kwargs) -> Tuple[Tensor, Tensor]: 27 | """ 28 | Returns: 29 | conf: [n_decoders, n_scene, n_agent, n_pred] 30 | pred: [n_decoders, n_scene, n_agent, n_pred, n_step_future, pred_dim] 31 | """ 32 | if self.use_vmap and self.n_decoders > 1: 33 | conf, pred = self.v_model(tuple(self.params_decoders), (), **kwargs) 34 | else: 35 | conf, pred = [], [] 36 | for decoder in self._decoders: 37 | c, p = decoder(**kwargs) 38 | conf.append(c) 39 | pred.append(p) 40 | conf = torch.stack(conf, dim=0) 41 | pred = torch.stack(pred, dim=0) 42 | return conf, pred 43 | 44 | 45 | class MLPHead(nn.Module): 46 | def __init__( 47 | self, 48 | hidden_dim: int, 49 | use_vmap: bool, 50 | n_step_future: int, 51 | out_mlp_layernorm: bool, 52 | out_mlp_batchnorm: bool, 53 | use_agent_type: bool, 54 | predictions: List[str], 55 | **kwargs, 56 | ) -> None: 57 | super().__init__() 58 | self.use_agent_type = use_agent_type 59 | self.n_step_future = n_step_future 60 | 61 | self.pred_dim = 0 62 | if "pos" in predictions: 63 | self.pred_dim += 2 64 | if "spd" in predictions: 65 | self.pred_dim += 1 66 | if "vel" in predictions: 67 | self.pred_dim += 2 68 | if "yaw_bbox" in predictions: 69 | self.pred_dim += 1 70 | if "cov1" in predictions: 71 | self.pred_dim += 1 72 | elif "cov2" in predictions: 73 | self.pred_dim += 2 74 | elif "cov3" in predictions: 75 | self.pred_dim += 3 76 | 77 | _d = hidden_dim * 2 78 | cfg_mlp_pred = { 79 | "fc_dims": [hidden_dim, _d, _d, self.n_step_future * self.pred_dim], 80 | "end_layer_activation": False, 81 | "use_layernorm": out_mlp_layernorm, 82 | "use_batchnorm": out_mlp_batchnorm, 83 | } 84 | cfg_mlp_conf = { 85 | "end_layer_activation": False, 86 | "use_layernorm": out_mlp_layernorm, 87 | "use_batchnorm": out_mlp_batchnorm, 88 | } 89 | n_mlp_head = 3 if use_agent_type else 1 90 | self.mlp_pred = MLPEnsemble(n_decoders=n_mlp_head, decoder_cfg=cfg_mlp_pred, use_vmap=use_vmap) 91 | 92 | cfg_mlp_conf["fc_dims"] = [hidden_dim, _d, _d, 1] 93 | self.mlp_conf = MLPEnsemble(n_decoders=n_mlp_head, decoder_cfg=cfg_mlp_conf, use_vmap=use_vmap) 94 | 95 | def forward(self, valid: Tensor, emb: Tensor, agent_type: Tensor) -> Tuple[Tensor, Tensor]: 96 | """ 97 | Args: 98 | valid: [n_scene, n_agent] 99 | emb: [n_scene, n_agent, n_pred, hidden_dim] 100 | agent_type: [n_scene, n_agent, 3] 101 | 102 | Returns: 103 | conf: [n_scene, n_agent, n_pred] 104 | pred: [n_scene, n_agent, n_pred, n_step_future, pred_dim] 105 | """ 106 | pred = self.mlp_pred(x=emb, valid_mask=valid.unsqueeze(-1)) # [1/3, n_scene, n_agent, n_pred, 400] 107 | conf = self.mlp_conf(x=emb, valid_mask=valid.unsqueeze(-1)).squeeze(-1) # [1/3, n_scene, n_agent, n_pred] 108 | 109 | if self.use_agent_type: 110 | _type = agent_type.movedim(-1, 0).unsqueeze(-1) # [3, n_scene, n_agent, 1] 111 | pred = (pred * _type.unsqueeze(-1)).sum(0) 112 | conf = (conf * _type).sum(0) 113 | else: 114 | pred = pred.squeeze(0) 115 | conf = conf.squeeze(0) 116 | 117 | n_scene, n_agent, n_pred = conf.shape 118 | return conf, pred.view(n_scene, n_agent, n_pred, self.n_step_future, self.pred_dim) 119 | 120 | 121 | class MLPEnsemble(nn.Module): 122 | def __init__(self, n_decoders: int, decoder_cfg: DictConfig, use_vmap: bool) -> None: 123 | super().__init__() 124 | self.use_vmap = use_vmap 125 | self.n_decoders = n_decoders 126 | if self.use_vmap and self.n_decoders > 1: 127 | _decoders = [MLP(**decoder_cfg) for _ in range(n_decoders)] 128 | fmodel_decoders, params_decoders, buffers_decoders = combine_state_for_ensemble(_decoders) 129 | assert buffers_decoders == () 130 | self.v_model = vmap(fmodel_decoders, randomness="different") 131 | [p.requires_grad_() for p in params_decoders] 132 | self.params_decoders = nn.ParameterList(params_decoders) 133 | else: 134 | self._decoders = nn.ModuleList([MLP(**decoder_cfg) for _ in range(n_decoders)]) 135 | 136 | def forward(self, **kwargs) -> Tensor: 137 | if self.use_vmap and self.n_decoders > 1: 138 | out = self.v_model(tuple(self.params_decoders), (), **kwargs) 139 | else: 140 | out = [] 141 | for decoder in self._decoders: 142 | x = decoder(**kwargs) 143 | out.append(x) 144 | out = torch.stack(out, dim=0) 145 | return out 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HPTR 2 | 3 |

4 | HPTR realizes real-time and on-board motion prediction without sacrificing the performance. 5 |
HPTR realizes real-time and on-board motion prediction without sacrificing the performance.
To efficiently predict the multi-modal future of numerous agents (a), HPTR minimizes the computational overhead by: (b) Sharing contexts among target agents. (c) Reusing static contexts during online inference. (d) Avoiding expensive post-processing and ensembling. 6 |

7 | 8 | > **Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding** 9 | > [Zhejun Zhang](https://zhejz.github.io/), [Alexander Liniger](https://alexliniger.github.io/), [Christos Sakaridis](https://people.ee.ethz.ch/~csakarid/), Fisher Yu and [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjcxLC0xOTcxNDY1MTc4.html).
10 | > 11 | > [NeurIPS 2023](https://neurips.cc/virtual/2023/poster/71285)
12 | > [Project Website](https://zhejz.github.io/hptr)
13 | > [arXiv Paper](https://arxiv.org/abs/2310.12970) 14 | 15 | ```bibtex 16 | @inproceedings{zhang2023hptr, 17 | title = {Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding}, 18 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 19 | author = {Zhang, Zhejun and Liniger, Alexander and Sakaridis, Christos and Yu, Fisher and Van Gool, Luc}, 20 | year = {2023}, 21 | } 22 | ``` 23 | 24 | ## Updates 25 | - The model checkpoint for Argoverse 2 is available at this wandb artifact `zhejun/hptr_av2_ckpt/av2_ckpt:v0` ([wandb project](https://wandb.ai/zhejun/hptr_av2_ckpt)). 26 | - HPTR ranks 1st in minADE and 2nd in minFDE on the [WOMD Motion Prediction Leaderboard 2023](https://waymo.com/open/challenges/2023/motion-prediction/). 27 | 28 | 29 | ## Setup Environment 30 | - Create the main [conda](https://docs.conda.io/en/latest/miniconda.html) environment by running `conda env create -f environment.yml`. 31 | - Install [Waymo Open Dataset](https://github.com/waymo-research/waymo-open-dataset) API manually because the pip installation of version 1.5.2 is not supported on some linux, e.g. CentOS. Run 32 | ``` 33 | conda activate hptr 34 | wget https://files.pythonhosted.org/packages/85/1d/4cdd31fc8e88c3d689a67978c41b28b6e242bd4fe6b080cf8c99663b77e4/waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl 35 | mv waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl 36 | pip install --no-deps waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl 37 | rm waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl 38 | ``` 39 | - Create the conda environment for packing [Argoverse 2 Motion Forecasting Dataset](https://www.argoverse.org/av2.html#forecasting-link) by running `conda env create -f env_av2.yml`. 40 | - We use [WandB](https://wandb.ai/) for logging. You can register an account for free. 41 | - Be aware 42 | - We use 4 *NVIDIA RTX 2080Ti* for training and a single 2080Ti for evaluation. The training takes at least 5 days to converge. 43 | - This repo contains the experiments for the [Waymo Motion Prediction Challenge](https://waymo.com/open/challenges/2023/motion-prediction/) and the [Argoverse 2: Motion Forecasting Competition](https://eval.ai/web/challenges/challenge-page/1719/submission) 44 | - We cannot share pre-trained models according to the [terms](https://waymo.com/open/terms) of the Waymo Open Motion Dataset. 45 | 46 | ## Prepare Datasets 47 | - Waymo Open Motion Dataset (WOMD): 48 | - Download the [Waymo Open Motion Dataset](https://waymo.com/open/data/motion/). We use v1.2. 49 | - Run `python src/pack_h5_womd.py` or use [bash/pack_h5.sh](bash/pack_h5.sh) to pack the dataset into h5 files to accelerate data loading during the training and evaluation. 50 | - You should pack three datasets: `training`, `validation` and `testing`. Packing the `training` dataset takes around 2 days. For `validation` and `testing` it should take a few hours. 51 | - Argoverse 2 Motion Forecasting Dataset (AV2): 52 | - Download the [Argoverse 2 Motion Forecasting Dataset](https://www.argoverse.org/av2.html#download-link). 53 | - Run `python src/pack_h5_av2.py` or use [bash/pack_h5.sh](bash/pack_h5.sh) to pack the dataset into h5 files to accelerate data loading during the training and evaluation. 54 | - You should pack three datasets: `training`, `validation` and `testing`. Each dataset should take a few hours. 55 | 56 | 57 | ## Training, Validation, Testing and Submission 58 | Please refer to [bash/train.sh](bash/train.sh) for the training. 59 | 60 | Once the training converges, you can use the saved checkpoints (WandB artifacts) to do validation and testing, please refer to [bash/submission.sh](bash/submission.sh) for more details. 61 | 62 | Once the validation/testing is finished, download the file `womd_K6.tar.gz` from WandB and submit to the [Waymo Motion Prediction Leaderboard](https://waymo.com/open/challenges/2023/motion-prediction/). For AV2, download the file `av2_K6.parquet` from WandB and submit to the [Argoverse 2 Motion Forecasting Competition](https://eval.ai/web/challenges/challenge-page/1719/submission). 63 | 64 | 65 | ## Performance 66 | 67 | Our submission to the [WOMD leaderboard](https://waymo.com/open/challenges/2023/motion-prediction/) is found here [here](https://waymo.com/open/challenges/entry/?challenge=MOTION_PREDICTION&challengeId=MOTION_PREDICTION_2023&emailId=5ea7a3eb-7337×tamp=1684068775971677). 68 | 69 | Our submission to the [AV2 leaderboard](https://eval.ai/web/challenges/challenge-page/1719/overview) is found here [here](https://eval.ai/web/challenges/challenge-page/1719/leaderboard/4098). 70 | 71 | ## Ablation Models 72 | 73 | Please refer to [docs/ablation_models.md](docs/ablation_models.md) for the configurations of ablation models. 74 | 75 | Specifically you can find the [Wayformer](https://arxiv.org/abs/2207.05844) and [SceneTransformer](https://arxiv.org/abs/2106.08417) based on our backbone. You can also try out different hierarchical architectures. 76 | 77 | ## License 78 | 79 | This software is made available for non-commercial use under a creative commons [license](LICENSE). You can find a summary of the license [here](https://creativecommons.org/licenses/by-nc/4.0/). 80 | 81 | ## Acknowledgement 82 | 83 | This work is funded by Toyota Motor Europe via the research project [TRACE-Zurich](https://trace.ethz.ch) (Toyota Research on Automated Cars Europe). -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional 3 | from pathlib import Path 4 | import wandb 5 | from pytorch_lightning import Callback, Trainer 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger 8 | from pytorch_lightning.utilities import rank_zero_only 9 | 10 | 11 | def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]: 12 | """Safely get Weights&Biases logger from Trainer.""" 13 | 14 | if isinstance(trainer.logger, WandbLogger): 15 | return trainer.logger 16 | 17 | if isinstance(trainer.logger, LoggerCollection): 18 | for logger in trainer.logger: 19 | if isinstance(logger, WandbLogger): 20 | return logger 21 | 22 | return None 23 | # raise Exception("You are using wandb related callback, but WandbLogger was not found for some reason...") 24 | 25 | 26 | class ModelCheckpointWB(ModelCheckpoint): 27 | def __init__(self, save_only_best=False, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | self.save_only_best = save_only_best 30 | 31 | def save_checkpoint(self, trainer) -> None: 32 | super().save_checkpoint(trainer) 33 | if not hasattr(self, "_logged_model_time"): 34 | self._logged_model_time = {} 35 | logger = get_wandb_logger(trainer) 36 | if self.current_score is None: 37 | self.current_score = trainer.callback_metrics.get(self.monitor) 38 | if logger is not None: 39 | self._scan_and_log_checkpoints(logger) 40 | 41 | @rank_zero_only 42 | def _scan_and_log_checkpoints(self, wb_logger: WandbLogger) -> None: 43 | if self.save_only_best: 44 | self._log_best_checkpoint(wb_logger) 45 | else: 46 | self._log_all_checkpoints(wb_logger) 47 | 48 | def _log_all_checkpoints(self, wb_logger: WandbLogger) -> None: 49 | # adapted from pytorch_lightning 1.4.0: loggers/wandb.py 50 | checkpoints = { 51 | self.last_model_path: self.current_score, 52 | self.best_model_path: self.best_model_score, 53 | } 54 | checkpoints = sorted( 55 | (Path(p).stat().st_mtime, p, s) 56 | for p, s in checkpoints.items() 57 | if Path(p).is_file() 58 | ) 59 | checkpoints = [ 60 | c 61 | for c in checkpoints 62 | if c[1] not in self._logged_model_time.keys() 63 | or self._logged_model_time[c[1]] < c[0] 64 | ] 65 | # log iteratively all new checkpoints 66 | for t, p, s in checkpoints: 67 | metadata = { 68 | "score": s.item(), 69 | "original_filename": Path(p).name, 70 | "ModelCheckpoint": { 71 | k: getattr(self, k) 72 | for k in [ 73 | "monitor", 74 | "mode", 75 | "save_last", 76 | "save_top_k", 77 | "save_weights_only", 78 | "_every_n_train_steps", 79 | "_every_n_val_epochs", 80 | ] 81 | # ensure it does not break if `ModelCheckpoint` args change 82 | if hasattr(self, k) 83 | }, 84 | } 85 | artifact = wandb.Artifact( 86 | name=wb_logger.experiment.id, type="model", metadata=metadata 87 | ) 88 | artifact.add_file(p, name="model.ckpt") 89 | aliases = ["latest", "best"] if p == self.best_model_path else ["latest"] 90 | wb_logger.experiment.log_artifact(artifact, aliases=aliases) 91 | # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) 92 | self._logged_model_time[p] = t 93 | 94 | def _log_best_checkpoint(self, wb_logger: WandbLogger) -> None: 95 | # Only consider the best model checkpoint for logging 96 | if Path(self.best_model_path).is_file(): 97 | best_model_mtime = Path(self.best_model_path).stat().st_mtime 98 | # Check if the best model checkpoint is new or has been updated 99 | if ( 100 | self.best_model_path not in self._logged_model_time 101 | or self._logged_model_time[self.best_model_path] < best_model_mtime 102 | ): 103 | # Attempt to delete the previous best artifact if it exists 104 | try: 105 | api = wandb.Api() 106 | runs = api.run( 107 | f"{wb_logger.experiment.entity}/{wb_logger.experiment.project}/{wb_logger.experiment.id}" 108 | ) 109 | for artifact in runs.logged_artifacts(): 110 | if "best" in artifact.aliases: 111 | artifact.delete(delete_aliases=True) 112 | print("Deleted previous best artifact.") 113 | break 114 | else: 115 | print("No previous best artifact found to delete.") 116 | except Exception as e: 117 | print(f"Could not delete previous best artifact: {e}") 118 | # Log the best model checkpoint 119 | metadata = { 120 | "score": self.best_model_score.item(), 121 | "original_filename": Path(self.best_model_path).name, 122 | "ModelCheckpoint": { 123 | k: getattr(self, k) 124 | for k in [ 125 | "monitor", 126 | "mode", 127 | "save_last", 128 | "save_top_k", 129 | "save_weights_only", 130 | "_every_n_train_steps", 131 | "_every_n_val_epochs", 132 | ] 133 | if hasattr(self, k) 134 | }, 135 | } 136 | artifact = wandb.Artifact( 137 | name=wb_logger.experiment.id, type="model", metadata=metadata 138 | ) 139 | artifact.add_file(self.best_model_path, name="model.ckpt") 140 | wb_logger.experiment.log_artifact(artifact, aliases=["best"]) 141 | # Update the log timestamp for this model checkpoint 142 | self._logged_model_time[self.best_model_path] = best_model_mtime 143 | 144 | 145 | class WatchModel(Callback): 146 | """Make wandb watch model at the beginning of the run.""" 147 | 148 | def __init__(self, log: str = "gradients", log_freq: int = 100): 149 | self._log = log 150 | self._log_freq = log_freq 151 | 152 | def on_train_start(self, trainer, pl_module): 153 | logger = get_wandb_logger(trainer) 154 | logger.watch(model=trainer.model, log=self._log, log_freq=self._log_freq) 155 | -------------------------------------------------------------------------------- /src/utils/transform_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | import transforms3d 6 | from typing import Union 7 | 8 | 9 | def cast_rad(angle: Union[float, np.ndarray, Tensor]) -> Union[float, np.ndarray, Tensor]: 10 | """Cast angle such that they are always in the [-pi, pi) range.""" 11 | return (angle + np.pi) % (2 * np.pi) - np.pi 12 | 13 | 14 | def _rotation33_as_yaw(rotation: np.ndarray) -> float: 15 | """Compute the yaw component of given 3x3 rotation matrix. 16 | 17 | Args: 18 | rotation (np.ndarray): 3x3 rotation matrix (np.float64 dtype recommended) 19 | 20 | Returns: 21 | float: yaw rotation in radians 22 | """ 23 | return transforms3d.euler.mat2euler(rotation)[2] 24 | 25 | 26 | def _yaw_as_rotation33(yaw: float) -> np.ndarray: 27 | """Create a 3x3 rotation matrix from given yaw. 28 | The rotation is counter-clockwise and it is equivalent to: 29 | [cos(yaw), -sin(yaw), 0.0], 30 | [sin(yaw), cos(yaw), 0.0], 31 | [0.0, 0.0, 1.0], 32 | 33 | Args: 34 | yaw (float): yaw rotation in radians 35 | 36 | Returns: 37 | np.ndarray: 3x3 rotation matrix 38 | """ 39 | return transforms3d.euler.euler2mat(0, 0, yaw) 40 | 41 | 42 | def get_so2_from_se2(transform_se3: np.ndarray) -> np.ndarray: 43 | """Gets rotation component in SO(2) from transformation in SE(2). 44 | 45 | Args: 46 | transform_se3: se2 transformation. 47 | 48 | Returns: 49 | rotation component in so2 50 | """ 51 | rotation = np.eye(3, dtype=np.float64) 52 | rotation[:2, :2] = transform_se3[:2, :2] 53 | return rotation 54 | 55 | 56 | def get_yaw_from_se2(transform_se3: np.ndarray) -> float: 57 | """Gets yaw from transformation in SE(2). 58 | 59 | Args: 60 | transform_se3: se2 transformation. 61 | 62 | Returns: 63 | yaw 64 | """ 65 | return _rotation33_as_yaw(get_so2_from_se2(transform_se3)) 66 | 67 | 68 | def transform_points(points: np.ndarray, transf_matrix: np.ndarray) -> np.ndarray: 69 | """Transform points using transformation matrix. 70 | Note this function assumes points.shape[1] == matrix.shape[1] - 1, which means that the last 71 | row in the matrix does not influence the final result. 72 | For 2D points only the first 2x3 part of the matrix will be used. 73 | 74 | Args: 75 | points (np.ndarray): Input points (Nx2) or (Nx3). 76 | transf_matrix (np.ndarray): np.float64, 3x3 or 4x4 transformation matrix for 2D and 3D input respectively 77 | 78 | Returns: 79 | np.ndarray: array of shape (N,2) for 2D input points, or (N,3) points for 3D input points 80 | """ 81 | assert len(points.shape) == len(transf_matrix.shape) == 2, ( 82 | f"dimensions mismatch, both points ({points.shape}) and " 83 | f"transf_matrix ({transf_matrix.shape}) needs to be 2D numpy ndarrays." 84 | ) 85 | assert ( 86 | transf_matrix.shape[0] == transf_matrix.shape[1] 87 | ), f"transf_matrix ({transf_matrix.shape}) should be a square matrix." 88 | 89 | if points.shape[1] not in [2, 3]: 90 | raise AssertionError(f"Points input should be (N, 2) or (N, 3) shape, received {points.shape}") 91 | 92 | assert points.shape[1] == transf_matrix.shape[1] - 1, "points dim should be one less than matrix dim" 93 | 94 | points_transformed = (points @ transf_matrix.T[:-1, :-1]) + transf_matrix[:-1, -1] 95 | 96 | return points_transformed.astype(points.dtype) 97 | 98 | 99 | def get_transformation_matrix(agent_translation_m: np.ndarray, agent_yaw: float) -> np.ndarray: 100 | """Get transformation matrix from world to vehicle frame 101 | 102 | Args: 103 | agent_translation_m (np.ndarray): (x, y) position of the vehicle in world frame 104 | agent_yaw (float): rotation of the vehicle in the world frame 105 | 106 | Returns: 107 | (np.ndarray) transformation matrix from world to vehicle 108 | """ 109 | 110 | # Translate world to ego by applying the negative ego translation. 111 | world_to_agent_in_2d = np.eye(3, dtype=np.float64) 112 | world_to_agent_in_2d[0:2, 2] = -agent_translation_m[0:2] 113 | 114 | # Rotate counter-clockwise by negative yaw to align world such that ego faces right. 115 | world_to_agent_in_2d = _yaw_as_rotation33(-agent_yaw) @ world_to_agent_in_2d 116 | 117 | return world_to_agent_in_2d 118 | 119 | 120 | # transformation for torch 121 | def torch_rad2rot(rad: Tensor) -> Tensor: 122 | """ 123 | Args: 124 | rad: [n_batch] or [n_scene, n_agent] or etc. 125 | 126 | Returns: 127 | rot_mat: [{rad.shape}, 2, 2] 128 | """ 129 | _cos = torch.cos(rad) 130 | _sin = torch.sin(rad) 131 | return torch.stack([torch.stack([_cos, -_sin], dim=-1), torch.stack([_sin, _cos], dim=-1)], dim=-2) 132 | 133 | 134 | def torch_sincos2rot(in_sin: Tensor, in_cos: Tensor) -> Tensor: 135 | """ 136 | Args: 137 | in_sin: [n_batch] or [n_scene, n_agent] or etc. 138 | in_cos: [n_batch] or [n_scene, n_agent] or etc. 139 | 140 | Returns: 141 | rot_mat: [{in_sin.shape}, 2, 2] 142 | """ 143 | return torch.stack([torch.stack([in_cos, -in_sin], dim=-1), torch.stack([in_sin, in_cos], dim=-1)], dim=-2) 144 | 145 | 146 | def torch_pos2local(in_pos: Tensor, local_pos: Tensor, local_rot: Tensor) -> Tensor: 147 | """Transform M position to the local coordinates. 148 | 149 | Args: 150 | in_pos: [..., M, 2] 151 | local_pos: [..., 1, 2] 152 | local_rot: [..., 2, 2] 153 | 154 | Returns: 155 | out_pos: [..., M, 2] 156 | """ 157 | return torch.matmul(in_pos - local_pos, local_rot) 158 | 159 | 160 | def torch_pos2global(in_pos: Tensor, local_pos: Tensor, local_rot: Tensor) -> Tensor: 161 | """Reverse torch_pos2local 162 | 163 | Args: 164 | in_pos: [..., M, 2] 165 | local_pos: [..., 1, 2] 166 | local_rot: [..., 2, 2] 167 | 168 | Returns: 169 | out_pos: [..., M, 2] 170 | """ 171 | return torch.matmul(in_pos.double(), local_rot.transpose(-1, -2).double()) + local_pos.double() 172 | 173 | 174 | def torch_dir2local(in_dir: Tensor, local_rot: Tensor) -> Tensor: 175 | """Transform M dir to the local coordinates. 176 | 177 | Args: 178 | in_dir: [..., M, 2] 179 | local_rot: [..., 2, 2] 180 | 181 | Returns: 182 | out_dir: [..., M, 2] 183 | """ 184 | return torch.matmul(in_dir, local_rot) 185 | 186 | 187 | def torch_dir2global(in_dir: Tensor, local_rot: Tensor) -> Tensor: 188 | """Reverse torch_dir2local 189 | 190 | Args: 191 | in_dir: [..., M, 2] 192 | local_rot: [..., 2, 2] 193 | 194 | Returns: 195 | out_dir: [..., M, 2] 196 | """ 197 | return torch.matmul(in_dir, local_rot.transpose(-1, -2)) 198 | 199 | 200 | def torch_rad2local(in_rad: Tensor, local_rad: Tensor, cast: bool = True) -> Tensor: 201 | """Transform M rad angles to the local coordinates. 202 | 203 | Args: 204 | in_rad: [..., M] 205 | local_rad: [...] 206 | 207 | Returns: 208 | out_rad: [..., M] 209 | """ 210 | out_rad = in_rad - local_rad.unsqueeze(-1) 211 | if cast: 212 | out_rad = cast_rad(out_rad) 213 | return out_rad 214 | 215 | 216 | def torch_rad2global(in_rad: Tensor, local_rad: Tensor) -> Tensor: 217 | """Reverse torch_rad2local 218 | 219 | Args: 220 | in_rad: [..., M] 221 | local_rad: [...] 222 | 223 | Returns: 224 | out_rad: [..., M] 225 | """ 226 | return cast_rad(in_rad + local_rad.unsqueeze(-1)) 227 | -------------------------------------------------------------------------------- /src/data_modules/scene_centric.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Dict 3 | from omegaconf import DictConfig 4 | import torch 5 | from torch import nn, Tensor 6 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_dir2local, torch_rad2local 7 | 8 | 9 | class SceneCentricPreProcessing(nn.Module): 10 | def __init__(self, time_step_current: int, data_size: DictConfig, gt_in_local: bool, mask_invalid: bool) -> None: 11 | super().__init__() 12 | self.step_current = time_step_current 13 | self.n_step_hist = time_step_current + 1 14 | self.gt_in_local = gt_in_local 15 | self.mask_invalid = mask_invalid 16 | self.model_kwargs = {"gt_in_local": gt_in_local, "agent_centric": False} 17 | 18 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: 19 | """ 20 | Args: scene-centric Dict 21 | # agent states 22 | "agent/valid": [n_scene, n_step, n_agent], bool, 23 | "agent/pos": [n_scene, n_step, n_agent, 2], float32 24 | "agent/vel": [n_scene, n_step, n_agent, 2], float32, v_x, v_y 25 | "agent/spd": [n_scene, n_step, n_agent, 1], norm of vel, signed using yaw_bbox and vel_xy 26 | "agent/acc": [n_scene, n_step, n_agent, 1], m/s2, acc[t] = (spd[t]-spd[t-1])/dt 27 | "agent/yaw_bbox": [n_scene, n_step, n_agent, 1], float32, yaw of the bbox heading 28 | "agent/yaw_rate": [n_scene, n_step, n_agent, 1], rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt 29 | # agent attributes 30 | "agent/type": [n_scene, n_agent, 3], bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 31 | "agent/role": [n_scene, n_agent, 3], bool [sdc=0, interest=1, predict=2] 32 | "agent/size": [n_scene, n_agent, 3], float32: [length, width, height] 33 | # map polylines 34 | "map/valid": [n_scene, n_pl, n_pl_node], bool 35 | "map/type": [n_scene, n_pl, 11], bool one_hot 36 | "map/pos": [n_scene, n_pl, n_pl_node, 2], float32 37 | "map/dir": [n_scene, n_pl, n_pl_node, 2], float32 38 | # traffic lights 39 | "tl_stop/valid": [n_scene, n_step, n_tl_stop], bool 40 | "tl_stop/state": [n_scene, n_step, n_tl_stop, 5], bool one_hot 41 | "tl_stop/pos": [n_scene, n_step, n_tl_stop, 2], x,y 42 | "tl_stop/dir": [n_scene, n_step, n_tl_stop, 2], x,y 43 | 44 | Returns: scene-centric Dict, masked according to valid 45 | # (ref) reference information for transform back to global coordinate and submission to waymo 46 | "ref/pos": [n_scene, n_agent, 1, 2] 47 | "ref/yaw": [n_scene, n_agent, 1] 48 | "ref/rot": [n_scene, n_agent, 2, 2] 49 | "ref/role": [n_scene, n_agent, 3] 50 | "ref/type": [n_scene, n_agent, 3] 51 | # (gt) ground-truth agent future for training, not available for testing 52 | "gt/valid": [n_scene, n_agent, n_step_future], bool 53 | "gt/pos": [n_scene, n_agent, n_step_future, 2] 54 | "gt/spd": [n_scene, n_agent, n_step_future, 1] 55 | "gt/vel": [n_scene, n_agent, n_step_future, 2] 56 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1] 57 | "gt/cmd": [n_scene, n_agent, 8] 58 | # (sc) scene-centric agents states 59 | "sc/agent_valid": [n_scene, n_agent, n_step_hist] 60 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2] 61 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2] 62 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1] 63 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1] 64 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1] 65 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1] 66 | # agents attributes 67 | "sc/agent_type": [n_scene, n_agent, 3] 68 | "sc/agent_role": [n_scene, n_agent, 3] 69 | "sc/agent_size": [n_scene, n_agent, 3] 70 | # map polylines 71 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool 72 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot 73 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32 74 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32 75 | # traffic lights 76 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool 77 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot 78 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y 79 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y 80 | """ 81 | prefix = "" if self.training else "history/" 82 | 83 | # ! prepare "ref/" 84 | batch["ref/type"] = batch[prefix + "agent/type"] 85 | batch["ref/role"] = batch[prefix + "agent/role"] 86 | 87 | last_valid_step = ( 88 | self.step_current - batch[prefix + "agent/valid"][:, : self.step_current + 1].flip(1).max(1)[1] 89 | ) # [n_scene, n_agent] 90 | i_scene = torch.arange(batch["ref/type"].shape[0]).unsqueeze(1) # [n_scene, 1] 91 | i_agent = torch.arange(batch["ref/type"].shape[1]).unsqueeze(0) # [1, n_agent] 92 | ref_pos = batch[prefix + "agent/pos"][i_scene, last_valid_step, i_agent].unsqueeze(-2).contiguous() 93 | ref_yaw = batch[prefix + "agent/yaw_bbox"][i_scene, last_valid_step, i_agent] 94 | ref_rot = torch_rad2rot(ref_yaw.squeeze(-1)) 95 | batch["ref/pos"] = ref_pos 96 | batch["ref/yaw"] = ref_yaw 97 | batch["ref/rot"] = ref_rot 98 | 99 | # ! prepare agents states 100 | # [n_scene, n_step, n_agent, ...] -> [n_scene, n_agent, n_step_hist, ...] 101 | for k in ("valid", "pos", "vel", "spd", "acc", "yaw_bbox", "yaw_rate"): 102 | batch[f"sc/agent_{k}"] = batch[f"{prefix}agent/{k}"][:, : self.n_step_hist].transpose(1, 2) 103 | 104 | # ! prepare agents attributes 105 | for k in ("type", "role", "size"): 106 | batch[f"sc/agent_{k}"] = batch[f"{prefix}agent/{k}"] 107 | 108 | # ! training/validation time, prepare "gt/" for losses 109 | if "agent/valid" in batch.keys(): 110 | batch["gt/cmd"] = batch["agent/cmd"] 111 | for k in ("valid", "spd", "pos", "vel", "yaw_bbox"): 112 | batch[f"gt/{k}"] = batch[f"agent/{k}"][:, self.n_step_hist :].transpose(1, 2).contiguous() 113 | 114 | if self.gt_in_local: 115 | # [n_scene, n_agent, n_step_hist, 2] 116 | batch["gt/pos"] = torch_pos2local(batch["gt/pos"], ref_pos, ref_rot) 117 | batch["gt/vel"] = torch_dir2local(batch["gt/vel"], ref_rot) 118 | # [n_scene, n_agent, n_step_hist, 1] 119 | batch["gt/yaw_bbox"] = torch_rad2local(batch["gt/yaw_bbox"], ref_yaw, cast=False) 120 | 121 | # ! prepare map polylines 122 | for k in ("valid", "type", "pos", "dir"): 123 | batch[f"sc/map_{k}"] = batch[f"map/{k}"] 124 | 125 | # ! prepare traffic lights 126 | for k in ("valid", "state", "pos", "dir"): 127 | batch[f"sc/tl_{k}"] = batch[f"{prefix}tl_stop/{k}"][:, : self.n_step_hist] 128 | 129 | if self.mask_invalid: 130 | self.zero_mask_invalid(batch) 131 | return batch 132 | 133 | @staticmethod 134 | def zero_mask_invalid(batch: Dict[str, Tensor]): 135 | 136 | agent_invalid = ~batch["sc/agent_valid"].unsqueeze(-1) 137 | for k in ["pos", "vel", "spd", "acc", "yaw_bbox", "yaw_rate"]: 138 | _key = f"sc/agent_{k}" 139 | batch[_key] = batch[_key].masked_fill(agent_invalid, 0) 140 | 141 | agent_invalid = ~(batch["sc/agent_valid"].any(-1, keepdim=True)) 142 | for k in ["type", "role", "size"]: 143 | _key = f"sc/agent_{k}" 144 | batch[_key] = batch[_key].masked_fill(agent_invalid, 0) 145 | 146 | map_invalid = ~batch["sc/map_valid"].unsqueeze(-1) 147 | batch["sc/map_pos"] = batch["sc/map_pos"].masked_fill(map_invalid, 0) 148 | batch["sc/map_dir"] = batch["sc/map_dir"].masked_fill(map_invalid, 0) 149 | map_invalid = ~(batch["sc/map_valid"].any(-1, keepdim=True)) 150 | batch["sc/map_type"] = batch["sc/map_type"].masked_fill(map_invalid, 0) 151 | 152 | tl_invalid = ~batch["sc/tl_valid"].unsqueeze(-1) 153 | for k in ["state", "pos", "dir"]: 154 | _key = f"sc/tl_{k}" 155 | batch[_key] = batch[_key].masked_fill(tl_invalid, 0) 156 | 157 | gt_invalid = ~batch["gt/valid"].unsqueeze(-1) 158 | for k in ["pos", "spd", "vel", "yaw_bbox"]: 159 | _key = f"gt/{k}" 160 | batch[_key] = batch[_key].masked_fill(gt_invalid, 0) 161 | -------------------------------------------------------------------------------- /src/models/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional, Tuple 3 | import math 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.nn import functional as F 7 | 8 | # KNARPE 9 | class AttentionRPE(nn.Module): 10 | def __init__( 11 | self, 12 | d_model: int, 13 | n_head: int, 14 | dropout_p: float = 0.0, 15 | bias: bool = True, 16 | d_rpe: int = -1, 17 | apply_q_rpe: bool = False, 18 | ) -> None: 19 | """ 20 | Always batch first. Always src and tgt have the same d_model. 21 | """ 22 | super(AttentionRPE, self).__init__() 23 | 24 | self.d_model = d_model 25 | self.n_head = n_head 26 | self.d_head = d_model // n_head 27 | self.apply_q_rpe = apply_q_rpe 28 | self.d_rpe = d_rpe 29 | 30 | assert self.d_head * n_head == d_model, "d_model must be divisible by n_head" 31 | 32 | if self.d_rpe > 0: 33 | n_project_rpe = 3 if apply_q_rpe else 2 34 | self.mlp_rpe = nn.Linear(d_rpe, n_project_rpe * d_model, bias=bias) 35 | 36 | self.in_proj_weight = nn.Parameter(torch.empty((3 * d_model, d_model))) 37 | self.out_proj_weight = nn.Parameter(torch.empty((d_model, d_model))) 38 | if bias: 39 | self.in_proj_bias = nn.Parameter(torch.empty(3 * d_model)) 40 | self.out_proj_bias = nn.Parameter(torch.empty(d_model)) 41 | else: 42 | self.register_parameter("in_proj_bias", None) 43 | self.register_parameter("out_proj_bias", None) 44 | 45 | self.dropout = nn.Dropout(p=dropout_p, inplace=False) if dropout_p > 0 else None 46 | 47 | self._reset_parameters() 48 | 49 | def _reset_parameters(self): 50 | nn.init.xavier_uniform_(self.in_proj_weight) 51 | nn.init.xavier_uniform_(self.out_proj_weight) 52 | if self.in_proj_bias is not None: 53 | nn.init.constant_(self.in_proj_bias, 0.0) 54 | if self.out_proj_bias is not None: 55 | nn.init.constant_(self.out_proj_bias, 0.0) 56 | 57 | def forward( 58 | self, 59 | src: Tensor, 60 | tgt: Optional[Tensor] = None, 61 | tgt_padding_mask: Optional[Tensor] = None, 62 | attn_mask: Optional[Tensor] = None, 63 | rpe: Optional[Tensor] = None, 64 | need_weights=False, 65 | ) -> Tuple[Tensor, Optional[Tensor]]: 66 | """ 67 | Args: 68 | src: [n_batch, n_src, d_model] 69 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe. 70 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe. 71 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt. 72 | rpe: [n_batch, n_src, n_tgt, d_rpe] 73 | 74 | Returns: 75 | out: [n_batch, n_src, d_model] 76 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None 77 | 78 | Remarks: 79 | absoulte_pe should be already added to src/tgt. 80 | if for a batch entry all tgt are invalid, then returns 0 for that batch entry. 81 | """ 82 | n_batch, n_src, _ = src.shape 83 | if tgt is None: 84 | n_tgt = n_src 85 | # self-attention 86 | qkv = F.linear(src, self.in_proj_weight, self.in_proj_bias) 87 | q, k, v = qkv.chunk(3, dim=-1) 88 | else: 89 | n_tgt = tgt.shape[-2] 90 | # encoder-decoder attention 91 | w_src, w_tgt = self.in_proj_weight.split([self.d_model, self.d_model * 2]) 92 | b_src, b_tgt = None, None 93 | if self.in_proj_bias is not None: 94 | b_src, b_tgt = self.in_proj_bias.split([self.d_model, self.d_model * 2]) 95 | q = F.linear(src, w_src, b_src) 96 | kv = F.linear(tgt, w_tgt, b_tgt) 97 | k, v = kv.chunk(2, dim=-1) 98 | # q: [n_batch, n_src, d_model], k,v: [n_batch, (n_src), n_tgt, d_model] 99 | 100 | attn_invalid_mask = None # [n_batch, n_src, n_tgt] 101 | if tgt_padding_mask is not None: # [n_batch, n_tgt], bool 102 | attn_invalid_mask = tgt_padding_mask 103 | if attn_invalid_mask.dim() == 2: 104 | attn_invalid_mask = attn_invalid_mask.unsqueeze(1).expand(-1, n_src, -1) 105 | if attn_mask is not None: # [n_batch, n_src, n_tgt], bool 106 | if attn_invalid_mask is None: 107 | attn_invalid_mask = attn_mask 108 | else: 109 | attn_invalid_mask = attn_invalid_mask | attn_mask 110 | 111 | mask_no_tgt_valid = None # [n_batch, n_src] 112 | if attn_invalid_mask is not None: 113 | mask_no_tgt_valid = attn_invalid_mask.all(-1) 114 | if mask_no_tgt_valid.any(): 115 | attn_invalid_mask = attn_invalid_mask & (~mask_no_tgt_valid.unsqueeze(-1)) # to avoid softmax nan 116 | else: 117 | mask_no_tgt_valid = None 118 | 119 | # get attn: [n_batch, n_head, n_src, n_tgt] 120 | if rpe is None: 121 | if k.dim() == 3: 122 | # ! normal attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_tgt, d_model] 123 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).contiguous() 124 | k = k.view(n_batch, n_tgt, self.n_head, self.d_head).transpose(1, 2).contiguous() 125 | v = v.view(n_batch, n_tgt, self.n_head, self.d_head).transpose(1, 2).contiguous() 126 | attn = torch.matmul(q, k.transpose(-2, -1)) # [n_batch, n_head, n_src, n_tgt] 127 | # q: [n_batch, n_head, n_src, d_head], k,v: [n_batch, n_head, n_tgt, d_head] 128 | else: 129 | # ! KNN attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_src, n_tgt, d_model] 130 | # k,v: [n_batch, n_src, n_tgt, d_model] -> [n_batch, n_head, n_src, n_tgt_knn, d_head] 131 | k = k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 132 | v = v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 133 | # [n_batch, n_src, d_model] -> [n_batch, n_head, n_src, 1, d_head] 134 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).unsqueeze(3) 135 | attn = torch.sum(q * k, dim=-1) # [n_batch, n_head, n_src, n_tgt_knn] 136 | else: 137 | # ! rpe attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_tgt, d_model] 138 | assert self.d_rpe > 0 139 | # k,v: [n_batch, n_src, n_tgt, d_model] -> [n_batch, n_head, n_src, n_tgt_knn, d_head] 140 | k = k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 141 | v = v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 142 | # [n_batch, n_src, d_model] -> [n_batch, n_head, n_src, 1, d_head] 143 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).unsqueeze(3) 144 | 145 | # project rpe to rpe_q, rpe_k, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head] 146 | rpe = self.mlp_rpe(rpe) 147 | if self.apply_q_rpe: 148 | rpe_q, rpe_k, rpe_v = rpe.chunk(3, dim=-1) 149 | rpe_q = rpe_q.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 150 | else: 151 | rpe_k, rpe_v = rpe.chunk(2, dim=-1) 152 | rpe_k = rpe_k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 153 | rpe_v = rpe_v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1) 154 | 155 | # get attn: [n_batch, n_head, n_src, n_tgt] 156 | if self.apply_q_rpe: 157 | attn = torch.sum((q + rpe_q) * (k + rpe_k), dim=-1) 158 | # attn = torch.sum((q + rpe_q) * (k + rpe_k) - rpe_q * rpe_k, dim=-1) 159 | else: 160 | attn = torch.sum(q * (k + rpe_k), dim=-1) 161 | # q: [n_batch, n_head, n_src, 1, d_head] 162 | # k,v: [n_batch, n_head, n_src, n_tgt, d_head] 163 | # rpe_q, rpe_k, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head] 164 | 165 | if attn_invalid_mask is not None: 166 | # attn_invalid_mask: [n_batch, n_src, n_tgt], attn: [n_batch, n_head, n_src, n_tgt] 167 | attn = attn.masked_fill(attn_invalid_mask.unsqueeze(1), float("-inf")) 168 | 169 | attn = torch.softmax(attn / math.sqrt(self.d_head), dim=-1) 170 | if self.dropout is not None: 171 | attn = self.dropout(attn) 172 | 173 | # attn: [n_batch, n_head, n_src, n_tgt] 174 | if rpe is None: 175 | if v.dim() == 4: 176 | out = torch.matmul(attn, v) # v, [n_batch, n_head, n_tgt, d_head] 177 | else: 178 | out = torch.sum(v * attn.unsqueeze(-1), dim=3) # v: [n_batch, n_head, n_src, n_tgt, d_head] 179 | else: 180 | # v, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head] 181 | out = torch.sum((v + rpe_v) * attn.unsqueeze(-1), dim=3) 182 | 183 | # out: [n_batch, n_head, n_src, d_head] 184 | out = out.transpose(1, 2).flatten(2, 3) # [n_batch, n_src, d_model] 185 | out = F.linear(out, self.out_proj_weight, self.out_proj_bias) 186 | 187 | if mask_no_tgt_valid is not None: 188 | # mask_no_tgt_valid: [n_batch, n_src], out: [n_batch, n_src, d_model] 189 | out = out.masked_fill(mask_no_tgt_valid.unsqueeze(-1), 0) 190 | 191 | if need_weights: 192 | attn_weights = attn.mean(1) # [n_batch, n_src, n_tgt] 193 | if mask_no_tgt_valid is not None: 194 | attn_weights = attn_weights.masked_fill(mask_no_tgt_valid.unsqueeze(-1), 0) 195 | return out, attn_weights 196 | else: 197 | return out, None 198 | -------------------------------------------------------------------------------- /src/utils/submission.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import List 3 | from omegaconf import ListConfig 4 | from pathlib import Path 5 | import tarfile 6 | import os 7 | import pandas as pd 8 | from torch import Tensor 9 | from waymo_open_dataset.protos import motion_submission_pb2 10 | from pytorch_lightning.loggers import WandbLogger 11 | from .transform_utils import torch_pos2global, torch_rad2rot 12 | 13 | # ! single GPU only 14 | 15 | 16 | class SubWOMD: 17 | def __init__( 18 | self, 19 | k_futures: int, 20 | wb_artifact: str, 21 | interactive_challenge: bool, 22 | activate: bool, 23 | method_name: str, 24 | authors: ListConfig[str], 25 | affiliation: str, 26 | description: str, 27 | method_link: str, 28 | account_name: str, 29 | ) -> None: 30 | self.activate = activate 31 | if activate: 32 | self.submissions = {} 33 | for _k in range(1, k_futures + 1): 34 | self.submissions[_k] = motion_submission_pb2.MotionChallengeSubmission() 35 | self.submissions[_k].account_name = account_name 36 | self.submissions[_k].unique_method_name = f"{method_name}_K{_k}" 37 | self.submissions[_k].authors.extend(list(authors)) 38 | self.submissions[_k].affiliation = affiliation 39 | self.submissions[_k].description = description 40 | self.submissions[_k].method_link = method_link 41 | if interactive_challenge: 42 | self.submissions[_k].submission_type = 2 43 | else: 44 | self.submissions[_k].submission_type = 1 45 | 46 | def add_to_submissions( 47 | self, 48 | waymo_trajs: Tensor, 49 | waymo_scores: Tensor, 50 | mask_pred: Tensor, 51 | object_id: Tensor, 52 | scenario_center: Tensor, 53 | scenario_yaw: Tensor, 54 | scenario_id: List[str], 55 | ) -> None: 56 | """ 57 | Args: 58 | waymo_trajs: [n_batch, step_start+1...step_end, n_agent, K, 2] 59 | waymo_scores: [n_batch, n_agent, K] 60 | mask_pred: [n_batch, n_agent] bool 61 | object_id: [n_batch, n_agent] int 62 | scenario_center: [n_batch, 2] 63 | scenario_yaw: [n_batch] 64 | scenario_id: list of str 65 | """ 66 | if not self.activate: 67 | return 68 | 69 | waymo_trajs = waymo_trajs[:, 4::5].permute(0, 2, 3, 1, 4) # [n_batch, n_agent, K, n_step, 2] 70 | waymo_trajs = torch_pos2global( 71 | waymo_trajs.flatten(1, 3), scenario_center.unsqueeze(1), torch_rad2rot(scenario_yaw) 72 | ).view(waymo_trajs.shape) 73 | 74 | waymo_trajs = waymo_trajs.cpu().numpy() 75 | waymo_scores = waymo_scores.cpu().numpy() 76 | mask_pred = mask_pred.cpu().numpy() 77 | object_id = object_id.cpu().numpy() 78 | 79 | for i_batch in range(waymo_trajs.shape[0]): 80 | agent_pos = waymo_trajs[i_batch, mask_pred[i_batch]] # [n_agent_pred, K, n_step, 2] 81 | agent_id = object_id[i_batch, mask_pred[i_batch]] # [n_agent_pred] 82 | agent_score = waymo_scores[i_batch, mask_pred[i_batch]] # [n_agent_pred, K] 83 | 84 | for n_K, submission in self.submissions.items(): 85 | scenario_prediction = motion_submission_pb2.ChallengeScenarioPredictions() 86 | scenario_prediction.scenario_id = scenario_id[i_batch] 87 | 88 | if submission.submission_type == 1: 89 | # single prediction 90 | for i_track in range(agent_pos.shape[0]): 91 | prediction = motion_submission_pb2.SingleObjectPrediction() 92 | prediction.object_id = agent_id[i_track] 93 | for _k in range(n_K): 94 | scored_trajectory = motion_submission_pb2.ScoredTrajectory() 95 | scored_trajectory.confidence = agent_score[i_track, _k] 96 | scored_trajectory.trajectory.center_x.extend(agent_pos[i_track, _k, :, 0]) 97 | scored_trajectory.trajectory.center_y.extend(agent_pos[i_track, _k, :, 1]) 98 | prediction.trajectories.append(scored_trajectory) 99 | scenario_prediction.single_predictions.predictions.append(prediction) 100 | else: 101 | # joint prediction 102 | for _k in range(n_K): 103 | scored_joint_trajectory = motion_submission_pb2.ScoredJointTrajectory() 104 | scored_joint_trajectory.confidence = agent_score[:, _k].sum(0) 105 | for i_track in range(agent_pos.shape[0]): 106 | object_trajectory = motion_submission_pb2.ObjectTrajectory() 107 | object_trajectory.object_id = agent_id[i_track] 108 | object_trajectory.trajectory.center_x.extend(agent_pos[i_track, _k, :, 0]) 109 | object_trajectory.trajectory.center_y.extend(agent_pos[i_track, _k, :, 1]) 110 | scored_joint_trajectory.trajectories.append(object_trajectory) 111 | scenario_prediction.joint_prediction.joint_trajectories.append(scored_joint_trajectory) 112 | 113 | submission.scenario_predictions.append(scenario_prediction) 114 | 115 | def save_sub_files(self, logger: WandbLogger) -> List[str]: 116 | if not self.activate: 117 | return [] 118 | 119 | print(f"saving womd submission files to {os.getcwd()}") 120 | file_paths = [] 121 | for k, submission in self.submissions.items(): 122 | submission_dir = Path(f"womd_K{k}") 123 | submission_dir.mkdir(exist_ok=True) 124 | f = open(submission_dir / f"womd_K{k}.bin", "wb") 125 | f.write(submission.SerializeToString()) 126 | f.close() 127 | tar_file_name = submission_dir.as_posix() + ".tar.gz" 128 | with tarfile.open(tar_file_name, "w:gz") as tar: 129 | tar.add(submission_dir, arcname=submission_dir.name) 130 | if isinstance(logger, WandbLogger): 131 | logger.experiment.save(tar_file_name) 132 | else: 133 | file_paths.append(tar_file_name) 134 | return file_paths 135 | 136 | 137 | class SubAV2: 138 | def __init__(self, k_futures: int, activate: bool) -> None: 139 | self._SUBMISSION_COL_NAMES = [ 140 | "scenario_id", 141 | "track_id", 142 | "probability", 143 | "predicted_trajectory_x", 144 | "predicted_trajectory_y", 145 | ] 146 | self.activate = activate 147 | if activate: 148 | self.submissions = {} 149 | for _k in range(1, k_futures + 1): 150 | self.submissions[_k] = [] 151 | 152 | def add_to_submissions( 153 | self, 154 | waymo_trajs: Tensor, 155 | waymo_scores: Tensor, 156 | mask_pred: Tensor, 157 | object_id: Tensor, 158 | scenario_center: Tensor, 159 | scenario_yaw: Tensor, 160 | scenario_id: List[str], 161 | ) -> None: 162 | """ 163 | Args: 164 | waymo_trajs: [n_batch, step_start+1...step_end, n_agent, K, 2] 165 | waymo_scores: [n_batch, n_agent, K] 166 | mask_pred: [n_batch, n_agent] bool 167 | object_id: [n_batch, n_agent] int 168 | scenario_center: [n_batch, 2] 169 | scenario_yaw: [n_batch] 170 | scenario_id: list of str 171 | """ 172 | if not self.activate: 173 | return 174 | 175 | waymo_trajs = waymo_trajs.permute(0, 2, 3, 1, 4) # [n_batch, n_agent, K, n_step, 2] 176 | waymo_trajs = torch_pos2global( 177 | waymo_trajs.flatten(1, 3), scenario_center.unsqueeze(1), torch_rad2rot(scenario_yaw) 178 | ).view(waymo_trajs.shape) 179 | 180 | waymo_trajs = waymo_trajs.cpu().numpy() 181 | waymo_scores = waymo_scores.cpu().numpy() 182 | mask_pred = mask_pred.cpu().numpy() 183 | object_id = object_id.cpu().numpy() 184 | 185 | for i_batch in range(waymo_trajs.shape[0]): 186 | agent_pos = waymo_trajs[i_batch, mask_pred[i_batch]] # [n_agent_pred, K, n_step, 2] 187 | agent_id = object_id[i_batch, mask_pred[i_batch]] # [n_agent_pred] 188 | agent_score = waymo_scores[i_batch, mask_pred[i_batch]] # [n_agent_pred, K] 189 | 190 | for i_agent in range(agent_id.shape[0]): 191 | for n_K in self.submissions.keys(): 192 | confidence = agent_score[i_agent, :n_K] / agent_score[i_agent, :n_K].sum() 193 | for _k in range(n_K): 194 | self.submissions[n_K].append( 195 | ( 196 | scenario_id[i_batch], 197 | str(agent_id[i_agent]), 198 | confidence[_k], 199 | agent_pos[i_agent, _k, :, 0], 200 | agent_pos[i_agent, _k, :, 1], 201 | ) 202 | ) 203 | 204 | def save_sub_files(self, logger: WandbLogger) -> List[str]: 205 | if not self.activate: 206 | return [] 207 | 208 | print(f"saving av2 submission files to {os.getcwd()}") 209 | file_paths = [] 210 | for k, prediction_rows in self.submissions.items(): 211 | parquet_file_path = Path(f"av2_K{k}.parquet").as_posix() 212 | submission_df = pd.DataFrame(prediction_rows, columns=self._SUBMISSION_COL_NAMES) 213 | submission_df.to_parquet(parquet_file_path) 214 | if isinstance(logger, WandbLogger): 215 | logger.experiment.save(parquet_file_path) 216 | else: 217 | file_paths.append(parquet_file_path) 218 | return file_paths -------------------------------------------------------------------------------- /src/data_modules/data_h5_womd.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional, Dict, Any, Tuple 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import h5py 7 | 8 | 9 | class DatasetBase(Dataset[Dict[str, np.ndarray]]): 10 | def __init__(self, filepath: str, tensor_size: Dict[str, Tuple]) -> None: 11 | super().__init__() 12 | self.tensor_size = tensor_size 13 | self.filepath = filepath 14 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 15 | self.dataset_len = int(hf.attrs["data_len"]) 16 | 17 | def __len__(self) -> int: 18 | return self.dataset_len 19 | 20 | 21 | class DatasetTrain(DatasetBase): 22 | """ 23 | The waymo 9-sec trainging.h5 is repetitive, start at {0, 2, 4, 5, 6, 8, 10} seconds within the 20-sec episode. 24 | Always train with the whole training.h5 dataset. 25 | limit_train_batches just for controlling the validation frequency. 26 | """ 27 | 28 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: 29 | idx = np.random.randint(self.dataset_len) 30 | idx_key = str(idx) 31 | out_dict = {"episode_idx": idx} 32 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 33 | for k in self.tensor_size.keys(): 34 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k]) 35 | return out_dict 36 | 37 | 38 | class DatasetVal(DatasetBase): 39 | # for validation.h5 and testing.h5 40 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: 41 | idx_key = str(idx) 42 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 43 | out_dict = { 44 | "episode_idx": idx, 45 | "scenario_id": hf[idx_key].attrs["scenario_id"], 46 | "scenario_center": hf[idx_key].attrs["scenario_center"], 47 | "scenario_yaw": hf[idx_key].attrs["scenario_yaw"], 48 | "with_map": hf[idx_key].attrs["with_map"], # some epidosdes in the testing dataset do not have map. 49 | } 50 | for k, _size in self.tensor_size.items(): 51 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k]) 52 | if out_dict[k].shape != _size: 53 | assert "agent" in k 54 | out_dict[k] = np.ones(_size, dtype=out_dict[k].dtype) 55 | return out_dict 56 | 57 | 58 | class DataH5womd(LightningDataModule): 59 | def __init__( 60 | self, 61 | data_dir: str, 62 | filename_train: str = "training", 63 | filename_val: str = "validation", 64 | filename_test: str = "testing", 65 | batch_size: int = 3, 66 | num_workers: int = 4, 67 | n_agent: int = 64, # if not the same as h5 dataset, use dummy agents, for scalability tests. 68 | ) -> None: 69 | super().__init__() 70 | self.interactive_challenge = "interactive" in filename_val or "interactive" in filename_test 71 | 72 | self.path_train_h5 = f"{data_dir}/{filename_train}.h5" 73 | self.path_val_h5 = f"{data_dir}/{filename_val}.h5" 74 | self.path_test_h5 = f"{data_dir}/{filename_test}.h5" 75 | self.batch_size = batch_size 76 | self.num_workers = num_workers 77 | 78 | n_step = 91 79 | n_step_history = 11 80 | n_agent_no_sim = 256 81 | n_pl = 1024 82 | n_tl = 100 83 | n_tl_stop = 40 84 | n_pl_node = 20 85 | self.tensor_size_train = { 86 | # agent states 87 | "agent/valid": (n_step, n_agent), # bool, 88 | "agent/pos": (n_step, n_agent, 2), # float32 89 | # v[1] = p[1]-p[0]. if p[1] invalid, v[1] also invalid, v[2]=v[3] 90 | "agent/vel": (n_step, n_agent, 2), # float32, v_x, v_y 91 | "agent/spd": (n_step, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy 92 | "agent/acc": (n_step, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt 93 | "agent/yaw_bbox": (n_step, n_agent, 1), # float32, yaw of the bbox heading 94 | "agent/yaw_rate": (n_step, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt 95 | # agent attributes 96 | "agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 97 | "agent/cmd": (n_agent, 8), # bool one_hot 98 | "agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2] 99 | "agent/size": (n_agent, 3), # float32: [length, width, height] 100 | "agent/goal": (n_agent, 4), # float32: [x, y, theta, v] 101 | "agent/dest": (n_agent,), # int64: index to map n_pl 102 | # map polylines 103 | "map/valid": (n_pl, n_pl_node), # bool 104 | "map/type": (n_pl, 11), # bool one_hot 105 | "map/pos": (n_pl, n_pl_node, 2), # float32 106 | "map/dir": (n_pl, n_pl_node, 2), # float32 107 | "map/boundary": (4,), # xmin, xmax, ymin, ymax 108 | # traffic lights 109 | "tl_lane/valid": (n_step, n_tl), # bool 110 | "tl_lane/state": (n_step, n_tl, 5), # bool one_hot 111 | "tl_lane/idx": (n_step, n_tl), # int, -1 means not valid 112 | "tl_stop/valid": (n_step, n_tl_stop), # bool 113 | "tl_stop/state": (n_step, n_tl_stop, 5), # bool one_hot 114 | "tl_stop/pos": (n_step, n_tl_stop, 2), # x,y 115 | "tl_stop/dir": (n_step, n_tl_stop, 2), # x,y 116 | } 117 | 118 | self.tensor_size_test = { 119 | # object_id for waymo metrics 120 | "history/agent/object_id": (n_agent,), 121 | "history/agent_no_sim/object_id": (n_agent_no_sim,), 122 | # agent_sim 123 | "history/agent/valid": (n_step_history, n_agent), # bool, 124 | "history/agent/pos": (n_step_history, n_agent, 2), # float32 125 | "history/agent/vel": (n_step_history, n_agent, 2), # float32, v_x, v_y 126 | "history/agent/spd": (n_step_history, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy 127 | "history/agent/acc": (n_step_history, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt 128 | "history/agent/yaw_bbox": (n_step_history, n_agent, 1), # float32, yaw of the bbox heading 129 | "history/agent/yaw_rate": (n_step_history, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt 130 | "history/agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 131 | "history/agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2] 132 | "history/agent/size": (n_agent, 3), # float32: [length, width, height] 133 | # agent_no_sim not used by the models currently 134 | "history/agent_no_sim/valid": (n_step_history, n_agent_no_sim), 135 | "history/agent_no_sim/pos": (n_step_history, n_agent_no_sim, 2), 136 | "history/agent_no_sim/vel": (n_step_history, n_agent_no_sim, 2), 137 | "history/agent_no_sim/spd": (n_step_history, n_agent_no_sim, 1), 138 | "history/agent_no_sim/yaw_bbox": (n_step_history, n_agent_no_sim, 1), 139 | "history/agent_no_sim/type": (n_agent_no_sim, 3), 140 | "history/agent_no_sim/size": (n_agent_no_sim, 3), 141 | # map 142 | "map/valid": (n_pl, n_pl_node), # bool 143 | "map/type": (n_pl, 11), # bool one_hot 144 | "map/pos": (n_pl, n_pl_node, 2), # float32 145 | "map/dir": (n_pl, n_pl_node, 2), # float32 146 | "map/boundary": (4,), # xmin, xmax, ymin, ymax 147 | # traffic_light 148 | "history/tl_lane/valid": (n_step_history, n_tl), # bool 149 | "history/tl_lane/state": (n_step_history, n_tl, 5), # bool one_hot 150 | "history/tl_lane/idx": (n_step_history, n_tl), # int, -1 means not valid 151 | "history/tl_stop/valid": (n_step_history, n_tl_stop), # bool 152 | "history/tl_stop/state": (n_step_history, n_tl_stop, 5), # bool one_hot 153 | "history/tl_stop/pos": (n_step_history, n_tl_stop, 2), # x,y 154 | "history/tl_stop/dir": (n_step_history, n_tl_stop, 2), # dx,dy 155 | } 156 | 157 | self.tensor_size_val = { 158 | "agent/object_id": (n_agent,), 159 | "agent_no_sim/object_id": (n_agent_no_sim,), 160 | # agent_no_sim 161 | "agent_no_sim/valid": (n_step, n_agent_no_sim), # bool, 162 | "agent_no_sim/pos": (n_step, n_agent_no_sim, 2), # float32 163 | "agent_no_sim/vel": (n_step, n_agent_no_sim, 2), # float32, v_x, v_y 164 | "agent_no_sim/spd": (n_step, n_agent_no_sim, 1), # norm of vel, signed using yaw_bbox and vel_xy 165 | "agent_no_sim/yaw_bbox": (n_step, n_agent_no_sim, 1), # float32, yaw of the bbox heading 166 | "agent_no_sim/type": (n_agent_no_sim, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 167 | "agent_no_sim/size": (n_agent_no_sim, 3), # float32: [length, width, height] 168 | } 169 | 170 | self.tensor_size_val = self.tensor_size_val | self.tensor_size_train | self.tensor_size_test 171 | 172 | def setup(self, stage: Optional[str] = None) -> None: 173 | if stage == "fit" or stage is None: 174 | self.train_dataset = DatasetTrain(self.path_train_h5, self.tensor_size_train) 175 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val) 176 | elif stage == "validate": 177 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val) 178 | elif stage == "test": 179 | self.test_dataset = DatasetVal(self.path_test_h5, self.tensor_size_test) 180 | 181 | def train_dataloader(self) -> DataLoader[Any]: 182 | return self._get_dataloader(self.train_dataset, self.batch_size, self.num_workers) 183 | 184 | def val_dataloader(self) -> DataLoader[Any]: 185 | return self._get_dataloader(self.val_dataset, self.batch_size, self.num_workers) 186 | 187 | def test_dataloader(self) -> DataLoader[Any]: 188 | return self._get_dataloader(self.test_dataset, self.batch_size, self.num_workers) 189 | 190 | @staticmethod 191 | def _get_dataloader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader[Any]: 192 | return DataLoader( 193 | ds, 194 | batch_size=batch_size, 195 | num_workers=num_workers, 196 | pin_memory=True, 197 | shuffle=False, 198 | drop_last=False, 199 | persistent_workers=True, 200 | ) 201 | -------------------------------------------------------------------------------- /src/models/modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional, Tuple 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from .attention import AttentionRPE 6 | 7 | 8 | def _get_activation_fn(activation): 9 | if activation == "relu": 10 | return F.relu 11 | elif activation == "gelu": 12 | return F.gelu 13 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 14 | 15 | 16 | class TransformerBlock(nn.Module): 17 | __constants__ = ["norm"] 18 | 19 | def __init__( 20 | self, 21 | d_model: int, 22 | n_head: int = 2, 23 | d_feedforward: int = 256, 24 | dropout_p: float = 0.1, 25 | activation: str = "relu", 26 | n_layer: int = 1, 27 | norm_first: bool = True, 28 | decoder_self_attn: bool = False, 29 | bias: bool = True, 30 | d_rpe: int = -1, 31 | apply_q_rpe: bool = False, 32 | ) -> None: 33 | super(TransformerBlock, self).__init__() 34 | self.layers = nn.ModuleList( 35 | [ 36 | TransformerCrossAttention( 37 | d_model=d_model, 38 | n_head=n_head, 39 | d_feedforward=d_feedforward, 40 | dropout_p=dropout_p, 41 | activation=activation, 42 | norm_first=norm_first, 43 | decoder_self_attn=decoder_self_attn, 44 | bias=bias, 45 | d_rpe=d_rpe, 46 | apply_q_rpe=apply_q_rpe, 47 | ) 48 | for _ in range(n_layer) 49 | ] 50 | ) 51 | 52 | # self.layers = _get_clones(encoder_layer, n_layer) 53 | # self.n_layer = n_layer 54 | # self.norm = nn.LayerNorm(d_model) if norm_first else None 55 | 56 | def forward( 57 | self, 58 | src: Tensor, 59 | src_padding_mask: Optional[Tensor] = None, 60 | tgt: Optional[Tensor] = None, 61 | tgt_padding_mask: Optional[Tensor] = None, 62 | rpe: Optional[Tensor] = None, 63 | decoder_tgt: Optional[Tensor] = None, 64 | decoder_tgt_padding_mask: Optional[Tensor] = None, 65 | decoder_rpe: Optional[Tensor] = None, 66 | attn_mask: Optional[Tensor] = None, 67 | need_weights: bool = False, 68 | ) -> Tuple[Tensor, Optional[Tensor]]: 69 | """ 70 | Args: 71 | src: [n_batch, n_src, d_model] 72 | src_padding_mask: [n_batch, n_src], bool, if True, src is invalid. 73 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe. 74 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe. 75 | rpe: [n_batch, n_src, n_tgt, d_rpe] 76 | decoder_tgt: [n_batch, (n_src), n_tgt_decoder, d_model], (n_src) if using rpe. 77 | decoder_tgt_padding_mask: [n_batch, (n_src), n_tgt_decoder], (n_src) if using rpe. 78 | decoder_rpe: [n_batch, n_src, n_tgt_decoder, d_rpe] 79 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt. 80 | 81 | Returns: 82 | src: [n_batch, n_src, d_model] 83 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None 84 | 85 | Remarks: 86 | absoulte_pe should be already added to src/tgt. 87 | """ 88 | attn_weights = None 89 | for mod in self.layers: 90 | src, attn_weights = mod( 91 | src=src, 92 | src_padding_mask=src_padding_mask, 93 | tgt=tgt, 94 | tgt_padding_mask=tgt_padding_mask, 95 | rpe=rpe, 96 | decoder_tgt=decoder_tgt, 97 | decoder_tgt_padding_mask=decoder_tgt_padding_mask, 98 | decoder_rpe=decoder_rpe, 99 | attn_mask=attn_mask, 100 | need_weights=need_weights, 101 | ) 102 | # if self.norm is not None: 103 | # src = self.norm(src) 104 | return src, attn_weights 105 | 106 | 107 | class TransformerCrossAttention(nn.Module): 108 | def __init__( 109 | self, 110 | d_model: int, 111 | n_head: int, 112 | d_feedforward: int, 113 | dropout_p: float, 114 | activation: str, 115 | norm_first: bool, 116 | decoder_self_attn: bool, 117 | bias: bool, 118 | d_rpe: int = -1, 119 | apply_q_rpe: bool = False, 120 | ) -> None: 121 | super(TransformerCrossAttention, self).__init__() 122 | self.norm_first = norm_first 123 | self.d_feedforward = d_feedforward 124 | self.decoder_self_attn = decoder_self_attn 125 | inplace = False 126 | 127 | self.dropout = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None 128 | self.activation = _get_activation_fn(activation) 129 | self.norm1 = nn.LayerNorm(d_model) 130 | 131 | if self.decoder_self_attn: 132 | self.attn_src = AttentionRPE( 133 | d_model=d_model, n_head=n_head, dropout_p=dropout_p, bias=bias, d_rpe=d_rpe, apply_q_rpe=apply_q_rpe 134 | ) 135 | self.norm_src = nn.LayerNorm(d_model) 136 | self.dropout_src = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None 137 | 138 | if self.norm_first: 139 | self.norm_tgt = nn.LayerNorm(d_model) 140 | 141 | self.attn = AttentionRPE( 142 | d_model=d_model, n_head=n_head, dropout_p=dropout_p, bias=bias, d_rpe=d_rpe, apply_q_rpe=apply_q_rpe 143 | ) 144 | if self.d_feedforward > 0: 145 | self.linear1 = nn.Linear(d_model, d_feedforward) 146 | self.linear2 = nn.Linear(d_feedforward, d_model) 147 | self.norm2 = nn.LayerNorm(d_model) 148 | self.dropout1 = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None 149 | self.dropout2 = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None 150 | 151 | def forward( 152 | self, 153 | src: Tensor, 154 | src_padding_mask: Optional[Tensor] = None, 155 | tgt: Optional[Tensor] = None, 156 | tgt_padding_mask: Optional[Tensor] = None, 157 | rpe: Optional[Tensor] = None, 158 | decoder_tgt: Optional[Tensor] = None, 159 | decoder_tgt_padding_mask: Optional[Tensor] = None, 160 | decoder_rpe: Optional[Tensor] = None, 161 | attn_mask: Optional[Tensor] = None, 162 | need_weights: bool = False, 163 | ) -> Tuple[Tensor, Optional[Tensor]]: 164 | """ 165 | Args: 166 | src: [n_batch, n_src, d_model] 167 | src_padding_mask: [n_batch, n_src], bool, if True, src is invalid. 168 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe. 169 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe. 170 | rpe: [n_batch, n_src, n_tgt, d_rpe] 171 | decoder_tgt: [n_batch, n_src, n_tgt_decoder, d_model], when use decoder_rpe 172 | decoder_tgt_padding_mask: [n_batch, n_src, n_tgt_decoder], when use decoder_rpe 173 | decoder_rpe: [n_batch, n_src, n_tgt_decoder, d_rpe] 174 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt. 175 | 176 | Returns: 177 | out: [n_batch, n_src, d_model] 178 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None 179 | 180 | Remarks: 181 | absoulte_pe should be already added to src/tgt. 182 | """ 183 | if self.decoder_self_attn: 184 | # transformer decoder 185 | if self.norm_first: 186 | _s = self.norm_src(src) 187 | if decoder_tgt is None: 188 | _s = self.attn_src(_s, tgt_padding_mask=src_padding_mask)[0] 189 | else: 190 | decoder_tgt = self.norm_src(decoder_tgt) 191 | _s = self.attn_src(_s, decoder_tgt, tgt_padding_mask=decoder_tgt_padding_mask, rpe=decoder_rpe)[0] 192 | 193 | if self.dropout_src is None: 194 | src = src + _s 195 | else: 196 | src = src + self.dropout_src(_s) 197 | else: 198 | if decoder_tgt is None: 199 | _s = self.attn_src(src, tgt_padding_mask=src_padding_mask)[0] 200 | else: 201 | _s = self.attn_src(src, decoder_tgt, tgt_padding_mask=decoder_tgt_padding_mask, rpe=decoder_rpe)[0] 202 | 203 | if self.dropout_src is None: 204 | src = self.norm_src(src + _s) 205 | else: 206 | src = self.norm_src(src + self.dropout_src(_s)) 207 | 208 | if tgt is None: 209 | tgt_padding_mask = src_padding_mask 210 | 211 | if self.norm_first: 212 | src2 = self.norm1(src) 213 | if tgt is not None: 214 | tgt = self.norm_tgt(tgt) 215 | else: 216 | src2 = src 217 | 218 | # [n_batch, n_src, d_model] 219 | src2, attn_weights = self.attn( 220 | src=src2, 221 | tgt=tgt, 222 | tgt_padding_mask=tgt_padding_mask, 223 | attn_mask=attn_mask, 224 | rpe=rpe, 225 | need_weights=need_weights, 226 | ) 227 | 228 | if self.d_feedforward > 0: 229 | if self.dropout1 is None: 230 | src = src + src2 231 | else: 232 | src = src + self.dropout1(src2) 233 | 234 | if self.norm_first: 235 | src2 = self.norm2(src) 236 | else: 237 | src = self.norm1(src) 238 | src2 = src 239 | 240 | src2 = self.activation(self.linear1(src2)) 241 | if self.dropout is None: 242 | src2 = self.linear2(src2) 243 | else: 244 | src2 = self.linear2(self.dropout(src2)) 245 | 246 | if self.dropout2 is None: 247 | src = src + src2 248 | else: 249 | src = src + self.dropout2(src2) 250 | 251 | if not self.norm_first: 252 | src = self.norm2(src) 253 | else: 254 | # densetnt vectornet 255 | src2 = self.activation(src2) 256 | if self.dropout is None: 257 | src = src + src2 258 | else: 259 | src = src + self.dropout(src2) 260 | if not self.norm_first: 261 | src = self.norm1(src) 262 | 263 | if src_padding_mask is not None: 264 | src.masked_fill_(src_padding_mask.unsqueeze(-1), 0.0) 265 | if need_weights: 266 | attn_weights.masked_fill_(src_padding_mask.unsqueeze(-1), 0.0) 267 | return src, attn_weights 268 | -------------------------------------------------------------------------------- /src/data_modules/data_h5_av2.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Optional, Dict, Any, Tuple 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import h5py 7 | 8 | 9 | class DatasetBase(Dataset[Dict[str, np.ndarray]]): 10 | def __init__(self, filepath: str, tensor_size: Dict[str, Tuple]) -> None: 11 | super().__init__() 12 | self.tensor_size = tensor_size 13 | self.filepath = filepath 14 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 15 | self.dataset_len = int(hf.attrs["data_len"]) 16 | 17 | def __len__(self) -> int: 18 | return self.dataset_len 19 | 20 | 21 | class DatasetTrain(DatasetBase): 22 | """ 23 | Always train with the whole training.h5 dataset. 24 | limit_train_batches just for controlling the validation frequency. 25 | """ 26 | 27 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: 28 | idx = np.random.randint(self.dataset_len) 29 | idx_key = str(idx) 30 | out_dict = {"episode_idx": idx} 31 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 32 | for k, _size in self.tensor_size.items(): 33 | if k in hf[idx_key]: 34 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k]) 35 | else: 36 | if "/valid" in k or "/state" in k: 37 | _dtype = np.bool 38 | elif "/idx" in k: 39 | _dtype = np.int64 40 | else: 41 | _dtype = np.float32 42 | out_dict[k] = np.zeros(_size, dtype=_dtype) 43 | return out_dict 44 | 45 | 46 | class DatasetVal(DatasetBase): 47 | # for validation.h5 and testing.h5 48 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: 49 | idx_key = str(idx) 50 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf: 51 | out_dict = { 52 | "episode_idx": idx, 53 | "scenario_id": hf[idx_key].attrs["scenario_id"], 54 | "scenario_center": hf[idx_key].attrs["scenario_center"], 55 | "scenario_yaw": hf[idx_key].attrs["scenario_yaw"], 56 | "with_map": hf[idx_key].attrs["with_map"], # some epidosdes in the testing dataset do not have map. 57 | } 58 | for k, _size in self.tensor_size.items(): 59 | if k in hf[idx_key]: 60 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k]) 61 | else: 62 | if "/valid" in k or "/state" in k: 63 | _dtype = np.bool 64 | elif "/idx" in k: 65 | _dtype = np.int64 66 | else: 67 | _dtype = np.float32 68 | out_dict[k] = np.zeros(_size, dtype=_dtype) 69 | if out_dict[k].shape != _size: 70 | assert "agent" in k 71 | out_dict[k] = np.ones(_size, dtype=out_dict[k].dtype) 72 | return out_dict 73 | 74 | 75 | class DataH5av2(LightningDataModule): 76 | def __init__( 77 | self, 78 | data_dir: str, 79 | filename_train: str = "training", 80 | filename_val: str = "validation", 81 | filename_test: str = "testing", 82 | batch_size: int = 3, 83 | num_workers: int = 4, 84 | n_agent: int = 64, # if not the same as h5 dataset, use dummy agents, for scalability tests. 85 | ) -> None: 86 | super().__init__() 87 | self.interactive_challenge = False 88 | 89 | self.path_train_h5 = f"{data_dir}/{filename_train}.h5" 90 | self.path_val_h5 = f"{data_dir}/{filename_val}.h5" 91 | self.path_test_h5 = f"{data_dir}/{filename_test}.h5" 92 | self.batch_size = batch_size 93 | self.num_workers = num_workers 94 | 95 | n_step = 110 96 | n_step_history = 50 97 | n_agent_no_sim = 256 98 | n_pl = 1024 99 | n_pl_node = 20 100 | 101 | n_tl = 1 102 | n_tl_stop = 1 103 | self.tensor_size_train = { 104 | # agent states 105 | "agent/valid": (n_step, n_agent), # bool, 106 | "agent/pos": (n_step, n_agent, 2), # float32 107 | # v[1] = p[1]-p[0]. if p[1] invalid, v[1] also invalid, v[2]=v[3] 108 | "agent/vel": (n_step, n_agent, 2), # float32, v_x, v_y 109 | "agent/spd": (n_step, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy 110 | "agent/acc": (n_step, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt 111 | "agent/yaw_bbox": (n_step, n_agent, 1), # float32, yaw of the bbox heading 112 | "agent/yaw_rate": (n_step, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt 113 | # agent attributes 114 | "agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 115 | "agent/cmd": (n_agent, 8), # bool one_hot 116 | "agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2] 117 | "agent/size": (n_agent, 3), # float32: [length, width, height] 118 | "agent/goal": (n_agent, 4), # float32: [x, y, theta, v] 119 | "agent/dest": (n_agent,), # int64: index to map n_pl 120 | # map polylines 121 | "map/valid": (n_pl, n_pl_node), # bool 122 | "map/type": (n_pl, 11), # bool one_hot 123 | "map/pos": (n_pl, n_pl_node, 2), # float32 124 | "map/dir": (n_pl, n_pl_node, 2), # float32 125 | "map/boundary": (4,), # xmin, xmax, ymin, ymax 126 | # dummy traffic lights 127 | "tl_lane/valid": (n_step, n_tl), # bool 128 | "tl_lane/state": (n_step, n_tl, 5), # bool one_hot 129 | "tl_lane/idx": (n_step, n_tl), # int, -1 means not valid 130 | "tl_stop/valid": (n_step, n_tl_stop), # bool 131 | "tl_stop/state": (n_step, n_tl_stop, 5), # bool one_hot 132 | "tl_stop/pos": (n_step, n_tl_stop, 2), # x,y 133 | "tl_stop/dir": (n_step, n_tl_stop, 2), # x,y 134 | } 135 | 136 | self.tensor_size_test = { 137 | # object_id for waymo metrics 138 | "history/agent/object_id": (n_agent,), 139 | "history/agent_no_sim/object_id": (n_agent_no_sim,), 140 | # agent_sim 141 | "history/agent/valid": (n_step_history, n_agent), # bool, 142 | "history/agent/pos": (n_step_history, n_agent, 2), # float32 143 | "history/agent/vel": (n_step_history, n_agent, 2), # float32, v_x, v_y 144 | "history/agent/spd": (n_step_history, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy 145 | "history/agent/acc": (n_step_history, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt 146 | "history/agent/yaw_bbox": (n_step_history, n_agent, 1), # float32, yaw of the bbox heading 147 | "history/agent/yaw_rate": (n_step_history, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt 148 | "history/agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 149 | "history/agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2] 150 | "history/agent/size": (n_agent, 3), # float32: [length, width, height] 151 | # agent_no_sim not used by the models currently 152 | "history/agent_no_sim/valid": (n_step_history, n_agent_no_sim), 153 | "history/agent_no_sim/pos": (n_step_history, n_agent_no_sim, 2), 154 | "history/agent_no_sim/vel": (n_step_history, n_agent_no_sim, 2), 155 | "history/agent_no_sim/spd": (n_step_history, n_agent_no_sim, 1), 156 | "history/agent_no_sim/yaw_bbox": (n_step_history, n_agent_no_sim, 1), 157 | "history/agent_no_sim/type": (n_agent_no_sim, 3), 158 | "history/agent_no_sim/size": (n_agent_no_sim, 3), 159 | # map 160 | "map/valid": (n_pl, n_pl_node), # bool 161 | "map/type": (n_pl, 11), # bool one_hot 162 | "map/pos": (n_pl, n_pl_node, 2), # float32 163 | "map/dir": (n_pl, n_pl_node, 2), # float32 164 | "map/boundary": (4,), # xmin, xmax, ymin, ymax 165 | # dummy traffic_light 166 | "history/tl_lane/valid": (n_step_history, n_tl), # bool 167 | "history/tl_lane/state": (n_step_history, n_tl, 5), # bool one_hot 168 | "history/tl_lane/idx": (n_step_history, n_tl), # int, -1 means not valid 169 | "history/tl_stop/valid": (n_step_history, n_tl_stop), # bool 170 | "history/tl_stop/state": (n_step_history, n_tl_stop, 5), # bool one_hot 171 | "history/tl_stop/pos": (n_step_history, n_tl_stop, 2), # x,y 172 | "history/tl_stop/dir": (n_step_history, n_tl_stop, 2), # dx,dy 173 | } 174 | 175 | self.tensor_size_val = { 176 | "agent/object_id": (n_agent,), 177 | "agent_no_sim/object_id": (n_agent_no_sim,), 178 | # agent_no_sim 179 | "agent_no_sim/valid": (n_step, n_agent_no_sim), # bool, 180 | "agent_no_sim/pos": (n_step, n_agent_no_sim, 2), # float32 181 | "agent_no_sim/vel": (n_step, n_agent_no_sim, 2), # float32, v_x, v_y 182 | "agent_no_sim/spd": (n_step, n_agent_no_sim, 1), # norm of vel, signed using yaw_bbox and vel_xy 183 | "agent_no_sim/yaw_bbox": (n_step, n_agent_no_sim, 1), # float32, yaw of the bbox heading 184 | "agent_no_sim/type": (n_agent_no_sim, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2] 185 | "agent_no_sim/size": (n_agent_no_sim, 3), # float32: [length, width, height] 186 | } 187 | 188 | self.tensor_size_val = self.tensor_size_val | self.tensor_size_train | self.tensor_size_test 189 | 190 | def setup(self, stage: Optional[str] = None) -> None: 191 | if stage == "fit" or stage is None: 192 | self.train_dataset = DatasetTrain(self.path_train_h5, self.tensor_size_train) 193 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val) 194 | elif stage == "validate": 195 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val) 196 | elif stage == "test": 197 | self.test_dataset = DatasetVal(self.path_test_h5, self.tensor_size_test) 198 | 199 | def train_dataloader(self) -> DataLoader[Any]: 200 | return self._get_dataloader(self.train_dataset, self.batch_size, self.num_workers) 201 | 202 | def val_dataloader(self) -> DataLoader[Any]: 203 | return self._get_dataloader(self.val_dataset, self.batch_size, self.num_workers) 204 | 205 | def test_dataloader(self) -> DataLoader[Any]: 206 | return self._get_dataloader(self.test_dataset, self.batch_size, self.num_workers) 207 | 208 | @staticmethod 209 | def _get_dataloader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader[Any]: 210 | return DataLoader( 211 | ds, 212 | batch_size=batch_size, 213 | num_workers=num_workers, 214 | pin_memory=True, 215 | shuffle=False, 216 | drop_last=False, 217 | persistent_workers=True, 218 | ) 219 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: hptr 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - absl-py=1.4.0=pyhd8ed1ab_0 11 | - aiohttp=3.7.0=py39h07f9747_0 12 | - antlr-python-runtime=4.8=pyhd8ed1ab_3 13 | - anyio=3.5.0=py39h06a4308_0 14 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 15 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0 16 | - asttokens=2.0.5=pyhd3eb1b0_0 17 | - async-timeout=3.0.1=py_1000 18 | - attrs=22.1.0=pyh71513ae_1 19 | - babel=2.9.1=pyhd3eb1b0_0 20 | - backcall=0.2.0=pyhd3eb1b0_0 21 | - beautifulsoup4=4.11.1=py39h06a4308_0 22 | - blas=1.0=mkl 23 | - bleach=4.1.0=pyhd3eb1b0_0 24 | - blinker=1.5=pyhd8ed1ab_0 25 | - brotli=1.0.9=h5eee18b_7 26 | - brotli-bin=1.0.9=h5eee18b_7 27 | - brotlipy=0.7.0=py39h27cfd23_1003 28 | - bzip2=1.0.8=h7b6447c_0 29 | - c-ares=1.18.1=h7f98852_0 30 | - ca-certificates=2022.9.24=ha878542_0 31 | - cachetools=5.2.0=pyhd8ed1ab_0 32 | - certifi=2022.9.24=pyhd8ed1ab_0 33 | - cffi=1.15.1=py39h5eee18b_2 34 | - chardet=3.0.4=py39h079e4ff_1008 35 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 36 | - colorama=0.4.6=pyhd8ed1ab_0 37 | - cryptography=38.0.1=py39h9ce1e76_0 38 | - cuda=11.7.1=0 39 | - cuda-cccl=11.7.91=0 40 | - cuda-command-line-tools=11.7.1=0 41 | - cuda-compiler=11.7.1=0 42 | - cuda-cudart=11.7.99=0 43 | - cuda-cudart-dev=11.7.99=0 44 | - cuda-cuobjdump=11.7.91=0 45 | - cuda-cupti=11.7.101=0 46 | - cuda-cuxxfilt=11.7.91=0 47 | - cuda-demo-suite=11.8.86=0 48 | - cuda-documentation=11.8.86=0 49 | - cuda-driver-dev=11.7.99=0 50 | - cuda-gdb=11.8.86=0 51 | - cuda-libraries=11.7.1=0 52 | - cuda-libraries-dev=11.7.1=0 53 | - cuda-memcheck=11.8.86=0 54 | - cuda-nsight=11.8.86=0 55 | - cuda-nsight-compute=11.8.0=0 56 | - cuda-nvcc=11.7.99=0 57 | - cuda-nvdisasm=11.8.86=0 58 | - cuda-nvml-dev=11.7.91=0 59 | - cuda-nvprof=11.8.87=0 60 | - cuda-nvprune=11.7.91=0 61 | - cuda-nvrtc=11.7.99=0 62 | - cuda-nvrtc-dev=11.7.99=0 63 | - cuda-nvtx=11.7.91=0 64 | - cuda-nvvp=11.8.87=0 65 | - cuda-runtime=11.7.1=0 66 | - cuda-sanitizer-api=11.8.86=0 67 | - cuda-toolkit=11.7.1=0 68 | - cuda-tools=11.7.1=0 69 | - cuda-visual-tools=11.7.1=0 70 | - cudatoolkit=11.3.1=h2bc3f7f_2 71 | - cudnn=8.2.1=cuda11.3_0 72 | - cycler=0.11.0=pyhd3eb1b0_0 73 | - dbus=1.13.18=hb2f20db_0 74 | - debugpy=1.5.1=py39h295c915_0 75 | - decorator=5.1.1=pyhd3eb1b0_0 76 | - defusedxml=0.7.1=pyhd3eb1b0_0 77 | - docker-pycreds=0.4.0=py_0 78 | - entrypoints=0.4=py39h06a4308_0 79 | - executing=0.8.3=pyhd3eb1b0_0 80 | - expat=2.4.9=h6a678d5_0 81 | - ffmpeg=4.3.2=hca11adc_0 82 | - fftw=3.3.9=h27cfd23_1 83 | - flit-core=3.6.0=pyhd3eb1b0_0 84 | - fontconfig=2.14.1=hef1e5e3_0 85 | - fonttools=4.25.0=pyhd3eb1b0_0 86 | - freetype=2.12.1=h4a9f257_0 87 | - fsspec=2022.11.0=pyhd8ed1ab_0 88 | - future=0.18.2=pyhd8ed1ab_6 89 | - gds-tools=1.4.0.31=0 90 | - giflib=5.2.1=h7b6447c_0 91 | - gitdb=4.0.10=pyhd8ed1ab_0 92 | - gitpython=3.1.29=pyhd8ed1ab_0 93 | - glib=2.69.1=he621ea3_2 94 | - gmp=6.2.1=h295c915_3 95 | - gnutls=3.6.15=he1e5248_0 96 | - google-auth=2.15.0=pyh1a96a4e_0 97 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 98 | - grpcio=1.38.1=py39hff7568b_0 99 | - gst-plugins-base=1.14.0=h8213a91_2 100 | - gstreamer=1.14.0=h28cd5cc_2 101 | - hdf5=1.10.6=h3ffc7dd_1 102 | - hydra-core=1.1.1=pyhd8ed1ab_0 103 | - icu=58.2=he6710b0_3 104 | - idna=3.4=py39h06a4308_0 105 | - importlib-metadata=5.1.0=pyha770c72_0 106 | - importlib_resources=5.10.1=pyhd8ed1ab_0 107 | - intel-openmp=2021.4.0=h06a4308_3561 108 | - ipykernel=6.15.2=py39h06a4308_0 109 | - ipython=8.6.0=py39h06a4308_0 110 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 111 | - ipywidgets=7.6.5=pyhd3eb1b0_1 112 | - jedi=0.18.1=py39h06a4308_1 113 | - jinja2=3.1.2=py39h06a4308_0 114 | - joblib=1.1.1=py39h06a4308_0 115 | - jpeg=9e=h7f8727e_0 116 | - json5=0.9.6=pyhd3eb1b0_0 117 | - jsonschema=4.16.0=py39h06a4308_0 118 | - jupyter=1.0.0=py39h06a4308_8 119 | - jupyter_client=7.4.7=py39h06a4308_0 120 | - jupyter_console=6.4.3=pyhd3eb1b0_0 121 | - jupyter_core=4.11.2=py39h06a4308_0 122 | - jupyter_server=1.18.1=py39h06a4308_0 123 | - jupyterlab=3.5.0=py39h06a4308_0 124 | - jupyterlab_pygments=0.1.2=py_0 125 | - jupyterlab_server=2.16.3=py39h06a4308_0 126 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 127 | - kiwisolver=1.4.2=py39h295c915_0 128 | - krb5=1.19.2=hac12032_0 129 | - lame=3.100=h7b6447c_0 130 | - lcms2=2.12=h3be6417_0 131 | - ld_impl_linux-64=2.38=h1181459_1 132 | - lerc=3.0=h295c915_0 133 | - libblas=3.9.0=12_linux64_mkl 134 | - libbrotlicommon=1.0.9=h5eee18b_7 135 | - libbrotlidec=1.0.9=h5eee18b_7 136 | - libbrotlienc=1.0.9=h5eee18b_7 137 | - libcblas=3.9.0=12_linux64_mkl 138 | - libclang=10.0.1=default_hb85057a_2 139 | - libcublas=11.11.3.6=0 140 | - libcublas-dev=11.11.3.6=0 141 | - libcufft=10.9.0.58=0 142 | - libcufft-dev=10.9.0.58=0 143 | - libcufile=1.4.0.31=0 144 | - libcufile-dev=1.4.0.31=0 145 | - libcurand=10.3.0.86=0 146 | - libcurand-dev=10.3.0.86=0 147 | - libcusolver=11.4.1.48=0 148 | - libcusolver-dev=11.4.1.48=0 149 | - libcusparse=11.7.5.86=0 150 | - libcusparse-dev=11.7.5.86=0 151 | - libdeflate=1.8=h7f8727e_5 152 | - libedit=3.1.20210910=h7f8727e_0 153 | - libevent=2.1.12=h8f2d780_0 154 | - libffi=3.4.2=h6a678d5_6 155 | - libgcc-ng=11.2.0=h1234567_1 156 | - libgfortran-ng=11.2.0=h00389a5_1 157 | - libgfortran5=11.2.0=h1234567_1 158 | - libgomp=11.2.0=h1234567_1 159 | - libiconv=1.16=h7f8727e_2 160 | - libidn2=2.3.2=h7f8727e_0 161 | - liblapack=3.9.0=12_linux64_mkl 162 | - libllvm10=10.0.1=hbcb73fb_5 163 | - libnpp=11.8.0.86=0 164 | - libnpp-dev=11.8.0.86=0 165 | - libnvjpeg=11.9.0.86=0 166 | - libnvjpeg-dev=11.9.0.86=0 167 | - libpng=1.6.37=hbc83047_0 168 | - libpq=12.9=h16c4e8d_3 169 | - libprotobuf=3.20.1=h4ff587b_0 170 | - libsodium=1.0.18=h7b6447c_0 171 | - libstdcxx-ng=11.2.0=h1234567_1 172 | - libtasn1=4.16.0=h27cfd23_0 173 | - libtiff=4.4.0=hecacb30_2 174 | - libunistring=0.9.10=h27cfd23_0 175 | - libwebp=1.2.4=h11a3e52_0 176 | - libwebp-base=1.2.4=h5eee18b_0 177 | - libxcb=1.15=h7f8727e_0 178 | - libxkbcommon=1.0.1=hfa300c1_0 179 | - libxml2=2.9.14=h74e7548_0 180 | - libxslt=1.1.35=h4e12654_0 181 | - lxml=4.9.1=py39h1edc446_0 182 | - lz4-c=1.9.3=h295c915_1 183 | - markdown=3.4.1=pyhd8ed1ab_0 184 | - markupsafe=2.1.1=py39hb9d737c_1 185 | - matplotlib=3.5.3=py39h06a4308_0 186 | - matplotlib-base=3.5.3=py39hf590b9c_0 187 | - matplotlib-inline=0.1.6=py39h06a4308_0 188 | - mistune=0.8.4=py39h27cfd23_1000 189 | - mkl=2021.4.0=h06a4308_640 190 | - mkl-service=2.4.0=py39h7f8727e_0 191 | - mkl_fft=1.3.1=py39hd3c417c_0 192 | - mkl_random=1.2.2=py39h51133e4_0 193 | - multidict=6.0.2=py39hb9d737c_1 194 | - munkres=1.1.4=py_0 195 | - nbclassic=0.4.8=py39h06a4308_0 196 | - nbclient=0.5.13=py39h06a4308_0 197 | - nbconvert=6.5.4=py39h06a4308_0 198 | - nbformat=5.7.0=py39h06a4308_0 199 | - ncurses=6.3=h5eee18b_3 200 | - nest-asyncio=1.5.5=py39h06a4308_0 201 | - nettle=3.7.3=hbbd107a_1 202 | - notebook=6.5.2=py39h06a4308_0 203 | - notebook-shim=0.2.2=py39h06a4308_0 204 | - nsight-compute=2022.3.0.22=0 205 | - nspr=4.33=h295c915_0 206 | - nss=3.74=h0370c37_0 207 | - oauthlib=3.2.2=pyhd8ed1ab_0 208 | - omegaconf=2.1.1=py39hf3d152e_1 209 | - openh264=2.1.1=h4ff587b_0 210 | - openssl=1.1.1s=h7f8727e_0 211 | - packaging=21.3=pyhd8ed1ab_0 212 | - pandocfilters=1.5.0=pyhd3eb1b0_0 213 | - parso=0.8.3=pyhd3eb1b0_0 214 | - pathtools=0.1.2=py_1 215 | - pcre=8.45=h295c915_0 216 | - pexpect=4.8.0=pyhd3eb1b0_3 217 | - pickleshare=0.7.5=pyhd3eb1b0_1003 218 | - pillow=9.2.0=py39hace64e9_1 219 | - pip=22.2.2=py39h06a4308_0 220 | - ply=3.11=py39h06a4308_0 221 | - prometheus_client=0.14.1=py39h06a4308_0 222 | - promise=2.3=py39hf3d152e_7 223 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 224 | - prompt_toolkit=3.0.20=hd3eb1b0_0 225 | - protobuf=3.20.1=py39h295c915_0 226 | - psutil=5.9.0=py39h5eee18b_0 227 | - ptyprocess=0.7.0=pyhd3eb1b0_2 228 | - pure_eval=0.2.2=pyhd3eb1b0_0 229 | - pyasn1=0.4.8=py_0 230 | - pyasn1-modules=0.2.7=py_0 231 | - pycparser=2.21=pyhd3eb1b0_0 232 | - pydeprecate=0.3.2=pyhd8ed1ab_0 233 | - pygments=2.11.2=pyhd3eb1b0_0 234 | - pyjwt=2.6.0=pyhd8ed1ab_0 235 | - pyopenssl=22.0.0=pyhd3eb1b0_0 236 | - pyparsing=3.0.9=pyhd8ed1ab_0 237 | - pyqt=5.15.7=py39h6a678d5_1 238 | - pyqt5-sip=12.11.0=py39h6a678d5_1 239 | - pyrsistent=0.18.0=py39heee7806_0 240 | - pysocks=1.7.1=py39h06a4308_0 241 | - python=3.9.15=h7a1cb2a_2 242 | - python-dateutil=2.8.2=pyhd3eb1b0_0 243 | - python-fastjsonschema=2.16.2=py39h06a4308_0 244 | - python_abi=3.9=2_cp39 245 | - pytorch=1.13.0=py3.9_cuda11.7_cudnn8.5.0_0 246 | - pytorch-cuda=11.7=h67b0de4_0 247 | - pytorch-lightning=1.5.10=pyhd8ed1ab_0 248 | - pytorch-mutex=1.0=cuda 249 | - pytz=2022.1=py39h06a4308_0 250 | - pyu2f=0.1.5=pyhd8ed1ab_0 251 | - pyyaml=6.0=py39hb9d737c_4 252 | - pyzmq=23.2.0=py39h6a678d5_0 253 | - qt-main=5.15.2=h327a75a_7 254 | - qt-webengine=5.15.9=hd2b0992_4 255 | - qtconsole=5.3.2=py39h06a4308_0 256 | - qtpy=2.2.0=py39h06a4308_0 257 | - qtwebkit=5.212=h4eab89a_4 258 | - readline=8.2=h5eee18b_0 259 | - requests=2.28.1=py39h06a4308_0 260 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 261 | - rsa=4.9=pyhd8ed1ab_0 262 | - scikit-learn=1.1.3=py39h6a678d5_0 263 | - scipy=1.9.3=py39h14f4228_0 264 | - send2trash=1.8.0=pyhd3eb1b0_1 265 | - sentry-sdk=1.11.1=pyhd8ed1ab_0 266 | - setproctitle=1.1.10=py39h3811e60_1004 267 | - setuptools=59.5.0=py39hf3d152e_0 268 | - shortuuid=1.0.11=pyhd8ed1ab_0 269 | - sip=6.6.2=py39h6a678d5_0 270 | - smmap=3.0.5=pyh44b312d_0 271 | - sniffio=1.2.0=py39h06a4308_1 272 | - soupsieve=2.3.2.post1=py39h06a4308_0 273 | - sqlite=3.40.0=h5082296_0 274 | - stack_data=0.2.0=pyhd3eb1b0_0 275 | - tensorboard=2.11.0=pyhd8ed1ab_0 276 | - tensorboard-data-server=0.6.0=py39hd97740a_2 277 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 278 | - terminado=0.13.1=py39h06a4308_0 279 | - threadpoolctl=2.2.0=pyh0d69192_0 280 | - tinycss2=1.2.1=py39h06a4308_0 281 | - tk=8.6.12=h1ccaba5_0 282 | - toml=0.10.2=pyhd3eb1b0_0 283 | - tomli=2.0.1=py39h06a4308_0 284 | - torchaudio=0.13.0=py39_cu117 285 | - torchmetrics=0.11.0=pyhd8ed1ab_0 286 | - torchvision=0.14.0=py39_cu117 287 | - tornado=6.2=py39h5eee18b_0 288 | - tqdm=4.64.1=py39h06a4308_0 289 | - traitlets=5.1.1=pyhd3eb1b0_0 290 | - transforms3d=0.4.1=pyhd8ed1ab_0 291 | - tzdata=2022g=h04d1e81_0 292 | - urllib3=1.26.12=py39h06a4308_0 293 | - wandb=0.13.6=pyhd8ed1ab_0 294 | - wcwidth=0.2.5=pyhd3eb1b0_0 295 | - webencodings=0.5.1=py39h06a4308_1 296 | - websocket-client=0.58.0=py39h06a4308_4 297 | - werkzeug=2.2.2=pyhd8ed1ab_0 298 | - wheel=0.37.1=pyhd3eb1b0_0 299 | - widgetsnbextension=3.5.2=py39h06a4308_0 300 | - x264=1!161.3030=h7f98852_1 301 | - xz=5.2.8=h5eee18b_0 302 | - yaml=0.2.5=h7f98852_2 303 | - yarl=1.7.2=py39hb9d737c_2 304 | - zeromq=4.3.4=h2531618_0 305 | - zipp=3.11.0=pyhd8ed1ab_0 306 | - zlib=1.2.13=h5eee18b_0 307 | - zstd=1.5.2=ha4553b6_0 308 | - pip: 309 | - appdirs==1.4.4 310 | - astunparse==1.6.3 311 | - black==20.8b0 312 | - clang==5.0 313 | - click==7.1.2 314 | - cloudpickle==2.2.0 315 | - flatbuffers==23.5.26 316 | - gast==0.4.0 317 | - google-pasta==0.2.0 318 | - gym==0.21.0 319 | - h5py==3.1.0 320 | - immutabledict==2.2.0 321 | - keras==2.11.0 322 | - keras-preprocessing==1.1.2 323 | - mypy-extensions==0.4.3 324 | - numpy==1.21.5 325 | - opencv-python==4.6.0.66 326 | - opt-einsum==3.3.0 327 | - pandas==1.5.3 328 | - pathspec==0.10.2 329 | - plotly==5.13.1 330 | - pyarrow==10.0.0 331 | - regex==2022.10.31 332 | - six==1.15.0 333 | - tenacity==8.2.2 334 | - tensorflow-cpu==2.11.0 335 | - tensorflow_probability==0.19.0 336 | - termcolor==1.1.0 337 | - typed-ast==1.5.4 338 | - typing-extensions==3.7.4.3 339 | - wrapt==1.12.1 340 | -------------------------------------------------------------------------------- /src/models/metrics/nll.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import torch 3 | from omegaconf import ListConfig 4 | from typing import Dict, Optional 5 | from torch import Tensor, tensor 6 | from torch.nn import functional as F 7 | from torchmetrics.metric import Metric 8 | from torch.distributions import MultivariateNormal 9 | 10 | 11 | def compute_nll_mtr(dmean: Tensor, cov: Tensor) -> Tensor: 12 | dx = dmean[..., 0] 13 | dy = dmean[..., 1] 14 | sx = cov[..., 0, 0] 15 | sy = cov[..., 1, 1] 16 | rho = torch.tanh(cov[..., 1, 0]) # mtr uses clamp to [-0.5, 0.5] 17 | one_minus_rho2 = 1 - rho ** 2 18 | log_prob = ( 19 | torch.log(sx) 20 | + torch.log(sy) 21 | + 0.5 * torch.log(one_minus_rho2) 22 | + 0.5 / one_minus_rho2 * ((dx / sx) ** 2 + (dy / sy) ** 2 - 2 * rho * dx * dy / (sx * sy)) 23 | ) 24 | return log_prob 25 | 26 | 27 | class NllMetrics(Metric): 28 | full_state_update = False 29 | 30 | def __init__( 31 | self, 32 | prefix: str, 33 | winner_takes_all: str, 34 | p_rand_train_agent: float, 35 | n_decoders: int, 36 | n_pred: int, 37 | l_pos: str, 38 | n_step_add_train_agent: ListConfig, 39 | focal_gamma_conf: ListConfig, 40 | w_conf: ListConfig, 41 | w_pos: ListConfig, 42 | w_yaw: ListConfig, # cos 43 | w_spd: ListConfig, # huber 44 | w_vel: ListConfig, # huber 45 | ) -> None: 46 | super().__init__(dist_sync_on_step=False) 47 | self.prefix = prefix 48 | self.winner_takes_all = winner_takes_all 49 | self.p_rand_train_agent = p_rand_train_agent 50 | self.n_decoders = n_decoders 51 | self.n_pred = n_pred 52 | self.l_pos = l_pos 53 | self.n_step_add_train_agent = n_step_add_train_agent 54 | self.focal_gamma_conf = list(focal_gamma_conf) 55 | self.w_conf = list(w_conf) 56 | self.w_pos = list(w_pos) 57 | self.w_yaw = list(w_yaw) 58 | self.w_spd = list(w_spd) 59 | self.w_vel = list(w_vel) 60 | 61 | self.add_state("counter_traj", default=tensor(0.0), dist_reduce_fx="sum") 62 | self.add_state("counter_conf", default=tensor(0.0), dist_reduce_fx="sum") 63 | self.add_state("error_pos", default=tensor(0.0), dist_reduce_fx="sum") 64 | self.add_state("error_conf", default=tensor(0.0), dist_reduce_fx="sum") 65 | self.add_state("error_yaw", default=tensor(0.0), dist_reduce_fx="sum") 66 | self.add_state("error_spd", default=tensor(0.0), dist_reduce_fx="sum") 67 | self.add_state("error_vel", default=tensor(0.0), dist_reduce_fx="sum") 68 | 69 | for i in range(self.n_decoders): 70 | for j in range(self.n_pred): 71 | self.add_state(f"counter_d{i}_p{j}", default=tensor(0.0), dist_reduce_fx="sum") 72 | self.add_state(f"conf_d{i}_p{j}", default=tensor(0.0), dist_reduce_fx="sum") 73 | 74 | def update( 75 | self, 76 | pred_valid: Tensor, 77 | pred_conf: Tensor, 78 | pred_pos: Tensor, 79 | pred_spd: Optional[Tensor], 80 | pred_vel: Optional[Tensor], 81 | pred_yaw_bbox: Optional[Tensor], 82 | pred_cov: Optional[Tensor], 83 | ref_role: Tensor, 84 | ref_type: Tensor, 85 | gt_valid: Tensor, 86 | gt_pos: Tensor, 87 | gt_spd: Tensor, 88 | gt_vel: Tensor, 89 | gt_yaw_bbox: Tensor, 90 | gt_cmd: Tensor, 91 | **kwargs, 92 | ) -> None: 93 | """ 94 | Args: 95 | pred_valid: [n_scene, n_agent], bool 96 | pred_conf: [n_decoder, n_scene, n_agent, n_pred], not normalized! 97 | pred_pos: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2] 98 | pred_spd: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 1] 99 | pred_vel: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2] 100 | pred_yaw_bbox: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 1] 101 | pred_cov: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2, 2] 102 | gt_valid: [n_scene, n_agent, n_step_future], bool 103 | gt_pos: [n_scene, n_agent, n_step_future, 2] 104 | gt_spd: [n_scene, n_agent, n_step_future, 1] 105 | gt_vel: [n_scene, n_agent, n_step_future, 2] 106 | gt_yaw_bbox: [n_scene, n_agent, n_step_future, 1] 107 | ref_role: [n_scene, n_agent, 3], one hot bool [sdc=0, interest=1, predict=2] 108 | ref_type: [n_scene, n_agent, 3], one hot bool [veh=0, ped=1, cyc=2] 109 | agent_cmd: [n_scene, n_agent, 8], one hot bool 110 | """ 111 | n_agent_type = ref_type.shape[-1] 112 | n_decoder, n_scene, n_agent, n_pred = pred_conf.shape 113 | assert (ref_role.any(-1) & pred_valid == ref_role.any(-1)).all(), "All relevat agents shall be predicted!" 114 | 115 | # ! prepare avails 116 | avails = ref_role.any(-1) # [n_scene, n_agent] 117 | # add rand agents for training 118 | if self.p_rand_train_agent > 0: 119 | avails = avails | (torch.bernoulli(self.p_rand_train_agent * torch.ones_like(avails)).bool()) 120 | # add long tracked agents for training 121 | _track_len = gt_valid.sum(-1) # [n_scene, n_agent] 122 | for i in range(n_agent_type): 123 | if self.n_step_add_train_agent[i] > 0: 124 | avails = avails | (ref_type[:, :, i] & (_track_len > self.n_step_add_train_agent[i])) 125 | 126 | avails = gt_valid & avails.unsqueeze(-1) # [n_scene, n_agent, n_step_future] 127 | avails = avails.unsqueeze(0).expand(n_decoder, -1, -1, -1) # [n_decoder, n_scene, n_agent, n_step_future] 128 | if n_decoder > 1: 129 | # [n_decoder], randomly train ensembles with 50% of chance 130 | mask_ensemble = torch.bernoulli(0.5 * torch.ones_like(pred_conf[:, 0, 0, 0])).bool() 131 | # make sure at least one ensemble is trained 132 | if not mask_ensemble.any(): 133 | mask_ensemble[torch.randint(0, n_decoder, (1,))] |= True 134 | avails = avails & mask_ensemble[:, None, None, None] 135 | # [n_decoder, n_scene, n_agent, n_pred, n_step_future] 136 | avails = avails.unsqueeze(3).expand(-1, -1, -1, n_pred, -1) 137 | 138 | # ! normalize pred_conf 139 | # [n_decoder, n_scene, n_agent, n_pred], per ensemble 140 | pred_conf = torch.softmax(pred_conf, dim=-1) 141 | 142 | # ! save conf histogram 143 | _prob = pred_conf.masked_fill(~(pred_valid[None, :, :, None]), 0.0) 144 | for i in range(self.n_decoders): 145 | for j in range(self.n_pred): 146 | x = getattr(self, f"conf_d{i}_p{j}") 147 | x += (_prob[i, :, :, j] * (avails[i, :, :, j].any(-1))).sum() 148 | 149 | # ! winnter takes all 150 | with torch.no_grad(): 151 | decoder_idx = torch.arange(n_decoder)[:, None, None, None] # [n_decoder, 1, 1, 1] 152 | scene_idx = torch.arange(n_scene)[None, :, None, None] # [1, n_scene, 1, 1] 153 | agent_idx = torch.arange(n_agent)[None, None, :, None] # [1, 1, n_agent, 1] 154 | 155 | if "hard" in self.winner_takes_all: 156 | # [n_decoder, n_scene, n_agent, n_pred, n_step_future] 157 | dist = torch.norm(pred_pos - gt_pos[None, :, :, None, :, :], dim=-1) 158 | dist = dist.masked_fill(~avails, 0.0).sum(-1) # [n_decoder, n_scene, n_agent, n_pred] 159 | if "joint" in self.winner_takes_all: 160 | dist = dist.sum(2, keepdim=True) # [n_decoder, n_scene, 1, n_pred] 161 | k_top = int(self.winner_takes_all[-1]) 162 | i = torch.randint(high=k_top, size=()) 163 | # [n_decoder, n_scene, n_agent, 1] 164 | mode_idx = dist.topk(k_top, dim=-1, largest=False, sorted=False)[1][..., [i]] 165 | elif self.winner_takes_all == "cmd": 166 | assert n_pred == gt_cmd.shape[-1] 167 | mode_idx = (gt_cmd + 0.0).argmax(-1, keepdim=True) # [n_scene, n_agent, 1] 168 | mode_idx = mode_idx.unsqueeze(0).expand(n_decoder, -1, -1, -1) # [n_decoder, n_scene, n_agent, 1] 169 | 170 | # ! save hard assignment histogram: [n_decoder, n_scene, n_agent, n_pred] 171 | counter_modes = torch.nn.functional.one_hot(mode_idx.squeeze(-1), self.n_pred) 172 | for i in range(self.n_decoders): 173 | for j in range(self.n_pred): 174 | x = getattr(self, f"counter_d{i}_p{j}") 175 | x += (counter_modes[i, :, :, j] * (avails[i, :, :, j].any(-1))).sum() 176 | 177 | # ! avails and counter 178 | # avails: [n_decoder, n_scene, n_agent, n_pred, n_step_future] 179 | avails = avails[decoder_idx, scene_idx, agent_idx, mode_idx] 180 | self.counter_traj += avails.sum() 181 | self.counter_conf += avails[:, :, :, 0, :].any(-1).sum() 182 | 183 | # ! prepare agent dependent loss weights 184 | focal_gamma_conf, w_conf, w_pos, w_yaw, w_spd, w_vel = 0, 0, 0, 0, 0, 0 185 | for i in range(n_agent_type): # [n_scene, n_agent] 186 | focal_gamma_conf += ref_type[:, :, i] * self.focal_gamma_conf[i] 187 | w_conf += ref_type[:, :, i] * self.w_conf[i] 188 | w_pos += ref_type[:, :, i] * self.w_pos[i] 189 | w_yaw += ref_type[:, :, i] * self.w_yaw[i] 190 | w_spd += ref_type[:, :, i] * self.w_spd[i] 191 | w_vel += ref_type[:, :, i] * self.w_vel[i] 192 | 193 | # ! error_conf 194 | # pred_conf: [n_decoder, n_scene, n_agent, n_pred], not normalized! 195 | pred_conf = pred_conf[decoder_idx, scene_idx, agent_idx, mode_idx] 196 | focal_gamma_conf = torch.pow(1 - pred_conf, focal_gamma_conf[None, :, :, None]) 197 | w_conf = w_conf[None, :, :, None] 198 | self.error_conf += (-torch.log(pred_conf) * w_conf * focal_gamma_conf).masked_fill(~(avails.any(-1)), 0.0).sum() 199 | 200 | # ! error_pos 201 | pred_pos = pred_pos[decoder_idx, scene_idx, agent_idx, mode_idx] 202 | if self.l_pos == "huber": 203 | errors_pos = F.huber_loss(pred_pos, gt_pos[None, :, :, None, :, :], reduction="none").sum(-1) 204 | elif self.l_pos == "l2": 205 | errors_pos = torch.norm(pred_pos - gt_pos[None, :, :, None, :, :], p=2, dim=-1) 206 | elif self.l_pos == "nll_mtr": 207 | pred_cov = pred_cov[decoder_idx, scene_idx, agent_idx, mode_idx] 208 | errors_pos = compute_nll_mtr(pred_pos - gt_pos[None, :, :, None, :, :], pred_cov) 209 | elif self.l_pos == "nll_torch": 210 | gmm = MultivariateNormal(pred_pos, scale_tril=pred_cov[decoder_idx, scene_idx, agent_idx, mode_idx]) 211 | errors_pos = -gmm.log_prob(gt_pos[None, :, :, None, :, :]) 212 | self.error_pos += (errors_pos * w_pos[None, :, :, None, None]).masked_fill(~avails, 0.0).sum() 213 | 214 | # ! error_spd 215 | if sum(self.w_spd) > 0 and pred_spd is not None: 216 | pred_spd = pred_spd[decoder_idx, scene_idx, agent_idx, mode_idx] 217 | errors_spd = F.huber_loss(pred_spd, gt_spd[None, :, :, None, :, :], reduction="none").squeeze(-1) 218 | self.error_spd += (errors_spd * w_spd[None, :, :, None, None]).masked_fill(~avails, 0.0).sum() 219 | 220 | # ! error_vel 221 | if sum(self.w_vel) > 0 and pred_vel is not None: 222 | pred_vel = pred_vel[decoder_idx, scene_idx, agent_idx, mode_idx] 223 | errors_vel = F.huber_loss(pred_vel, gt_vel[None, :, :, None, :, :], reduction="none").sum(-1) 224 | self.error_vel += (errors_vel * w_vel[None, :, :, None, None]).masked_fill(~avails, 0.0).sum() 225 | 226 | # ! error_yaw 227 | if sum(self.w_yaw) > 0 and pred_yaw_bbox is not None: 228 | pred_yaw_bbox = pred_yaw_bbox[decoder_idx, scene_idx, agent_idx, mode_idx] 229 | errors_yaw = -torch.cos(pred_yaw_bbox - gt_yaw_bbox[None, :, :, None, :, :]).squeeze(-1) 230 | self.error_yaw += (errors_yaw * w_yaw[None, :, :, None, None]).masked_fill(~avails, 0.0).sum() 231 | 232 | def compute(self) -> Dict[str, Tensor]: 233 | 234 | out_dict = { 235 | f"{self.prefix}/counter_traj": self.counter_traj, 236 | f"{self.prefix}/counter_conf": self.counter_conf, 237 | f"{self.prefix}/error_pos": self.error_pos, 238 | f"{self.prefix}/error_conf": self.error_conf, 239 | f"{self.prefix}/error_yaw": self.error_yaw, 240 | f"{self.prefix}/error_spd": self.error_spd, 241 | f"{self.prefix}/error_vel": self.error_vel, 242 | } 243 | out_dict[f"{self.prefix}/loss"] = ( 244 | self.error_pos + self.error_yaw + self.error_spd + self.error_vel 245 | ) / self.counter_traj + self.error_conf / self.counter_conf 246 | 247 | for i in range(self.n_decoders): 248 | for j in range(self.n_pred): 249 | out_dict[f"{self.prefix}/counter_d{i}_p{j}"] = getattr(self, f"counter_d{i}_p{j}") 250 | out_dict[f"{self.prefix}/conf_d{i}_p{j}"] = getattr(self, f"conf_d{i}_p{j}") 251 | 252 | return out_dict 253 | -------------------------------------------------------------------------------- /src/pack_h5_av2.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | import sys 3 | 4 | sys.path.append(".") 5 | 6 | from argparse import ArgumentParser 7 | from tqdm import tqdm 8 | import h5py 9 | import numpy as np 10 | from pathlib import Path 11 | from av2.datasets.motion_forecasting import scenario_serialization 12 | from av2.map.map_api import ArgoverseStaticMap 13 | from av2.geometry.interpolate import interp_arc 14 | import src.utils.pack_h5 as pack_utils 15 | 16 | # "map/type" 17 | # VEHICLE = 0 18 | # BUS = 1 19 | # BIKE = 2 20 | # UNKNOWN = 3 21 | # DOUBLE_DASH = 4 22 | # DASHED = 5 23 | # SOLID = 6 24 | # DOUBLE_SOLID = 7 25 | # DASH_SOLID = 8 26 | # SOLID_DASH = 9 27 | # CROSSWALK = 10 28 | N_PL_TYPE = 11 29 | PL_TYPE = { 30 | "VEHICLE": 0, 31 | "BUS": 1, 32 | "BIKE": 2, 33 | "NONE": 3, 34 | "UNKNOWN": 3, 35 | "DOUBLE_DASH_WHITE": 4, 36 | "DOUBLE_DASH_YELLOW": 4, 37 | "DASHED_WHITE": 5, 38 | "DASHED_YELLOW": 5, 39 | "SOLID_WHITE": 6, 40 | "SOLID_BLUE": 6, 41 | "SOLID_YELLOW": 6, 42 | "DOUBLE_SOLID_WHITE": 7, 43 | "DOUBLE_SOLID_YELLOW": 7, 44 | "DASH_SOLID_YELLOW": 8, 45 | "DASH_SOLID_WHITE": 8, 46 | "SOLID_DASH_WHITE": 9, 47 | "SOLID_DASH_YELLOW": 9, 48 | "CROSSWALK": 10, 49 | } 50 | DIM_VEH_LANES = [0, 1] 51 | DIM_CYC_LANES = [2] 52 | DIM_PED_LANES = [3, 6, 10] 53 | 54 | # "agent/type" 55 | AGENT_TYPE = { 56 | "vehicle": 0, 57 | "bus": 0, 58 | "pedestrian": 1, 59 | "motorcyclist": 2, 60 | "cyclist": 2, 61 | "riderless_bicycle": 2, 62 | } 63 | AGENT_SIZE = { 64 | "vehicle": [4.7, 2.1, 1.7], 65 | "bus": [11, 3, 3.5], 66 | "pedestrian": [0.85, 0.85, 1.75], 67 | "motorcyclist": [2, 0.8, 1.8], 68 | "cyclist": [2, 0.8, 1.8], 69 | "riderless_bicycle": [2, 0.8, 1.8], 70 | } 71 | 72 | 73 | N_PL_MAX = 1500 74 | N_AGENT_MAX = 256 75 | 76 | N_PL = 1024 77 | N_AGENT = 64 78 | N_AGENT_NO_SIM = N_AGENT_MAX - N_AGENT 79 | 80 | THRESH_MAP = 120 81 | THRESH_AGENT = 120 82 | 83 | STEP_CURRENT = 49 84 | N_STEP = 110 85 | 86 | 87 | def collate_agent_features(scenario_path, n_step): 88 | tracks = scenario_serialization.load_argoverse_scenario_parquet(scenario_path).tracks 89 | 90 | agent_id = [] 91 | agent_type = [] 92 | agent_states = [] 93 | agent_role = [] 94 | for _track in tracks: 95 | if _track.object_type not in AGENT_TYPE: 96 | continue 97 | # role: [sdc=0, interest=1, predict=2] 98 | if _track.track_id == "AV": 99 | agent_id.append(0) 100 | agent_role.append([True, False, False]) 101 | else: 102 | assert int(_track.track_id) > 0 103 | agent_id.append(int(_track.track_id)) 104 | # [sdc=0, interest=1, predict=2] 105 | if _track.category.value == 2: # SCORED_TRACK 106 | agent_role.append([False, True, False]) 107 | elif _track.category.value == 3: # FOCAL_TRACK 108 | agent_role.append([False, False, True]) 109 | else: # TRACK_FRAGMENT or UNSCORED_TRACK 110 | agent_role.append([False, False, False]) 111 | 112 | agent_type.append(AGENT_TYPE[_track.object_type]) 113 | 114 | step_states = [[0.0] * 9 + [False]] * n_step 115 | agent_size = AGENT_SIZE[_track.object_type] 116 | for s in _track.object_states: 117 | step_states[int(s.timestep)] = [ 118 | s.position[0], # center_x 119 | s.position[1], # center_y 120 | 3, # center_z 121 | agent_size[0], # length 122 | agent_size[1], # width 123 | agent_size[2], # height 124 | s.heading, # heading in radian 125 | s.velocity[0], # velocity_x 126 | s.velocity[1], # velocity_y 127 | True, # valid 128 | ] 129 | 130 | agent_states.append(step_states) 131 | 132 | return agent_id, agent_type, agent_states, agent_role 133 | 134 | 135 | def collate_map_features(map_path): 136 | static_map = ArgoverseStaticMap.from_json(map_path) 137 | 138 | def _interpolate_centerline(left_ln_boundary, right_ln_boundary): 139 | num_interp_pts = ( 140 | np.linalg.norm(np.diff(left_ln_boundary, axis=0), axis=-1).sum() 141 | + np.linalg.norm(np.diff(right_ln_boundary, axis=0), axis=-1).sum() 142 | ) / 2.0 143 | num_interp_pts = int(num_interp_pts) + 1 144 | left_even_pts = interp_arc(num_interp_pts, points=left_ln_boundary) 145 | right_even_pts = interp_arc(num_interp_pts, points=right_ln_boundary) 146 | centerline_pts = (left_even_pts + right_even_pts) / 2.0 147 | return centerline_pts, left_even_pts, right_even_pts 148 | 149 | mf_id = [] 150 | mf_xyz = [] 151 | mf_type = [] 152 | mf_edge = [] 153 | lane_boundary_set = [] 154 | 155 | for _id, ped_xing in static_map.vector_pedestrian_crossings.items(): 156 | v0, v1 = ped_xing.edge1.xyz 157 | v2, v3 = ped_xing.edge2.xyz 158 | pl_crosswalk = pack_utils.get_polylines_from_polygon(np.array([v0, v1, v3, v2])) 159 | mf_id.extend([_id] * len(pl_crosswalk)) 160 | mf_type.extend([PL_TYPE["CROSSWALK"]] * len(pl_crosswalk)) 161 | mf_xyz.extend(pl_crosswalk) 162 | 163 | for _id, lane_segment in static_map.vector_lane_segments.items(): 164 | centerline_pts, left_even_pts, right_even_pts = _interpolate_centerline( 165 | lane_segment.left_lane_boundary.xyz, lane_segment.right_lane_boundary.xyz 166 | ) 167 | 168 | mf_id.append(_id) 169 | mf_xyz.append(centerline_pts) 170 | mf_type.append(PL_TYPE[lane_segment.lane_type]) 171 | 172 | if (lane_segment.left_lane_boundary not in lane_boundary_set) and not ( 173 | lane_segment.is_intersection and lane_segment.left_mark_type in ["NONE", "UNKOWN"] 174 | ): 175 | lane_boundary_set.append(lane_segment.left_lane_boundary) 176 | mf_xyz.append(left_even_pts) 177 | mf_id.append(-2) 178 | mf_type.append(PL_TYPE[lane_segment.left_mark_type]) 179 | 180 | if (lane_segment.right_lane_boundary not in lane_boundary_set) and not ( 181 | lane_segment.is_intersection and lane_segment.right_mark_type in ["NONE", "UNKOWN"] 182 | ): 183 | lane_boundary_set.append(lane_segment.right_lane_boundary) 184 | mf_xyz.append(right_even_pts) 185 | mf_id.append(-2) 186 | mf_type.append(PL_TYPE[lane_segment.right_mark_type]) 187 | 188 | for _id_exit in lane_segment.successors: 189 | mf_edge.append([_id, _id_exit]) 190 | else: 191 | mf_edge.append([_id, -1]) 192 | 193 | return mf_id, mf_xyz, mf_type, mf_edge 194 | 195 | 196 | def main(): 197 | parser = ArgumentParser(allow_abbrev=True) 198 | parser.add_argument("--data-dir", default="/cluster/scratch/zhejzhan/av2_motion") 199 | parser.add_argument("--dataset", default="training") 200 | parser.add_argument("--out-dir", default="/cluster/scratch/zhejzhan/h5_av2_hptr") 201 | parser.add_argument("--rand-pos", default=50.0, type=float, help="Meter. Set to -1 to disable.") 202 | parser.add_argument("--rand-yaw", default=3.14, type=float, help="Radian. Set to -1 to disable.") 203 | parser.add_argument("--dest-no-pred", action="store_true") 204 | args = parser.parse_args() 205 | 206 | if "training" in args.dataset: 207 | pack_all = True # ["agent/valid"] 208 | pack_history = False # ["history/agent/valid"] 209 | n_step = N_STEP 210 | elif "validation" in args.dataset: 211 | pack_all = True 212 | pack_history = True 213 | n_step = N_STEP 214 | elif "testing" in args.dataset: 215 | pack_all = False 216 | pack_history = True 217 | n_step = STEP_CURRENT + 1 218 | 219 | out_path = Path(args.out_dir) 220 | out_path.mkdir(exist_ok=True) 221 | out_h5_path = out_path / (args.dataset + ".h5") 222 | 223 | scenario_list = sorted(list((Path(args.data_dir) / args.dataset).glob("*"))) 224 | n_pl_max, n_agent_max, n_agent_sim, n_agent_no_sim = 0, 0, 0, 0 225 | with h5py.File(out_h5_path, "w") as hf: 226 | hf.attrs["data_len"] = len(scenario_list) 227 | for i in tqdm(range(hf.attrs["data_len"])): 228 | scenario_folder = scenario_list[i] 229 | mf_id, mf_xyz, mf_type, mf_edge = collate_map_features( 230 | scenario_folder / f"log_map_archive_{scenario_folder.name}.json" 231 | ) 232 | agent_id, agent_type, agent_states, agent_role = collate_agent_features( 233 | scenario_folder / f"scenario_{scenario_folder.name}.parquet", n_step 234 | ) 235 | 236 | episode = {} 237 | n_pl = pack_utils.pack_episode_map( 238 | episode=episode, mf_id=mf_id, mf_xyz=mf_xyz, mf_type=mf_type, mf_edge=mf_edge, n_pl_max=N_PL_MAX 239 | ) 240 | n_agent = pack_utils.pack_episode_agents( 241 | episode=episode, 242 | agent_id=agent_id, 243 | agent_type=agent_type, 244 | agent_states=agent_states, 245 | agent_role=agent_role, 246 | pack_all=pack_all, 247 | pack_history=pack_history, 248 | n_agent_max=N_AGENT_MAX, 249 | step_current=STEP_CURRENT, 250 | ) 251 | scenario_center, scenario_yaw = pack_utils.center_at_sdc(episode, args.rand_pos, args.rand_yaw) 252 | n_pl_max = max(n_pl_max, n_pl) 253 | n_agent_max = max(n_agent_max, n_agent) 254 | 255 | episode_reduced = {} 256 | pack_utils.filter_episode_map(episode, N_PL, THRESH_MAP, thresh_z=-1) 257 | assert episode["map/valid"].any(1).sum() > 0 258 | pack_utils.repack_episode_map(episode, episode_reduced, N_PL, N_PL_TYPE) 259 | 260 | if "training" in args.dataset: 261 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents( 262 | episode=episode, 263 | episode_reduced=episode_reduced, 264 | n_agent=N_AGENT, 265 | prefix="", 266 | dim_veh_lanes=DIM_VEH_LANES, 267 | dist_thresh_agent=THRESH_AGENT, 268 | step_current=STEP_CURRENT, 269 | ) 270 | pack_utils.repack_episode_agents( 271 | episode=episode, 272 | episode_reduced=episode_reduced, 273 | mask_sim=mask_sim, 274 | n_agent=N_AGENT, 275 | prefix="", 276 | dim_veh_lanes=DIM_VEH_LANES, 277 | dim_cyc_lanes=DIM_CYC_LANES, 278 | dim_ped_lanes=DIM_PED_LANES, 279 | dest_no_pred=args.dest_no_pred, 280 | ) 281 | elif "validation" in args.dataset: 282 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents( 283 | episode=episode, 284 | episode_reduced=episode_reduced, 285 | n_agent=N_AGENT, 286 | prefix="history/", 287 | dim_veh_lanes=DIM_VEH_LANES, 288 | dist_thresh_agent=THRESH_AGENT, 289 | step_current=STEP_CURRENT, 290 | ) 291 | pack_utils.repack_episode_agents( 292 | episode=episode, 293 | episode_reduced=episode_reduced, 294 | mask_sim=mask_sim, 295 | n_agent=N_AGENT, 296 | prefix="", 297 | dim_veh_lanes=DIM_VEH_LANES, 298 | dim_cyc_lanes=DIM_CYC_LANES, 299 | dim_ped_lanes=DIM_PED_LANES, 300 | dest_no_pred=args.dest_no_pred, 301 | ) 302 | pack_utils.repack_episode_agents(episode, episode_reduced, mask_sim, N_AGENT, "history/") 303 | pack_utils.repack_episode_agents_no_sim(episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "") 304 | pack_utils.repack_episode_agents_no_sim( 305 | episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "history/" 306 | ) 307 | elif "testing" in args.dataset: 308 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents( 309 | episode=episode, 310 | episode_reduced=episode_reduced, 311 | n_agent=N_AGENT, 312 | prefix="history/", 313 | dim_veh_lanes=DIM_VEH_LANES, 314 | dist_thresh_agent=THRESH_AGENT, 315 | step_current=STEP_CURRENT, 316 | ) 317 | pack_utils.repack_episode_agents(episode, episode_reduced, mask_sim, N_AGENT, "history/") 318 | pack_utils.repack_episode_agents_no_sim( 319 | episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "history/" 320 | ) 321 | n_agent_sim = max(n_agent_sim, mask_sim.sum()) 322 | n_agent_no_sim = max(n_agent_no_sim, mask_no_sim.sum()) 323 | 324 | episode_reduced["map/boundary"] = pack_utils.get_map_boundary( 325 | episode_reduced["map/valid"], episode_reduced["map/pos"] 326 | ) 327 | 328 | hf_episode = hf.create_group(str(i)) 329 | hf_episode.attrs["scenario_id"] = scenario_folder.name 330 | hf_episode.attrs["scenario_center"] = scenario_center 331 | hf_episode.attrs["scenario_yaw"] = scenario_yaw 332 | hf_episode.attrs["with_map"] = True 333 | 334 | for k, v in episode_reduced.items(): 335 | hf_episode.create_dataset(k, data=v, compression="gzip", compression_opts=4, shuffle=True) 336 | 337 | print(f"n_pl_max: {n_pl_max}, n_agent_max: {n_agent_max}") 338 | print(f"n_agent_sim: {n_agent_sim}, n_agent_no_sim: {n_agent_no_sim}") 339 | 340 | 341 | if __name__ == "__main__": 342 | main() 343 | -------------------------------------------------------------------------------- /src/data_modules/sc_global.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Dict 3 | from omegaconf import DictConfig 4 | import torch 5 | from torch import nn, Tensor 6 | from utils.pose_pe import PosePE 7 | 8 | 9 | class SceneCentricGlobal(nn.Module): 10 | def __init__( 11 | self, 12 | time_step_current: int, 13 | data_size: DictConfig, 14 | dropout_p_history: float, 15 | use_current_tl: bool, 16 | add_ohe: bool, 17 | pl_aggr: bool, 18 | pose_pe: DictConfig, 19 | ) -> None: 20 | super().__init__() 21 | self.dropout_p_history = dropout_p_history # [0, 1], turn off if set to negative 22 | self.step_current = time_step_current 23 | self.n_step_hist = time_step_current + 1 24 | self.use_current_tl = use_current_tl 25 | self.add_ohe = add_ohe 26 | self.pl_aggr = pl_aggr 27 | self.n_pl_node = data_size["map/valid"][-1] 28 | 29 | self.pose_pe_agent = PosePE(pose_pe["agent"]) 30 | self.pose_pe_map = PosePE(pose_pe["map"]) 31 | self.pose_pe_tl = PosePE(pose_pe["tl"]) 32 | 33 | tl_attr_dim = self.pose_pe_tl.out_dim + data_size["tl_stop/state"][-1] 34 | if self.pl_aggr: 35 | agent_attr_dim = ( 36 | self.pose_pe_agent.out_dim * self.n_step_hist 37 | + data_size["agent/spd"][-1] * self.n_step_hist # 1 38 | + data_size["agent/vel"][-1] * self.n_step_hist # 2 39 | + data_size["agent/yaw_rate"][-1] * self.n_step_hist # 1 40 | + data_size["agent/acc"][-1] * self.n_step_hist # 1 41 | + data_size["agent/size"][-1] # 3 42 | + data_size["agent/type"][-1] # 3 43 | + self.n_step_hist # valid 44 | ) 45 | map_attr_dim = self.pose_pe_map.out_dim * self.n_pl_node + data_size["map/type"][-1] + self.n_pl_node 46 | else: 47 | agent_attr_dim = ( 48 | self.pose_pe_agent.out_dim 49 | + data_size["agent/spd"][-1] # 1 50 | + data_size["agent/vel"][-1] # 2 51 | + data_size["agent/yaw_rate"][-1] # 1 52 | + data_size["agent/acc"][-1] # 1 53 | + data_size["agent/size"][-1] # 3 54 | + data_size["agent/type"][-1] # 3 55 | ) 56 | map_attr_dim = self.pose_pe_map.out_dim + data_size["map/type"][-1] 57 | 58 | if self.add_ohe: 59 | self.register_buffer("history_step_ohe", torch.eye(self.n_step_hist)) 60 | self.register_buffer("pl_node_ohe", torch.eye(self.n_pl_node)) 61 | if not self.pl_aggr: 62 | map_attr_dim += self.n_pl_node 63 | agent_attr_dim += self.n_step_hist 64 | if not self.use_current_tl: 65 | tl_attr_dim += self.n_step_hist 66 | 67 | self.model_kwargs = { 68 | "agent_attr_dim": agent_attr_dim, 69 | "map_attr_dim": map_attr_dim, 70 | "tl_attr_dim": tl_attr_dim, 71 | "n_step_hist": self.n_step_hist, 72 | "n_pl_node": self.n_pl_node, 73 | "use_current_tl": self.use_current_tl, 74 | "pl_aggr": self.pl_aggr, 75 | } 76 | 77 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: 78 | """ 79 | Args: scene-centric Dict 80 | # (ref) reference information for transform back to global coordinate and submission to waymo 81 | "ref/pos": [n_scene, n_agent, 1, 2] 82 | "ref/yaw": [n_scene, n_agent, 1] 83 | "ref/rot": [n_scene, n_agent, 2, 2] 84 | "ref/role": [n_scene, n_agent, 3] 85 | "ref/type": [n_scene, n_agent, 3] 86 | # (gt) ground-truth agent future for training, not available for testing 87 | "gt/valid": [n_scene, n_agent, n_step_future], bool 88 | "gt/pos": [n_scene, n_agent, n_step_future, 2] 89 | "gt/spd": [n_scene, n_agent, n_step_future, 1] 90 | "gt/vel": [n_scene, n_agent, n_step_future, 2] 91 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1] 92 | "gt/cmd": [n_scene, n_agent, 8] 93 | # (sc) scene-centric agents states 94 | "sc/agent_valid": [n_scene, n_agent, n_step_hist] 95 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2] 96 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2] 97 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1] 98 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1] 99 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1] 100 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1] 101 | # agent attributes 102 | "sc/agent_type": [n_scene, n_agent, 3] 103 | "sc/agent_role": [n_scene, n_agent, 3] 104 | "sc/agent_size": [n_scene, n_agent, 3] 105 | # map polylines 106 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool 107 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot 108 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32 109 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32 110 | # traffic lights 111 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool 112 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot 113 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y 114 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y 115 | 116 | Returns: add following keys to batch Dict 117 | # agent type: no need to be aggregated. 118 | "input/agent_type": [n_scene, n_agent, 3] 119 | # agent history 120 | if pl_aggr: # processed by mlp encoder, c.f. our method. 121 | "input/agent_valid": [n_scene, n_agent], bool 122 | "input/agent_attr": [n_scene, n_agent, agent_attr_dim] 123 | "input/agent_pos": [n_scene, n_agent, 2], (x,y), pos of the current step. 124 | "input/map_valid": [n_scene, n_pl], bool 125 | "input/map_attr": [n_scene, n_pl, map_attr_dim] 126 | "input/map_pos": [n_scene, n_pl, 2], (x,y), polyline starting node in global coordinate. 127 | else: # processed by pointnet encoder, c.f. vectornet. 128 | "input/agent_valid": [n_scene, n_agent, n_step_hist], bool 129 | "input/agent_attr": [n_scene, n_agent, n_step_hist, agent_attr_dim] 130 | "input/agent_pos": [n_scene, n_agent, 2] 131 | "input/map_valid": [n_scene, n_pl, n_pl_node], bool 132 | "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim] 133 | "input/"map_pos"": [n_scene, n_pl, 2] 134 | # traffic lights: stop point, cannot be aggregated, detections are not tracked, singular node polyline. 135 | if use_current_tl: 136 | "input/tl_valid": [n_scene, 1, n_tl], bool 137 | "input/tl_attr": [n_scene, 1, n_tl, tl_attr_dim] 138 | "input/tl_pos": [n_scene, 1, n_tl, 2] (x,y) 139 | else: 140 | "input/tl_valid": [n_scene, n_step_hist, n_tl], bool 141 | "input/tl_attr": [n_scene, n_step_hist, n_tl, tl_attr_dim] 142 | "input/tl_pos": [n_scene, n_step_hist, n_tl, 2] 143 | """ 144 | batch["input/agent_type"] = batch["sc/agent_type"] 145 | batch["input/agent_valid"] = batch["sc/agent_valid"] 146 | batch["input/tl_valid"] = batch["sc/tl_valid"] 147 | batch["input/map_valid"] = batch["sc/map_valid"] 148 | 149 | # ! randomly mask history agent/tl/map 150 | if self.training and (0 < self.dropout_p_history <= 1.0): 151 | prob_mask = torch.ones_like(batch["input/agent_valid"][..., :-1]) * (1 - self.dropout_p_history) 152 | batch["input/agent_valid"][..., :-1] &= torch.bernoulli(prob_mask).bool() 153 | prob_mask = torch.ones_like(batch["input/tl_valid"]) * (1 - self.dropout_p_history) 154 | batch["input/tl_valid"] &= torch.bernoulli(prob_mask).bool() 155 | prob_mask = torch.ones_like(batch["input/map_valid"]) * (1 - self.dropout_p_history) 156 | batch["input/map_valid"] &= torch.bernoulli(prob_mask).bool() 157 | 158 | # ! prepare "input/agent" 159 | batch["input/agent_pos"] = batch["ref/pos"].squeeze(2) 160 | if self.pl_aggr: # [n_scene, n_agent, agent_attr_dim] 161 | agent_invalid = ~batch["input/agent_valid"].unsqueeze(-1) # [n_scene, n_agent, n_step_hist, 1] 162 | agent_invalid_reduced = agent_invalid.all(-2) # [n_scene, n_agent, 1] 163 | batch["input/agent_attr"] = torch.cat( 164 | [ 165 | self.pose_pe_agent(batch["sc/agent_pos"], batch["sc/agent_yaw_bbox"]) 166 | .masked_fill(agent_invalid, 0) 167 | .flatten(-2, -1), 168 | batch["sc/agent_vel"].masked_fill(agent_invalid, 0).flatten(-2, -1), # n_step_hist*2 169 | batch["sc/agent_spd"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 170 | batch["sc/agent_yaw_rate"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 171 | batch["sc/agent_acc"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 172 | batch["sc/agent_size"].masked_fill(agent_invalid_reduced, 0), # 3 173 | batch["sc/agent_type"].masked_fill(agent_invalid_reduced, 0), # 3 174 | batch["input/agent_valid"], # n_step_hist 175 | ], 176 | dim=-1, 177 | ) 178 | batch["input/agent_valid"] = batch["input/agent_valid"].any(-1) # [n_scene, n_agent] 179 | else: # [n_scene, n_agent, n_step_hist, agent_attr_dim] 180 | batch["input/agent_attr"] = torch.cat( 181 | [ 182 | self.pose_pe_agent(batch["sc/agent_pos"], batch["sc/agent_yaw_bbox"]), 183 | batch["sc/agent_vel"], # vel xy, 2 184 | batch["sc/agent_spd"], # speed, 1 185 | batch["sc/agent_yaw_rate"], # yaw rate, 1 186 | batch["sc/agent_acc"], # acc, 1 187 | batch["sc/agent_size"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3 188 | batch["sc/agent_type"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3 189 | ], 190 | dim=-1, 191 | ) 192 | 193 | # ! prepare "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim] 194 | batch["input/map_pos"] = batch["sc/map_pos"][:, :, 0] 195 | if self.pl_aggr: # [n_scene, n_pl, map_attr_dim] 196 | map_invalid = ~batch["input/map_valid"].unsqueeze(-1) # [n_scene, n_pl, n_pl_node, 1] 197 | map_invalid_reduced = map_invalid.all(-2) # [n_scene, n_pl, 1] 198 | batch["input/map_attr"] = torch.cat( 199 | [ 200 | self.pose_pe_map(batch["sc/map_pos"], batch["sc/map_dir"]) 201 | .masked_fill(map_invalid, 0) 202 | .flatten(-2, -1), 203 | batch["sc/map_type"].masked_fill(map_invalid_reduced, 0), # n_pl_type 204 | batch["input/map_valid"], # n_pl_node 205 | ], 206 | dim=-1, 207 | ) 208 | batch["input/map_valid"] = batch["input/map_valid"].any(-1) # [n_scene, n_pl] 209 | else: # [n_scene, n_pl, n_pl_node, map_attr_dim] 210 | batch["input/map_attr"] = torch.cat( 211 | [ 212 | self.pose_pe_map(batch["sc/map_pos"], batch["sc/map_dir"]), # pl_dim 213 | batch["sc/map_type"].unsqueeze(-2).expand(-1, -1, self.n_pl_node, -1), # n_pl_type 214 | ], 215 | dim=-1, 216 | ) 217 | 218 | # ! prepare "input/tl_attr": [n_scene, n_step_hist/1, n_tl, tl_attr_dim] 219 | # [n_scene, n_step_hist, n_tl, 2] 220 | tl_pos = batch["sc/tl_pos"] 221 | tl_dir = batch["sc/tl_dir"] 222 | tl_state = batch["sc/tl_state"] 223 | if self.use_current_tl: 224 | tl_pos = tl_pos[:, [-1]] # [n_scene, 1, n_tl, 2] 225 | tl_dir = tl_dir[:, [-1]] # [n_scene, 1, n_tl, 2] 226 | tl_state = tl_state[:, [-1]] # [n_scene, 1, n_tl, 5] 227 | batch["input/tl_valid"] = batch["input/tl_valid"][:, [-1]] # [n_scene, 1, n_tl] 228 | batch["input/tl_attr"] = torch.cat([self.pose_pe_tl(tl_pos, tl_dir), tl_state], dim=-1) 229 | batch["input/tl_pos"] = tl_pos 230 | # ! add one-hot encoding for sequence (temporal, order of polyline nodes) 231 | if self.add_ohe: 232 | n_scene, n_agent, _ = batch["sc/agent_valid"].shape 233 | n_pl = batch["sc/map_valid"].shape[1] 234 | if not self.pl_aggr: # there is no need to add ohe if self.pl_aggr 235 | batch["input/agent_attr"] = torch.cat( 236 | [ 237 | batch["input/agent_attr"], 238 | self.history_step_ohe[None, None, :, :].expand(n_scene, n_agent, -1, -1), 239 | ], 240 | dim=-1, 241 | ) 242 | batch["input/map_attr"] = torch.cat( 243 | [batch["input/map_attr"], self.pl_node_ohe[None, None, :, :].expand(n_scene, n_pl, -1, -1),], 244 | dim=-1, 245 | ) 246 | 247 | if not self.use_current_tl: # there is no need to add ohe if use_current_tl 248 | n_tl = batch["input/tl_valid"].shape[-1] 249 | batch["input/tl_attr"] = torch.cat( 250 | [batch["input/tl_attr"], self.history_step_ohe[None, :, None, :].expand(n_scene, -1, n_tl, -1)], 251 | dim=-1, 252 | ) 253 | 254 | return batch 255 | -------------------------------------------------------------------------------- /src/data_modules/sc_relative.py: -------------------------------------------------------------------------------- 1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) 2 | from typing import Dict 3 | from omegaconf import DictConfig 4 | import torch 5 | from torch import nn, Tensor 6 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_dir2local, torch_rad2local 7 | from utils.pose_pe import PosePE 8 | 9 | 10 | class SceneCentricRelative(nn.Module): 11 | def __init__( 12 | self, 13 | time_step_current: int, 14 | data_size: DictConfig, 15 | dropout_p_history: float, 16 | use_current_tl: bool, 17 | add_ohe: bool, 18 | pl_aggr: bool, 19 | pose_pe: DictConfig, 20 | ) -> None: 21 | super().__init__() 22 | self.dropout_p_history = dropout_p_history # [0, 1], turn off if set to negative 23 | self.step_current = time_step_current 24 | self.n_step_hist = time_step_current + 1 25 | self.use_current_tl = use_current_tl 26 | self.add_ohe = add_ohe 27 | self.pl_aggr = pl_aggr 28 | self.n_pl_node = data_size["map/valid"][-1] 29 | 30 | self.pose_pe_agent = PosePE(pose_pe["agent"]) 31 | self.pose_pe_map = PosePE(pose_pe["map"]) 32 | 33 | tl_attr_dim = data_size["tl_stop/state"][-1] 34 | if self.pl_aggr: 35 | agent_attr_dim = ( 36 | self.pose_pe_agent.out_dim * self.n_step_hist 37 | + data_size["agent/spd"][-1] * self.n_step_hist # 1 38 | + data_size["agent/vel"][-1] * self.n_step_hist # 2 39 | + data_size["agent/yaw_rate"][-1] * self.n_step_hist # 1 40 | + data_size["agent/acc"][-1] * self.n_step_hist # 1 41 | + data_size["agent/size"][-1] # 3 42 | + data_size["agent/type"][-1] # 3 43 | + self.n_step_hist # valid 44 | ) 45 | map_attr_dim = self.pose_pe_map.out_dim * self.n_pl_node + data_size["map/type"][-1] + self.n_pl_node 46 | else: 47 | agent_attr_dim = ( 48 | self.pose_pe_agent.out_dim 49 | + data_size["agent/spd"][-1] # 1 50 | + data_size["agent/vel"][-1] # 2 51 | + data_size["agent/yaw_rate"][-1] # 1 52 | + data_size["agent/acc"][-1] # 1 53 | + data_size["agent/size"][-1] # 3 54 | + data_size["agent/type"][-1] # 3 55 | ) 56 | map_attr_dim = self.pose_pe_map.out_dim + data_size["map/type"][-1] 57 | 58 | if self.add_ohe: 59 | self.register_buffer("history_step_ohe", torch.eye(self.n_step_hist)) 60 | self.register_buffer("pl_node_ohe", torch.eye(self.n_pl_node)) 61 | if not self.pl_aggr: 62 | map_attr_dim += self.n_pl_node 63 | agent_attr_dim += self.n_step_hist 64 | if not self.use_current_tl: 65 | tl_attr_dim += self.n_step_hist 66 | 67 | self.model_kwargs = { 68 | "agent_attr_dim": agent_attr_dim, 69 | "map_attr_dim": map_attr_dim, 70 | "tl_attr_dim": tl_attr_dim, 71 | "n_step_hist": self.n_step_hist, 72 | "n_pl_node": self.n_pl_node, 73 | "use_current_tl": self.use_current_tl, 74 | "pl_aggr": self.pl_aggr, 75 | } 76 | 77 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: 78 | """ 79 | Args: scene-centric Dict, masked according to valid 80 | # (ref) reference information for transform back to global coordinate and submission to waymo 81 | "ref/pos": [n_scene, n_agent, 1, 2] 82 | "ref/yaw": [n_scene, n_agent, 1] 83 | "ref/rot": [n_scene, n_agent, 2, 2] 84 | "ref/role": [n_scene, n_agent, 3] 85 | "ref/type": [n_scene, n_agent, 3] 86 | # (gt) ground-truth agent future for training, not available for testing 87 | "gt/valid": [n_scene, n_agent, n_step_future], bool 88 | "gt/pos": [n_scene, n_agent, n_step_future, 2] 89 | "gt/spd": [n_scene, n_agent, n_step_future, 1] 90 | "gt/vel": [n_scene, n_agent, n_step_future, 2] 91 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1] 92 | "gt/cmd": [n_scene, n_agent, 8] 93 | # (sc) scene-centric agents states 94 | "sc/agent_valid": [n_scene, n_agent, n_step_hist] 95 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2] 96 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2] 97 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1] 98 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1] 99 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1] 100 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1] 101 | # agent attributes 102 | "sc/agent_type": [n_scene, n_agent, 3] 103 | "sc/agent_role": [n_scene, n_agent, 3] 104 | "sc/agent_size": [n_scene, n_agent, 3] 105 | # map polylines 106 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool 107 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot 108 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32 109 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32 110 | # traffic lights 111 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool 112 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot 113 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y 114 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y 115 | 116 | Returns: add following keys to batch Dict 117 | # agent type: no need to be aggregated. 118 | "input/agent_type": [n_scene, n_agent, 3] 119 | # agent history 120 | if self.pl_aggr: # processed by mlp encoder, c.f. our method. 121 | "input/agent_valid": [n_scene, n_agent], bool 122 | "input/agent_attr": [n_scene, n_agent, agent_attr_dim] in local coordinate wrt. the current step. 123 | "input/agent_pose": [n_scene, n_agent, 3], (x,y,yaw), pos of the current step. 124 | "input/map_valid": [n_scene, n_pl], bool 125 | "input/map_attr": [n_scene, n_pl, map_attr_dim], in local coordinate wrt. the first node. 126 | "input/map_pose": [n_scene, n_pl, 3], (x,y,yaw), polyline starting node in global coordinate. 127 | else: # processed by pointnet encoder, c.f. vectornet. 128 | "input/agent_valid": [n_scene, n_agent, n_step_hist], bool 129 | "input/agent_attr": [n_scene, n_agent, n_step_hist, agent_attr_dim] 130 | "input/agent_pose": [n_scene, n_agent, 3] 131 | "input/map_valid": [n_scene, n_pl, n_pl_node], bool 132 | "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim] 133 | "input/"map_pose"": [n_scene, n_pl, 3] 134 | # traffic lights: stop point, cannot be aggregated, detections are not tracked, singular node polyline. 135 | if use_current_tl: 136 | "input/tl_valid": [n_scene, 1, n_tl], bool 137 | "input/tl_attr": [n_scene, 1, n_tl, tl_attr_dim] no positional info 138 | "input/tl_pose": [n_scene, 1, n_tl, 3] (x,y,yaw) 139 | else: 140 | "input/tl_valid": [n_scene, n_step_hist, n_tl], bool 141 | "input/tl_attr": [n_scene, n_step_hist, n_tl, tl_attr_dim] 142 | "input/tl_pose": [n_scene, n_step_hist, n_tl, 3] 143 | """ 144 | batch["input/agent_type"] = batch["sc/agent_type"] 145 | batch["input/agent_valid"] = batch["sc/agent_valid"] 146 | batch["input/tl_valid"] = batch["sc/tl_valid"] 147 | batch["input/map_valid"] = batch["sc/map_valid"] 148 | 149 | # ! randomly mask history agent/tl/map 150 | if self.training and (0 < self.dropout_p_history <= 1.0): 151 | prob_mask = torch.ones_like(batch["input/agent_valid"][..., :-1]) * (1 - self.dropout_p_history) 152 | batch["input/agent_valid"][..., :-1] &= torch.bernoulli(prob_mask).bool() 153 | prob_mask = torch.ones_like(batch["input/tl_valid"]) * (1 - self.dropout_p_history) 154 | batch["input/tl_valid"] &= torch.bernoulli(prob_mask).bool() 155 | prob_mask = torch.ones_like(batch["input/map_valid"]) * (1 - self.dropout_p_history) 156 | batch["input/map_valid"] &= torch.bernoulli(prob_mask).bool() 157 | 158 | # ! prepare "input/agent" 159 | # [n_scene, n_agent, 3] 160 | batch["input/agent_pose"] = torch.cat([batch["ref/pos"].squeeze(2), batch["ref/yaw"]], dim=-1) 161 | agent_pos_local = torch_pos2local(batch["sc/agent_pos"], batch["ref/pos"], batch["ref/rot"]) 162 | agent_vel_local = torch_dir2local(batch["sc/agent_vel"], batch["ref/rot"]) 163 | agent_yaw_local = torch_rad2local(batch["sc/agent_yaw_bbox"], batch["ref/yaw"], cast=False) 164 | if self.pl_aggr: # [n_scene, n_agent, agent_attr_dim] 165 | agent_invalid = ~batch["input/agent_valid"].unsqueeze(-1) 166 | agent_invalid_reduced = agent_invalid.all(-2) # [n_scene, n_agent, 1] 167 | batch["input/agent_attr"] = torch.cat( 168 | [ 169 | self.pose_pe_agent(agent_pos_local, agent_yaw_local).masked_fill(agent_invalid, 0).flatten(-2, -1), 170 | agent_vel_local.masked_fill(agent_invalid, 0).flatten(-2, -1), # n_step_hist*2 171 | batch["sc/agent_spd"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 172 | batch["sc/agent_yaw_rate"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 173 | batch["sc/agent_acc"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist 174 | batch["sc/agent_size"].masked_fill(agent_invalid_reduced, 0), # 3 175 | batch["sc/agent_type"].masked_fill(agent_invalid_reduced, 0), # 3 176 | batch["input/agent_valid"], # n_step_hist 177 | ], 178 | dim=-1, 179 | ) 180 | batch["input/agent_valid"] = batch["input/agent_valid"].any(-1) # [n_scene, n_agent] 181 | else: # [n_scene, n_agent, n_step_hist, agent_attr_dim] 182 | batch["input/agent_attr"] = torch.cat( 183 | [ 184 | self.pose_pe_agent(agent_pos_local, agent_yaw_local), 185 | agent_vel_local, # vel xy, 2 186 | batch["sc/agent_spd"], # speed, 1 187 | batch["sc/agent_yaw_rate"], # yaw rate, 1 188 | batch["sc/agent_acc"], # acc, 1 189 | batch["sc/agent_size"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3 190 | batch["sc/agent_type"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3 191 | ], 192 | dim=-1, 193 | ) 194 | 195 | # ! prepare "input/map" 196 | # [n_scene, n_pl] 197 | pl_yaw = torch.atan2(batch["sc/map_dir"][:, :, 0, 1], batch["sc/map_dir"][:, :, 0, 0]) 198 | # [n_scene, n_pl, 3] 199 | batch["input/map_pose"] = torch.cat([batch["sc/map_pos"][:, :, 0], pl_yaw.unsqueeze(-1)], dim=-1) 200 | # batch["sc/map_pos"], batch["sc/map_dir"]: [n_scene, n_pl, n_pl_node, 2] 201 | pl_rot = torch_rad2rot(pl_yaw) # [n_scene, n_pl, 2, 2] 202 | map_pos_local = torch_pos2local(batch["sc/map_pos"], batch["sc/map_pos"][:, :, [0]], pl_rot) 203 | map_dir_local = torch_dir2local(batch["sc/map_dir"], pl_rot) 204 | if self.pl_aggr: # [n_scene, n_pl, map_attr_dim] 205 | map_invalid = ~batch["input/map_valid"].unsqueeze(-1) # [n_scene, n_pl, n_pl_node, 1] 206 | map_invalid_reduced = map_invalid.all(-2) # [n_scene, n_pl, 1] 207 | batch["input/map_attr"] = torch.cat( 208 | [ 209 | self.pose_pe_map(map_pos_local, map_dir_local).masked_fill(map_invalid, 0).flatten(-2, -1), 210 | batch["sc/map_type"].masked_fill(map_invalid_reduced, 0), # n_pl_type 211 | batch["input/map_valid"], # n_pl_node 212 | ], 213 | dim=-1, 214 | ) 215 | batch["input/map_valid"] = batch["input/map_valid"].any(-1) # [n_scene, n_pl] 216 | else: # [n_scene, n_pl, n_pl_node, map_attr_dim] 217 | batch["input/map_attr"] = torch.cat( 218 | [ 219 | self.pose_pe_map(map_pos_local, map_dir_local), # pl_dim 220 | batch["sc/map_type"].unsqueeze(-2).expand(-1, -1, self.n_pl_node, -1), # n_pl_type 221 | ], 222 | dim=-1, 223 | ) 224 | 225 | # ! prepare "input/tl_attr": [n_scene, n_step_hist/1, n_tl, tl_attr_dim] 226 | # [n_scene, n_step_hist, n_tl, 2] 227 | tl_pos = batch["sc/tl_pos"] 228 | tl_dir = batch["sc/tl_dir"] 229 | tl_state = batch["sc/tl_state"] 230 | if self.use_current_tl: 231 | tl_pos = tl_pos[:, [-1]] # [n_scene, 1, n_tl, 2] 232 | tl_dir = tl_dir[:, [-1]] # [n_scene, 1, n_tl, 2] 233 | tl_state = tl_state[:, [-1]] # [n_scene, 1, n_tl, 5] 234 | batch["input/tl_valid"] = batch["input/tl_valid"][:, [-1]] # [n_scene, 1, n_tl] 235 | tl_yaw = torch.atan2(tl_dir[..., 1], tl_dir[..., 0]) # [n_scene, n_step_hist/1, n_tl] 236 | batch["input/tl_pose"] = torch.cat([tl_pos, tl_yaw.unsqueeze(-1)], dim=-1) 237 | batch["input/tl_attr"] = tl_state.type_as(batch["sc/map_pos"]) # to float 238 | 239 | # ! add one-hot encoding for sequence (temporal, order of polyline nodes) 240 | if self.add_ohe: 241 | n_scene, n_agent, _ = batch["sc/agent_valid"].shape 242 | n_pl = batch["sc/map_valid"].shape[1] 243 | if not self.pl_aggr: # there is no need to add ohe if self.pl_aggr 244 | batch["input/agent_attr"] = torch.cat( 245 | [ 246 | batch["input/agent_attr"], 247 | self.history_step_ohe[None, None, :, :].expand(n_scene, n_agent, -1, -1), 248 | ], 249 | dim=-1, 250 | ) 251 | batch["input/map_attr"] = torch.cat( 252 | [batch["input/map_attr"], self.pl_node_ohe[None, None, :, :].expand(n_scene, n_pl, -1, -1),], 253 | dim=-1, 254 | ) 255 | 256 | if not self.use_current_tl: # there is no need to add ohe if use_current_tl 257 | n_tl = batch["input/tl_valid"].shape[-1] 258 | batch["input/tl_attr"] = torch.cat( 259 | [batch["input/tl_attr"], self.history_step_ohe[None, :, None, :].expand(n_scene, -1, n_tl, -1)], 260 | dim=-1, 261 | ) 262 | 263 | return batch 264 | --------------------------------------------------------------------------------