├── src
├── __init__.py
├── models
│ ├── __init__.py
│ ├── metrics
│ │ ├── __init__.py
│ │ └── nll.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── pos_emb.py
│ │ ├── mlp.py
│ │ ├── point_net.py
│ │ ├── multi_modal.py
│ │ ├── rpe.py
│ │ ├── decoder_ensemble.py
│ │ ├── attention.py
│ │ └── transformer.py
├── utils
│ ├── __init__.py
│ ├── pose_pe.py
│ ├── transform_utils.py
│ └── submission.py
├── callbacks
│ ├── __init__.py
│ └── wandb_callbacks.py
├── data_modules
│ ├── __init__.py
│ ├── post_processing.py
│ ├── scene_centric.py
│ ├── data_h5_womd.py
│ ├── data_h5_av2.py
│ ├── sc_global.py
│ └── sc_relative.py
├── pl_modules
│ └── __init__.py
├── run.py
└── pack_h5_av2.py
├── docs
├── hptr_banner.png
├── hptr_teaser.png
├── hptr_efficiency.png
└── ablation_models.md
├── configs
├── resume
│ ├── empty.yaml
│ ├── sub_av2.yaml
│ └── sub_womd.yaml
├── loggers
│ └── wandb.yaml
├── datamodule
│ ├── h5_av2.yaml
│ └── h5_womd.yaml
├── run.yaml
├── trainer
│ ├── av2.yaml
│ └── womd.yaml
├── callbacks
│ └── wandb.yaml
└── model
│ ├── acg_womd.yaml
│ ├── scg_womd.yaml
│ ├── scr_av2.yaml
│ └── scr_womd.yaml
├── bash
├── pack_h5.sh
├── train.sh
└── submission.sh
├── .gitignore
├── env_av2.yml
├── README.md
└── environment.yml
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/pl_modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/hptr_banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_banner.png
--------------------------------------------------------------------------------
/docs/hptr_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_teaser.png
--------------------------------------------------------------------------------
/configs/resume/empty.yaml:
--------------------------------------------------------------------------------
1 | checkpoint: null
2 | resume_trainer: True
3 | model_overrides: {}
4 |
--------------------------------------------------------------------------------
/docs/hptr_efficiency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhejz/HPTR/HEAD/docs/hptr_efficiency.png
--------------------------------------------------------------------------------
/configs/loggers/wandb.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | _target_: pytorch_lightning.loggers.WandbLogger
3 | project: debug
4 | group: null
5 | name: my_name
6 | notes: my_notes
7 | tags: []
8 | job_type: train
9 | entity: YOUR_ENTITY
10 |
--------------------------------------------------------------------------------
/configs/datamodule/h5_av2.yaml:
--------------------------------------------------------------------------------
1 | _target_: data_modules.data_h5_av2.DataH5av2
2 |
3 | data_dir: /cluster/scratch/zhejzhan/h5_av2_hptr
4 | filename_train: training
5 | filename_val: validation
6 | filename_test: testing
7 | n_agent: 64
8 |
9 | batch_size: 3
10 | num_workers: 4
11 |
--------------------------------------------------------------------------------
/configs/datamodule/h5_womd.yaml:
--------------------------------------------------------------------------------
1 | _target_: data_modules.data_h5_womd.DataH5womd
2 |
3 | data_dir: /cluster/scratch/zhejzhan/h5_womd_hptr
4 | filename_train: training
5 | filename_val: validation
6 | filename_test: testing
7 | n_agent: 64
8 |
9 | batch_size: 3
10 | num_workers: 4
11 |
--------------------------------------------------------------------------------
/configs/run.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - trainer: womd
6 | - model: scr_womd
7 | - datamodule: h5_womd
8 | - callbacks: wandb
9 | - loggers: wandb
10 | - resume: empty
11 |
12 | hydra:
13 | run:
14 | dir: logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
15 |
16 | work_dir: ${hydra:runtime.cwd}
17 | seed: 2023
18 | action: fit # fit, validate, test
19 |
--------------------------------------------------------------------------------
/configs/trainer/av2.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | limit_train_batches: 0.5
4 | limit_val_batches: 0.25
5 | limit_test_batches: 1.0
6 | num_sanity_val_steps: 0
7 | max_epochs: null
8 | min_epochs: null
9 | # max_time: 00:120:00:00
10 |
11 | log_every_n_steps: 200
12 |
13 | gradient_clip_val: 0.5
14 | track_grad_norm: 2
15 | gpus: -1
16 | precision: 32
17 | benchmark: False
18 | deterministic: False
19 | sync_batchnorm: False
20 | detect_anomaly: False
21 | accumulate_grad_batches: 1
22 | resume_from_checkpoint: null
23 | enable_progress_bar: True
24 |
--------------------------------------------------------------------------------
/configs/trainer/womd.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | limit_train_batches: 0.25
4 | limit_val_batches: 0.25
5 | limit_test_batches: 1.0
6 | num_sanity_val_steps: 0
7 | max_epochs: null
8 | min_epochs: null
9 | # max_time: 00:120:00:00
10 |
11 | log_every_n_steps: 200
12 |
13 | gradient_clip_val: 0.5
14 | track_grad_norm: 2
15 | gpus: -1
16 | precision: 32
17 | benchmark: False
18 | deterministic: False
19 | sync_batchnorm: False
20 | detect_anomaly: False
21 | accumulate_grad_batches: 1
22 | resume_from_checkpoint: null
23 | enable_progress_bar: True
24 |
--------------------------------------------------------------------------------
/configs/callbacks/wandb.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: callbacks.wandb_callbacks.ModelCheckpointWB
3 | dirpath: "checkpoints/"
4 | filename: "{epoch:02d}"
5 | monitor: "val/loss" # name of the logged metric which determines when model is improving
6 | save_top_k: 1 # save k best models (determined by above metric)
7 | save_last: True # additionaly always save model from last epoch
8 | mode: "min" # can be "max" or "min"
9 | verbose: True
10 | save_only_best: True # if True, only save best model according to "monitor" metric
11 |
12 | lr_monitor:
13 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
14 |
15 | stochastic_weight_avg:
16 | _target_: pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging
--------------------------------------------------------------------------------
/configs/resume/sub_av2.yaml:
--------------------------------------------------------------------------------
1 | checkpoint: null
2 | resume_trainer: True
3 | model_overrides:
4 | n_video_batch: 0
5 | post_processing:
6 | to_dict:
7 | _target_: data_modules.post_processing.ToDict
8 | predictions: [pos, cov3, spd, vel, yaw_bbox]
9 | get_cov_mat:
10 | _target_: data_modules.post_processing.GetCovMat
11 | rho_clamp: 5.0
12 | std_min: -1.609
13 | std_max: 5.0
14 | waymo:
15 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
16 | k_pred: 6
17 | use_ade: True
18 | score_temperature: 0.5
19 | mpa_nms_thresh: [] # veh, ped, cyc
20 | gt_in_local: True
21 | sub_womd:
22 | _target_: utils.submission.SubWOMD
23 | activate: False
24 | method_name: METHOD_NAME
25 | authors: [NAME1, NAME2]
26 | affiliation: AFFILIATION
27 | description: scr_womd
28 | method_link: METHOD_LINK
29 | account_name: ACCOUNT_NAME
30 | sub_av2:
31 | _target_: utils.submission.SubAV2
32 | activate: True
33 |
--------------------------------------------------------------------------------
/configs/resume/sub_womd.yaml:
--------------------------------------------------------------------------------
1 | checkpoint: null
2 | resume_trainer: True
3 | model_overrides:
4 | n_video_batch: 0
5 | post_processing:
6 | to_dict:
7 | _target_: data_modules.post_processing.ToDict
8 | predictions: [pos, cov3, spd, vel, yaw_bbox]
9 | get_cov_mat:
10 | _target_: data_modules.post_processing.GetCovMat
11 | rho_clamp: 5.0
12 | std_min: -1.609
13 | std_max: 5.0
14 | waymo:
15 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
16 | k_pred: 6
17 | use_ade: True
18 | score_temperature: -1
19 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc
20 | gt_in_local: True
21 | sub_womd:
22 | _target_: utils.submission.SubWOMD
23 | activate: True
24 | method_name: HPTR
25 | authors: [Zhejun Zhang, Alexander Liniger, Christos Sakaridis, Fisher Yu, Luc Van Gool]
26 | affiliation: "CVL, ETH Zurich"
27 | description: "Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding. NeurIPS 2023. https://github.com/zhejz/HPTR"
28 | method_link: "https://github.com/zhejz/HPTR"
29 | account_name: "YOUR_ACCOUNT_NAME"
30 | sub_av2:
31 | _target_: utils.submission.SubAV2
32 | activate: False
33 |
--------------------------------------------------------------------------------
/bash/pack_h5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=./logs/%j.out
3 | #SBATCH --error=./logs/%j.out
4 | #SBATCH --time=120:00:00
5 | #SBATCH -n 1
6 | #SBATCH --cpus-per-task=12
7 | #SBATCH --mem-per-cpu=5000
8 | #SBATCH --tmp=200000
9 | #SBATCH --open-mode=truncate
10 |
11 | trap "echo sigterm recieved, exiting!" SIGTERM
12 |
13 | run () {
14 | python -u src/pack_h5_womd.py --dataset=training \
15 | --out-dir=/cluster/scratch/zhejzhan/h5_womd_hptr \
16 | --data-dir=/cluster/scratch/zhejzhan/womd_scenario_v_1_2_0
17 | }
18 |
19 | # ! for validation and testing
20 | # python -u scripts/pack_h5_womd.py --dataset=validation --rand-pos=-1 --rand-yaw=-1 \
21 | # python -u scripts/pack_h5_womd.py --dataset=testing --rand-pos=-1 --rand-yaw=-1 \
22 |
23 | # ! for packing av2
24 | # conda activate hptr_av2
25 | # run () {
26 | # python -u src/pack_h5_av2.py --dataset=training \
27 | # --out-dir=/cluster/scratch/zhejzhan/h5_av2_hptr \
28 | # --data-dir=/cluster/scratch/zhejzhan/av2_motion
29 | # }
30 |
31 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh
32 | conda activate hptr # for av2: conda activate hptr_av2
33 |
34 | echo Running on host: `hostname`
35 | echo In directory: `pwd`
36 | echo Starting on: `date`
37 |
38 |
39 | type run
40 | echo START: `date`
41 | run &
42 | wait
43 | echo DONE: `date`
44 |
45 | mkdir -p ./logs/slurm
46 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out
47 |
48 | echo finished at: `date`
49 | exit 0;
--------------------------------------------------------------------------------
/bash/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=./logs/%j.out
3 | #SBATCH --error=./logs/%j.out
4 | #SBATCH --time=120:00:00
5 | #SBATCH -n 1
6 | #SBATCH --cpus-per-task=12
7 | #SBATCH --mem-per-cpu=5000
8 | #SBATCH --tmp=200000
9 | #SBATCH --gpus=rtx_2080_ti:4
10 | #SBATCH --open-mode=truncate
11 |
12 | trap "echo sigterm recieved, exiting!" SIGTERM
13 |
14 | DATASET_DIR="h5_womd_hptr"
15 | run () {
16 | python -u src/run.py \
17 | trainer=womd \
18 | model=scr_womd \
19 | datamodule=h5_womd \
20 | loggers.wandb.name="hptr_womd" \
21 | loggers.wandb.project="hptr_train" \
22 | loggers.wandb.entity="YOUR_ENTITY" \
23 | datamodule.data_dir=${TMPDIR}/datasets \
24 | hydra.run.dir='/cluster/scratch/zhejzhan/logs/${now:%Y-%m-%d}/${now:%H-%M-%S}'
25 | }
26 |
27 | # ! For AV2 dataset.
28 | # DATASET_DIR="h5_av2_hptr"
29 | # trainer=av2 \
30 | # model=scr_av2 \
31 | # datamodule=h5_av2 \
32 |
33 | # ! To resume training.
34 | # resume.checkpoint=YOUR_WANDB_RUN_NAME:latest \
35 |
36 |
37 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh
38 | conda activate hptr
39 |
40 | echo Running on host: `hostname`
41 | echo In directory: `pwd`
42 | echo Starting on: `date`
43 |
44 | echo START copying data: `date`
45 | mkdir $TMPDIR/datasets
46 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/training.h5 $TMPDIR/datasets/
47 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/validation.h5 $TMPDIR/datasets/
48 | echo DONE copying: `date`
49 |
50 | type run
51 | echo START: `date`
52 | run &
53 | wait
54 | echo DONE: `date`
55 |
56 | mkdir -p ./logs/slurm
57 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out
58 |
59 | echo finished at: `date`
60 | exit 0;
61 |
--------------------------------------------------------------------------------
/bash/submission.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --output=./logs/%j.out
3 | #SBATCH --error=./logs/%j.out
4 | #SBATCH --time=4:00:00
5 | #SBATCH -n 1
6 | #SBATCH --cpus-per-task=8
7 | #SBATCH --mem-per-cpu=5000
8 | #SBATCH --tmp=100000
9 | #SBATCH --gpus=rtx_2080_ti:1
10 | #SBATCH --open-mode=truncate
11 |
12 | trap "echo sigterm recieved, exiting!" SIGTERM
13 |
14 | DATASET_DIR="h5_womd_hptr"
15 | run () {
16 | python -u src/run.py \
17 | trainer=womd \
18 | model=scr_womd \
19 | datamodule=h5_womd \
20 | resume=sub_womd \
21 | action=validate \
22 | trainer.limit_val_batches=1.0 \
23 | resume.checkpoint=YOUR_WANDB_RUN_NAME:latest \
24 | loggers.wandb.name="hptr_womd_val" \
25 | loggers.wandb.project="hptr_sub" \
26 | loggers.wandb.entity="YOUR_ENTITY" \
27 | datamodule.data_dir=${TMPDIR}/datasets \
28 | hydra.run.dir='/cluster/scratch/zhejzhan/logs/${now:%Y-%m-%d}/${now:%H-%M-%S}'
29 | }
30 |
31 | # ! For AV2 dataset.
32 | # DATASET_DIR="h5_av2_hptr"
33 | # trainer=av2 \
34 | # model=scr_av2 \
35 | # datamodule=h5_av2 \
36 | # resume=sub_av2 \
37 |
38 | # ! For testing.
39 | # action=test \
40 |
41 |
42 | source /cluster/project/cvl/zhejzhan/apps/miniconda3/etc/profile.d/conda.sh
43 | conda activate hptr
44 |
45 | echo Running on host: `hostname`
46 | echo In directory: `pwd`
47 | echo Starting on: `date`
48 |
49 | echo START copying data: `date`
50 | mkdir $TMPDIR/datasets
51 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/validation.h5 $TMPDIR/datasets/
52 | cp /cluster/scratch/zhejzhan/$DATASET_DIR/testing.h5 $TMPDIR/datasets/
53 | echo DONE copying: `date`
54 |
55 | type run
56 | echo START: `date`
57 | run &
58 | wait
59 | echo DONE: `date`
60 |
61 | mkdir -p ./logs/slurm
62 | mv ./logs/$SLURM_JOB_ID.out ./logs/slurm/$SLURM_JOB_ID.out
63 |
64 | echo finished at: `date`
65 | exit 0;
66 |
--------------------------------------------------------------------------------
/src/models/modules/pos_emb.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import torch
3 | from torch import Tensor, nn
4 |
5 |
6 | class PositionalEmbedding(nn.Module):
7 | def __init__(self, dim: int, theta: float = 10000):
8 | super().__init__()
9 | assert dim % 2 == 0
10 | self.dim = dim
11 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
12 | # [dim]
13 | freqs = freqs.repeat_interleave(2, 0)
14 | self.register_buffer("freqs", freqs)
15 |
16 | def forward(self, x: Tensor):
17 | """
18 | Args:
19 | x: [...]
20 | Returns:
21 | pos_enc: [..., dim]
22 | """
23 | # [..., dim]
24 | pos_enc = x.unsqueeze(-1) * self.freqs.view([1] * x.dim() + [-1])
25 | pos_enc = torch.cat([torch.cos(pos_enc[..., ::2]), torch.sin(pos_enc[..., 1::2])], dim=-1)
26 | return pos_enc
27 |
28 |
29 | class PositionalEmbeddingRad(nn.Module):
30 | def __init__(self, dim: int):
31 | """
32 | if dim=2, then just [cos(theta), sin(theta)]
33 | """
34 | super().__init__()
35 | assert dim % 2 == 0
36 | self.dim = dim
37 | # [dim]: [1,1,2,2,4,4,8,8]
38 | # freqs = 2 ** (torch.arange(0, dim // 2).float())
39 | # [dim]: [1,1,2,2,3,3,4,4]
40 | freqs = torch.arange(0, dim // 2) + 1.0
41 | freqs = freqs.repeat_interleave(2, 0)
42 | self.register_buffer("freqs", freqs)
43 |
44 | def forward(self, x: Tensor):
45 | """
46 | Args:
47 | x: [...], in rad
48 | Returns:
49 | pos_enc: [..., dim]
50 | """
51 | # [..., dim]
52 | pos_enc = x.unsqueeze(-1) * self.freqs.view([1] * x.dim() + [-1])
53 | pos_enc = torch.cat([torch.cos(pos_enc[..., ::2]), torch.sin(pos_enc[..., 1::2])], dim=-1)
54 | return pos_enc
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # my ignore
132 | .vscode/settings.json
133 | logs/
134 | runs/
135 | outputs/
136 | wandb/
137 | data/
138 | *.pyc
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import hydra
3 | import torch
4 | from omegaconf import DictConfig
5 | from typing import List
6 | from pytorch_lightning import seed_everything, LightningDataModule, LightningModule, Trainer, Callback
7 | from pytorch_lightning.loggers import LightningLoggerBase
8 | import os
9 |
10 |
11 | def download_checkpoint(loggers, wb_ckpt) -> None:
12 | if os.environ.get("LOCAL_RANK", 0) == 0:
13 | artifact = loggers[0].experiment.use_artifact(wb_ckpt, type="model")
14 | artifact_dir = artifact.download("ckpt")
15 |
16 |
17 | @hydra.main(config_path="../configs/", config_name="run.yaml")
18 | def main(config: DictConfig) -> None:
19 |
20 | seed_everything(config.seed, workers=True)
21 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
22 |
23 | callbacks: List[Callback] = []
24 | if "callbacks" in config:
25 | for _, cb_conf in config.callbacks.items():
26 | callbacks.append(hydra.utils.instantiate(cb_conf))
27 |
28 | loggers: List[LightningLoggerBase] = []
29 | if "loggers" in config:
30 | for _, lg_conf in config.loggers.items():
31 | loggers.append(hydra.utils.instantiate(lg_conf))
32 |
33 | if config.resume.checkpoint is None:
34 | model: LightningModule = hydra.utils.instantiate(
35 | config.model, data_size=datamodule.tensor_size_train, _recursive_=False
36 | )
37 | else:
38 | download_checkpoint(loggers, config.resume.checkpoint)
39 | ckpt_path = "ckpt/model.ckpt"
40 | modelClass = hydra.utils.get_class(config.model._target_)
41 |
42 | model = modelClass.load_from_checkpoint(
43 | ckpt_path, wb_artifact=config.resume.checkpoint, **config.resume.model_overrides
44 | )
45 |
46 | if config.resume.resume_trainer and config.action == "fit":
47 | config.trainer.resume_from_checkpoint = ckpt_path
48 |
49 | # from pytorch_lightning.plugins import DDPPlugin
50 | # strategy = DDPPlugin(gradient_as_bucket_view=True)
51 | strategy = None
52 | if torch.cuda.device_count() > 1:
53 | strategy = "ddp"
54 | trainer: Trainer = hydra.utils.instantiate(
55 | config.trainer, strategy=strategy, callbacks=callbacks, logger=loggers, _convert_="partial"
56 | )
57 | if config.action == "fit":
58 | trainer.fit(model=model, datamodule=datamodule)
59 | elif config.action == "validate":
60 | trainer.validate(model=model, datamodule=datamodule)
61 | elif config.action == "test":
62 | trainer.test(model=model, datamodule=datamodule)
63 | else:
64 | raise NotImplementedError
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/docs/ablation_models.md:
--------------------------------------------------------------------------------
1 | # Configurations of Ablation Models
2 |
3 | Use the following configurations and adapt [bash/train.sh](../bash/train.sh) to train the ablation models.
4 |
5 | ## Input Representation
6 | - Our `HPTR` for Waymo dataset. The model has 15.2M parameters.
7 | ```
8 | model=scr_womd \
9 | ```
10 | - Our `HPTR` for AV2 dataset.
11 | ```
12 | trainer=av2 \
13 | model=scr_av2 \
14 | datamodule=h5_av2 \
15 | ```
16 | - Agent-centric baseline [Wayformer](https://arxiv.org/abs/2207.05844), i.e. `WF baseline`.
17 | ```
18 | model=acg_womd \
19 | ```
20 | - Scene-centric baseline [SceneTransformer](https://arxiv.org/abs/2106.08417), i.e. `HPTR SC`.
21 | ```
22 | model=scg_womd \
23 | ```
24 |
25 | ## Hierarchical Architecture
26 |
27 | - `HPTR diag+full` with 15.4M parameters. It needs RTX 3090 for training.
28 | ```
29 | model.model.intra_class_encoder.n_layer_tf_map=6 \
30 | model.model.intra_class_encoder.n_layer_tf_tl=2 \
31 | model.model.intra_class_encoder.n_layer_tf_agent=2 \
32 | model.model.decoder.tf_n_layer=2 \
33 | model.model.decoder.k_reinforce_tl=-1 \
34 | model.model.decoder.k_reinforce_agent=-1 \
35 | model.model.decoder.k_reinforce_all=1 \
36 | ```
37 | - `HPTR diag` with 15.4M parameters.
38 | ```
39 | model.model.intra_class_encoder.n_layer_tf_map=6 \
40 | model.model.intra_class_encoder.n_layer_tf_tl=3 \
41 | model.model.intra_class_encoder.n_layer_tf_agent=3 \
42 | model.model.decoder.tf_n_layer=2 \
43 | model.model.decoder.k_reinforce_tl=-1 \
44 | model.model.decoder.k_reinforce_agent=-1 \
45 | ```
46 | - `HPTR full` with 15.2M parameters. It needs RTX 3090 for training.
47 | ```
48 | model.model.intra_class_encoder.n_layer_tf_map=-1 \
49 | model.model.decoder.tf_n_layer=6 \
50 | model.model.decoder.k_reinforce_tl=-1 \
51 | model.model.decoder.k_reinforce_agent=-1 \
52 | model.model.decoder.k_reinforce_all=1 \
53 | ```
54 |
55 | ## Others
56 | - Different polyline embedding.
57 | ```
58 | model.pre_processing.relative.pose_pe.agent=xy_dir \
59 | model.pre_processing.relative.pose_pe.map=xy_dir \
60 | ```
61 | - Attention without bias.
62 | ```
63 | model.model.tf_cfg.bias=False \
64 | ```
65 | - Different RPE mode.
66 | ```
67 | model.model.rpe_mode=xy_dir \
68 | model.model.rpe_mode=pe_xy_dir \
69 | ```
70 | - Apply RPE to query. It needs RTX 3090 for training.
71 | ```
72 | model.model.tf_cfg.apply_q_rpe=True \
73 | ```
74 | - Without anchor reinforce (17.5M parameters).
75 | ```
76 | model.model.decoder.tf_n_layer=3 \
77 | model.model.decoder.k_reinforce_agent=8 \
78 | model.model.decoder.k_reinforce_anchor=-1 \
79 | ```
80 | - Without anchor reinforce, larger model (23.3 parameters).
81 | ```
82 | model.model.n_tgt_knn=50 \
83 | model.model.intra_class_encoder.n_layer_tf_map=6 \
84 | model.model.decoder.tf_n_layer=4 \
85 | model.model.decoder.k_reinforce_agent=8 \
86 | model.model.decoder.k_reinforce_anchor=-1 \
87 | ```
--------------------------------------------------------------------------------
/src/models/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import List, Tuple, Union, Optional
3 | from torch import Tensor, nn
4 |
5 |
6 | def _get_activation(activation: str, inplace: bool) -> nn.Module:
7 | if activation == "relu":
8 | return nn.ReLU(inplace=inplace)
9 | elif activation == "gelu":
10 | return nn.GELU()
11 | raise RuntimeError("activation {} not implemented".format(activation))
12 |
13 |
14 | class MLP(nn.Module):
15 | def __init__(
16 | self,
17 | fc_dims: Union[List, Tuple],
18 | dropout_p: Optional[float] = None,
19 | use_layernorm: bool = False,
20 | activation: str = "relu",
21 | end_layer_activation: bool = True,
22 | init_weight_norm: bool = False,
23 | init_bias: Optional[float] = None,
24 | use_batchnorm: bool = False,
25 | ) -> None:
26 | super(MLP, self).__init__()
27 | assert len(fc_dims) >= 2
28 | assert not (use_layernorm and use_batchnorm)
29 | layers: List[nn.Module] = []
30 | for i in range(0, len(fc_dims) - 1):
31 |
32 | fc = nn.Linear(fc_dims[i], fc_dims[i + 1])
33 |
34 | if init_weight_norm:
35 | fc.weight.data *= 1.0 / fc.weight.norm(dim=1, p=2, keepdim=True)
36 | if init_bias is not None and i == len(fc_dims) - 2:
37 | fc.bias.data *= 0
38 | fc.bias.data += init_bias
39 |
40 | layers.append(fc)
41 |
42 | if i < len(fc_dims) - 2:
43 | if use_layernorm:
44 | layers.append(nn.LayerNorm(fc_dims[i + 1]))
45 | elif use_batchnorm:
46 | layers.append(nn.BatchNorm1d(fc_dims[i + 1]))
47 | if dropout_p is not None:
48 | layers.append(nn.Dropout(p=dropout_p))
49 | layers.append(_get_activation(activation, inplace=True))
50 | if i == len(fc_dims) - 2:
51 | if end_layer_activation:
52 | if use_layernorm:
53 | layers.append(nn.LayerNorm(fc_dims[i + 1]))
54 | elif use_batchnorm:
55 | layers.append(nn.BatchNorm1d(fc_dims[i + 1]))
56 | if dropout_p is not None:
57 | layers.append(nn.Dropout(p=dropout_p))
58 | self.end_layer_activation = _get_activation(activation, inplace=True)
59 | else:
60 | self.end_layer_activation = None
61 |
62 | self.input_dim = fc_dims[0]
63 | self.output_dim = fc_dims[-1]
64 | self.fc_layers = nn.Sequential(*layers)
65 |
66 | def forward(self, x: Tensor, valid_mask: Optional[Tensor] = None, fill_invalid: float = 0.0) -> Tensor:
67 | """
68 | Args:
69 | x: [..., input_dim]
70 | valid_mask: [...]
71 | Returns:
72 | x: [..., output_dim]
73 | """
74 | x = self.fc_layers(x.flatten(0, -2)).view(*x.shape[:-1], self.output_dim)
75 | if valid_mask is not None:
76 | x.masked_fill_(~valid_mask.unsqueeze(-1), fill_invalid)
77 | if self.end_layer_activation is not None:
78 | self.end_layer_activation(x)
79 | return x
80 |
--------------------------------------------------------------------------------
/src/models/modules/point_net.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Tuple, List, Optional
3 | import torch
4 | from torch import Tensor, nn
5 | from .mlp import MLP
6 |
7 |
8 | class PointNet(nn.Module):
9 | def __init__(
10 | self,
11 | input_dim: int,
12 | hidden_dim: int,
13 | n_layer: int = 3,
14 | use_layernorm: bool = False,
15 | use_batchnorm: bool = False,
16 | end_layer_activation: bool = True,
17 | dropout_p: Optional[float] = None,
18 | pool_mode: str = "max", # max, mean, first
19 | ) -> None:
20 | super().__init__()
21 | self.pool_mode = pool_mode
22 | self.input_mlp = MLP(
23 | [input_dim, hidden_dim, hidden_dim],
24 | dropout_p=dropout_p,
25 | use_layernorm=use_layernorm,
26 | use_batchnorm=use_batchnorm,
27 | )
28 |
29 | mlp_layers: List[nn.Module] = []
30 | for _ in range(n_layer - 2):
31 | mlp_layers.append(
32 | MLP(
33 | [hidden_dim, hidden_dim // 2],
34 | dropout_p=dropout_p,
35 | use_layernorm=use_layernorm,
36 | use_batchnorm=use_batchnorm,
37 | )
38 | )
39 | mlp_layers.append(
40 | MLP(
41 | [hidden_dim, hidden_dim // 2],
42 | dropout_p=dropout_p,
43 | use_layernorm=use_layernorm,
44 | use_batchnorm=use_batchnorm,
45 | end_layer_activation=end_layer_activation,
46 | )
47 | )
48 | self.mlp_layers = nn.ModuleList(mlp_layers)
49 |
50 | def forward(self, x: Tensor, valid: Tensor) -> Tuple[Tensor, Tensor]:
51 | """c.f. VectorNet and SceneTransformer, Aggregate polyline/track level feature.
52 |
53 | Args:
54 | x: [n_batch, n_pl, n_pl_node, attr_dim]
55 | valid: [n_batch, n_pl, n_pl_node] bool
56 |
57 | Returns:
58 | emb: [n_batch, n_pl, hidden_dim]
59 | emb_valid: [n_batch, n_pl]
60 | """
61 | x = self.input_mlp(x, valid) # [n_batch, n_pl, n_pl_node, hidden_dim]
62 |
63 | for mlp in self.mlp_layers:
64 | feature_encoded = mlp(x, valid, float("-inf")) # [n_batch, n_pl, n_pl_node, hidden_dim//2]
65 | feature_pooled = feature_encoded.amax(dim=2, keepdim=True)
66 | x = torch.cat((feature_encoded, feature_pooled.expand(-1, -1, valid.shape[-1], -1)), dim=-1)
67 |
68 | if self.pool_mode == "max":
69 | x.masked_fill_(~valid.unsqueeze(-1), float("-inf")) # [n_batch, n_pl, n_pl_node, hidden_dim]
70 | emb = x.amax(dim=2, keepdim=False) # [n_batch, n_pl, hidden_dim]
71 | elif self.pool_mode == "first":
72 | emb = x[:, :, 0]
73 | elif self.pool_mode == "mean":
74 | x.masked_fill_(~valid.unsqueeze(-1), 0) # [n_batch, n_pl, n_pl_node, hidden_dim]
75 | emb = x.sum(dim=2, keepdim=False) # [batch_size, n_pl, hidden_dim]
76 | emb = emb / (valid.sum(dim=-1, keepdim=True) + torch.finfo(x.dtype).eps)
77 |
78 | emb_valid = valid.any(-1) # [n_batch, n_pl]
79 | emb = emb.masked_fill(~emb_valid.unsqueeze(-1), 0) # [n_batch, n_pl, hidden_dim]
80 | return emb, emb_valid
81 |
--------------------------------------------------------------------------------
/src/models/modules/multi_modal.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import torch
3 | from torch import Tensor, nn
4 | from .mlp import MLP
5 |
6 |
7 | class MultiModalAnchors(nn.Module):
8 | def __init__(
9 | self,
10 | mode_emb: str,
11 | mode_init: str,
12 | hidden_dim: int,
13 | n_pred: int,
14 | emb_dim: int,
15 | use_agent_type: bool,
16 | scale: float = 1.0,
17 | ) -> None:
18 | super().__init__()
19 | self.n_pred = n_pred
20 | self.use_agent_type = use_agent_type
21 |
22 | self.mode_init = mode_init
23 | n_anchors = 3 if use_agent_type else 1
24 | if self.mode_init == "xavier":
25 | self.anchors = torch.empty((n_anchors, n_pred, hidden_dim))
26 | nn.init.xavier_normal_(self.anchors)
27 | self.anchors = nn.Parameter(self.anchors * scale, requires_grad=True)
28 | elif self.mode_init == "uniform":
29 | self.anchors = torch.empty((n_anchors, n_pred, hidden_dim))
30 | self.anchors.uniform_(-scale, scale)
31 | self.anchors = nn.Parameter(self.anchors, requires_grad=True)
32 | elif self.mode_init == "randn":
33 | self.anchors = nn.Parameter(torch.randn([n_anchors, n_pred, hidden_dim]) * scale, requires_grad=True)
34 | else:
35 | raise NotImplementedError
36 |
37 | self.mode_emb = mode_emb
38 | if self.mode_emb == "linear":
39 | self.mlp_anchor = nn.Linear(self.anchors.shape[-1] + emb_dim, hidden_dim, bias=False)
40 | elif self.mode_emb == "mlp":
41 | self.mlp_anchor = MLP([self.anchors.shape[-1] + emb_dim] + [hidden_dim] * 2, end_layer_activation=False)
42 | elif self.mode_emb == "add" or self.mode_emb == "none":
43 | assert emb_dim == hidden_dim
44 | if self.anchors.shape[-1] != hidden_dim:
45 | self.mlp_anchor = nn.Linear(self.anchors.shape[-1], hidden_dim, bias=False)
46 | else:
47 | self.mlp_anchor = None
48 | else:
49 | raise NotImplementedError
50 |
51 | def forward(self, valid: Tensor, emb: Tensor, agent_type: Tensor) -> Tensor:
52 | """
53 | Args:
54 | valid: [n_scene*n_agent]
55 | emb: [n_scene*n_agent, in_dim]
56 | agent_type: [n_scene*n_agent, 3]
57 |
58 | Returns:
59 | mm_emb: [n_scene*n_agent, n_pred, out_dim]
60 | """
61 | # [n_scene*n_agent, n_pred, emb_dim]
62 | if self.use_agent_type:
63 | anchors = (self.anchors.unsqueeze(0) * agent_type[:, :, None, None]).sum(1)
64 | else:
65 | anchors = self.anchors.expand(valid.shape[0], -1, -1)
66 |
67 | if self.mode_emb == "linear" or self.mode_emb == "mlp":
68 | # [n_scene*n_agent, n_pred, hidden_dim + emb_dim]
69 | mm_emb = torch.cat([emb.unsqueeze(1).expand(-1, self.n_pred, -1), anchors], dim=-1)
70 | mm_emb = self.mlp_anchor(mm_emb)
71 | elif self.mode_emb == "add":
72 | if self.mlp_anchor is not None:
73 | anchors = self.mlp_anchor(anchors) # [n_scene*n_agent, n_pred, hidden_dim]
74 | mm_emb = emb.unsqueeze(1) + anchors
75 | elif self.mode_emb == "none":
76 | if self.mlp_anchor is not None:
77 | anchors = self.mlp_anchor(anchors) # [n_scene*n_agent, n_pred, hidden_dim]
78 | mm_emb = anchors
79 | return mm_emb.masked_fill(~valid[:, None, None], 0)
80 |
--------------------------------------------------------------------------------
/src/models/modules/rpe.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional, Tuple, Union
3 | import torch
4 | from torch import Tensor
5 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_rad2local
6 |
7 |
8 | @torch.no_grad()
9 | def get_rel_pose(pose: Tensor, invalid: Tensor) -> Tuple[Tensor, Tensor]:
10 | """
11 | Args:
12 | pose: [n_scene, n_emb, 3], (x,y,yaw), in global coordinate
13 | invalid: [n_scene, n_emb]
14 |
15 | Returns:
16 | rel_pose: [n_scene, n_emb, n_emb, 3] (x,y,yaw)
17 | rel_dist: [n_scene, n_emb, n_emb]
18 | """
19 | xy = pose[:, :, :2] # [n_scene, n_emb, 2]
20 | yaw = pose[:, :, -1] # [n_scene, n_emb]
21 | rel_pose = torch.cat(
22 | [
23 | torch_pos2local(xy.unsqueeze(1), xy.unsqueeze(2), torch_rad2rot(yaw)),
24 | torch_rad2local(yaw.unsqueeze(1), yaw, cast=False).unsqueeze(-1),
25 | ],
26 | dim=-1,
27 | ) # [n_scene, n_emb, n_emb, 3]
28 | rel_dist = torch.norm(rel_pose[..., :2], dim=-1) # [n_scene, n_emb, n_emb]
29 | rel_dist.masked_fill_(invalid.unsqueeze(1) | invalid.unsqueeze(2), float("inf"))
30 | return rel_pose, rel_dist
31 |
32 |
33 | @torch.no_grad()
34 | def get_rel_dist(xy: Tensor, invalid: Tensor) -> Tensor:
35 | """
36 | Args:
37 | xy: [n_scene, n_emb, 2], in global coordinate
38 | invalid: [n_scene, n_emb]
39 |
40 | Returns:
41 | rel_dist: [n_scene, n_emb, n_emb]
42 | """
43 | rel_dist = torch.norm(xy.unsqueeze(1) - xy.unsqueeze(2), dim=-1) # [n_scene, n_emb, n_emb]
44 | rel_dist.masked_fill_(invalid.unsqueeze(1) | invalid.unsqueeze(2), float("inf"))
45 | return rel_dist
46 |
47 |
48 | @torch.no_grad()
49 | def get_tgt_knn_idx(
50 | tgt_invalid: Tensor, rel_pose: Optional[Tensor], rel_dist: Tensor, n_tgt_knn: int, dist_limit: Union[float, Tensor],
51 | ) -> Tuple[Optional[Tensor], Tensor, Optional[Tensor]]:
52 | """
53 | Args:
54 | tgt_invalid: [n_scene, n_tgt]
55 | rel_pose: [n_scene, n_src, n_tgt, 3]
56 | rel_dist: [n_scene, n_src, n_tgt]
57 | knn: int, set to <=0 to skip knn, i.e. n_tgt_knn=n_tgt
58 | dist_limit: float, or Tensor [n_scene, n_tgt, 1]
59 |
60 | Returns:
61 | idx_tgt: [n_scene, n_src, n_tgt_knn], or None
62 | tgt_invalid_knn: [n_scene, n_src, n_tgt_knn]
63 | rpe: [n_scene, n_src, n_tgt_knn, 3]
64 | """
65 | n_scene, n_src, _ = rel_dist.shape
66 | idx_scene = torch.arange(n_scene)[:, None, None] # [n_scene, 1, 1]
67 | idx_src = torch.arange(n_src)[None, :, None] # [1, n_src, 1]
68 |
69 | if 0 < n_tgt_knn < tgt_invalid.shape[1]:
70 | # [n_scene, n_src, n_tgt_knn]
71 | dist_knn, idx_tgt = torch.topk(rel_dist, n_tgt_knn, dim=-1, largest=False, sorted=False)
72 | # [n_scene, n_src, n_tgt_knn]
73 | tgt_invalid_knn = tgt_invalid.unsqueeze(1).expand(-1, n_src, -1)[idx_scene, idx_src, idx_tgt]
74 | # [n_batch, n_src, n_tgt_knn, 3]
75 | if rel_pose is None:
76 | rpe = None
77 | else:
78 | rpe = rel_pose[idx_scene, idx_src, idx_tgt]
79 | else:
80 | dist_knn = rel_dist
81 | tgt_invalid_knn = tgt_invalid.unsqueeze(1).expand(-1, n_src, -1) # [n_scene, n_src, n_tgt]
82 | rpe = rel_pose
83 | idx_tgt = None
84 |
85 | tgt_invalid_knn = tgt_invalid_knn | (dist_knn > dist_limit)
86 | if rpe is not None:
87 | rpe = rpe.masked_fill(tgt_invalid_knn.unsqueeze(-1), 0)
88 |
89 | return idx_tgt, tgt_invalid_knn, rpe
90 |
--------------------------------------------------------------------------------
/env_av2.yml:
--------------------------------------------------------------------------------
1 | name: hptr_av2
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - asttokens=2.2.1=pyhd8ed1ab_0
9 | - backcall=0.2.0=pyh9f0ad1d_0
10 | - backports=1.0=pyhd8ed1ab_3
11 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
12 | - ca-certificates=2022.12.7=ha878542_0
13 | - certifi=2022.12.7=pyhd8ed1ab_0
14 | - debugpy=1.5.1=py39h295c915_0
15 | - decorator=5.1.1=pyhd8ed1ab_0
16 | - entrypoints=0.4=pyhd8ed1ab_0
17 | - executing=1.2.0=pyhd8ed1ab_0
18 | - ipykernel=6.15.0=pyh210e3f2_0
19 | - ipython=8.11.0=pyh41d4057_0
20 | - jedi=0.18.2=pyhd8ed1ab_0
21 | - jupyter_client=7.0.6=pyhd8ed1ab_0
22 | - jupyter_core=5.2.0=py39hf3d152e_0
23 | - ld_impl_linux-64=2.38=h1181459_1
24 | - libffi=3.4.2=h6a678d5_6
25 | - libgcc-ng=11.2.0=h1234567_1
26 | - libgomp=11.2.0=h1234567_1
27 | - libsodium=1.0.18=h36c2ea0_1
28 | - libstdcxx-ng=11.2.0=h1234567_1
29 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
30 | - ncurses=6.4=h6a678d5_0
31 | - nest-asyncio=1.5.6=pyhd8ed1ab_0
32 | - openssl=1.1.1t=h7f8727e_0
33 | - packaging=23.0=pyhd8ed1ab_0
34 | - parso=0.8.3=pyhd8ed1ab_0
35 | - pexpect=4.8.0=pyh1a96a4e_2
36 | - pickleshare=0.7.5=py_1003
37 | - pip=23.0.1=py39h06a4308_0
38 | - platformdirs=3.1.1=pyhd8ed1ab_0
39 | - prompt-toolkit=3.0.38=pyha770c72_0
40 | - prompt_toolkit=3.0.38=hd8ed1ab_0
41 | - psutil=5.9.0=py39h5eee18b_0
42 | - ptyprocess=0.7.0=pyhd3deb0d_0
43 | - pure_eval=0.2.2=pyhd8ed1ab_0
44 | - pygments=2.14.0=pyhd8ed1ab_0
45 | - python=3.9.16=h7a1cb2a_2
46 | - python-dateutil=2.8.2=pyhd8ed1ab_0
47 | - python_abi=3.9=2_cp39
48 | - pyzmq=19.0.2=py39hb69f2a1_2
49 | - readline=8.2=h5eee18b_0
50 | - setuptools=65.6.3=py39h06a4308_0
51 | - six=1.16.0=pyh6c4a22f_0
52 | - sqlite=3.40.1=h5082296_0
53 | - stack_data=0.6.2=pyhd8ed1ab_0
54 | - tk=8.6.12=h1ccaba5_0
55 | - tornado=6.1=py39hb9d737c_3
56 | - traitlets=5.9.0=pyhd8ed1ab_0
57 | - typing-extensions=4.5.0=hd8ed1ab_0
58 | - typing_extensions=4.5.0=pyha770c72_0
59 | - tzdata=2022g=h04d1e81_0
60 | - wcwidth=0.2.6=pyhd8ed1ab_0
61 | - wheel=0.38.4=py39h06a4308_0
62 | - xz=5.2.10=h5eee18b_1
63 | - zeromq=4.3.4=h9c3ff4c_1
64 | - zlib=1.2.13=h5eee18b_0
65 | - pip:
66 | - argcomplete==2.1.1
67 | - av==10.0.0
68 | - av2==0.2.1
69 | - black==23.1.0
70 | - click==8.1.3
71 | - colorlog==6.7.0
72 | - contourpy==1.0.7
73 | - cycler==0.11.0
74 | - distlib==0.3.6
75 | - filelock==3.9.0
76 | - fonttools==4.39.0
77 | - fsspec==2023.3.0
78 | - h5py==3.8.0
79 | - importlib-resources==5.12.0
80 | - joblib==1.2.0
81 | - kiwisolver==1.4.4
82 | - llvmlite==0.39.1
83 | - markdown-it-py==2.2.0
84 | - matplotlib==3.7.1
85 | - mdurl==0.1.2
86 | - mypy-extensions==1.0.0
87 | - nox==2022.11.21
88 | - numba==0.56.4
89 | - numpy==1.23.5
90 | - nvidia-cublas-cu11==11.10.3.66
91 | - nvidia-cuda-nvrtc-cu11==11.7.99
92 | - nvidia-cuda-runtime-cu11==11.7.99
93 | - nvidia-cudnn-cu11==8.5.0.96
94 | - opencv-python==4.7.0.72
95 | - pandas==1.5.3
96 | - pathspec==0.11.1
97 | - pillow==9.4.0
98 | - pyarrow==11.0.0
99 | - pyparsing==3.0.9
100 | - pyproj==3.4.1
101 | - pytz==2022.7.1
102 | - rich==13.3.2
103 | - scipy==1.10.1
104 | - tomli==2.0.1
105 | - torch==1.13.1
106 | - tqdm==4.65.0
107 | - transforms3d==0.4.1
108 | - upath==1.0
109 | - virtualenv==20.21.0
110 | - zipp==3.15.0
111 |
--------------------------------------------------------------------------------
/src/utils/pose_pe.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import torch
3 | from torch import Tensor, nn
4 | from models.modules.pos_emb import PositionalEmbedding, PositionalEmbeddingRad
5 |
6 |
7 | class PosePE(nn.Module):
8 | def __init__(self, mode: str, pe_dim: int = 256, theta_xy: float = 1e3, theta_cs: float = 1e1):
9 | super().__init__()
10 | self.mode = mode
11 | if self.mode == "xy_dir":
12 | self.out_dim = 4
13 | elif self.mode == "mpa_pl":
14 | self.out_dim = 7
15 | elif self.mode == "pe_xy_dir":
16 | self.out_dim = pe_dim
17 | self.pe_xy = PositionalEmbedding(dim=pe_dim // 4, theta=theta_xy)
18 | self.pe_dir = PositionalEmbedding(dim=pe_dim // 4, theta=theta_cs)
19 | elif self.mode == "pe_xy_yaw":
20 | self.out_dim = pe_dim
21 | self.pe_xy = PositionalEmbedding(dim=pe_dim // 4, theta=theta_xy)
22 | self.pe_yaw = PositionalEmbeddingRad(dim=pe_dim // 2)
23 | else:
24 | raise NotImplementedError
25 |
26 | def forward(self, xy: Tensor, dir: Tensor):
27 | """
28 | Args: input either dir or yaw.
29 | xy: [..., 2]
30 | dir: cos/sin [..., 2] or yaw [..., 1]
31 |
32 | Returns:
33 | pos_out: [..., self.out_dim]
34 | """
35 | if self.mode == "xy_dir":
36 | if dir.shape[-1] == 1:
37 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1)
38 | pos_out = torch.cat([xy, dir], dim=-1)
39 | elif self.mode == "mpa_pl":
40 | if dir.shape[-1] == 1:
41 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1)
42 | pos_out = self.encode_polyline(xy, dir)
43 | elif self.mode == "pe_xy_dir":
44 | if dir.shape[-1] == 1:
45 | dir = torch.cat([dir.cos(), dir.sin()], dim=-1)
46 | pos_out = torch.cat(
47 | [self.pe_xy(xy[..., 0]), self.pe_xy(xy[..., 1]), self.pe_dir(dir[..., 0]), self.pe_dir(dir[..., 1])],
48 | dim=-1,
49 | )
50 | elif self.mode == "pe_xy_yaw":
51 | if dir.shape[-1] == 1:
52 | dir = dir.squeeze(-1)
53 | else:
54 | dir = torch.atan2(dir[..., 1], dir[..., 0])
55 | pos_out = torch.cat([self.pe_xy(xy[..., 0]), self.pe_xy(xy[..., 1]), self.pe_yaw(dir)], dim=-1)
56 | return pos_out
57 |
58 | @staticmethod
59 | def encode_polyline(pos: Tensor, dir: Tensor) -> Tensor:
60 | """
61 | Args: pos and dir with respect to the agent
62 | pos: [..., 2]
63 | dir: [..., 2]
64 |
65 | Returns:
66 | pl_feature: [..., 7]
67 | """
68 | eps = torch.finfo(pos.dtype).eps
69 | # [n_scene, n_target, n_map, n_pl_node, 2]
70 | segments_start = pos
71 | segment_vec = dir
72 | # [n_scene, n_target, n_map, n_pl_node]
73 | segment_proj = (-segments_start * segment_vec).sum(-1) / ((segment_vec * segment_vec).sum(-1) + eps)
74 | # [n_scene, n_target, n_map, n_pl_node, 2]
75 | closest_points = segments_start + torch.clamp(segment_proj, min=0, max=1).unsqueeze(-1) * segment_vec
76 | # [n_scene, n_target, n_map, n_pl_node, 1]
77 | r_norm = torch.norm(closest_points, dim=-1, keepdim=True)
78 | segment_vec_norm = torch.norm(segment_vec, dim=-1, keepdim=True)
79 | pl_feature = torch.cat(
80 | [
81 | r_norm, # 1
82 | closest_points / (r_norm + eps), # 2
83 | segment_vec / (segment_vec_norm + eps), # 2
84 | segment_vec_norm, # 1
85 | torch.norm(segments_start + segment_vec - closest_points, dim=-1, keepdim=True), # 1
86 | ],
87 | dim=-1,
88 | )
89 | return pl_feature
--------------------------------------------------------------------------------
/configs/model/acg_womd.yaml:
--------------------------------------------------------------------------------
1 | _target_: pl_modules.waymo_motion.WaymoMotion
2 |
3 | time_step_current: 10
4 | time_step_end: 90
5 | n_video_batch: 3
6 | interactive_challenge: False
7 | inference_cache_map: False
8 | inference_repeat_n: 1
9 |
10 | train_metric:
11 | _target_: models.metrics.nll.NllMetrics
12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd
13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2
14 | p_rand_train_agent: -1 # 0.2
15 | n_step_add_train_agent: [-1, -1, -1] # -1 to turn off
16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off
17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc
18 | w_pos: [1.0, 1.0, 1.0]
19 | w_yaw: [1.0, 1.0, 1.0]
20 | w_vel: [1.0, 1.0, 1.0]
21 | w_spd: [0, 0, 0]
22 |
23 | waymo_metric:
24 | _target_: models.metrics.waymo.WaymoMetrics
25 | n_max_pred_agent: 8
26 |
27 | pre_processing:
28 | agent_centric:
29 | _target_: data_modules.agent_centric.AgentCentricPreProcessing
30 | mask_invalid: False
31 | n_target: 8
32 | n_other: 48
33 | n_map: 512
34 | n_tl: 24
35 | ac_global:
36 | _target_: data_modules.ac_global.AgentCentricGlobal
37 | dropout_p_history: 0.15
38 | use_current_tl: False
39 | add_ohe: True
40 | pl_aggr: False
41 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw
42 | agent: xy_dir
43 | map: mpa_pl
44 | tl: mpa_pl
45 |
46 | post_processing:
47 | to_dict:
48 | _target_: data_modules.post_processing.ToDict
49 | predictions: ${...model.decoder.mlp_head.predictions}
50 | get_cov_mat:
51 | _target_: data_modules.post_processing.GetCovMat
52 | rho_clamp: 5.0
53 | std_min: -1.609
54 | std_max: 5.0
55 | waymo:
56 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
57 | k_pred: 6
58 | use_ade: True
59 | score_temperature: -1
60 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc
61 | gt_in_local: True
62 |
63 | model:
64 | _target_: models.ac_global.AgentCentricGlobal
65 | hidden_dim: 256
66 | n_decoders: 1
67 | tf_cfg:
68 | n_head: 4
69 | dropout_p: 0.1
70 | norm_first: True
71 | bias: True
72 | intra_class_encoder:
73 | add_learned_pe: True
74 | use_point_net: False
75 | n_layer_mlp: 3
76 | mlp_cfg:
77 | end_layer_activation: True
78 | use_layernorm: False
79 | use_batchnorm: False
80 | dropout_p: null
81 | n_layer_tf: -1
82 | decoder:
83 | _target_: models.ac_global.Decoder
84 | n_latent_query: 192
85 | n_layer_tf_all2all: 6
86 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n
87 | n_layer_tf_anchor: 8
88 | n_pred: 6
89 | latent_query:
90 | use_agent_type: False
91 | mode_emb: none # linear, mlp, add, none
92 | mode_init: xavier # uniform, xavier
93 | scale: 5.0
94 | multi_modal_anchors:
95 | use_agent_type: True
96 | mode_emb: none # linear, mlp, add, none
97 | mode_init: xavier # uniform, xavier
98 | scale: 5.0
99 | anchor_self_attn: True
100 | mlp_head:
101 | predictions: [pos, cov3] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox
102 | use_agent_type: False
103 | flatten_conf_head: False
104 | out_mlp_layernorm: False
105 | out_mlp_batchnorm: False
106 | n_step_future: 80
107 | use_vmap: True
108 |
109 | # * optimizer
110 | optimizer:
111 | _target_: torch.optim.AdamW
112 | lr: 2e-4 # 2e-4, 1e-5
113 | lr_scheduler:
114 | _target_: torch.optim.lr_scheduler.StepLR
115 | gamma: 0.5
116 | step_size: 20
117 |
118 | sub_womd:
119 | _target_: utils.submission.SubWOMD
120 | activate: False
121 | method_name: METHOD_NAME
122 | authors: [NAME1, NAME2]
123 | affiliation: AFFILIATION
124 | description: scr_womd
125 | method_link: METHOD_LINK
126 | account_name: ACCOUNT_NAME
127 | sub_av2:
128 | _target_: utils.submission.SubAV2
129 | activate: False
130 |
--------------------------------------------------------------------------------
/src/data_modules/post_processing.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import List, Dict
3 | import torch
4 | from torch import nn, Tensor
5 |
6 |
7 | class ToDict(nn.Module):
8 | def __init__(self, predictions: List[str]) -> None:
9 | super().__init__()
10 | self.dims = {"pred_pos": None, "pred_spd": None, "pred_vel": None, "pred_yaw_bbox": None, "pred_cov": None}
11 |
12 | pred_dim = 0
13 | if "pos" in predictions:
14 | self.dims["pred_pos"] = (pred_dim, pred_dim + 2)
15 | pred_dim += 2
16 | if "spd" in predictions:
17 | self.dims["pred_spd"] = (pred_dim, pred_dim + 1)
18 | pred_dim += 1
19 | if "vel" in predictions:
20 | self.dims["pred_vel"] = (pred_dim, pred_dim + 2)
21 | pred_dim += 2
22 | if "yaw_bbox" in predictions:
23 | self.dims["pred_yaw_bbox"] = (pred_dim, pred_dim + 1)
24 | pred_dim += 1
25 | if "cov1" in predictions:
26 | self.dims["pred_cov"] = (pred_dim, pred_dim + 1)
27 | pred_dim += 1
28 | elif "cov2" in predictions:
29 | self.dims["pred_cov"] = (pred_dim, pred_dim + 2)
30 | pred_dim += 2
31 | elif "cov3" in predictions:
32 | self.dims["pred_cov"] = (pred_dim, pred_dim + 3)
33 | pred_dim += 3
34 |
35 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
36 | """
37 | Inputs:
38 | valid: [n_scene, n_target]
39 | conf: [n_decoders, n_scene, n_target, n_pred], not normalized!
40 | pred: [n_decoders, n_scene, n_target, n_pred, n_step_future, pred_dim]
41 | """
42 | for k, v in self.dims.items():
43 | if v is None:
44 | pred_dict[k] = None
45 | else:
46 | pred_dict[k] = pred_dict["pred"][..., v[0] : v[1]]
47 | # del pred_dict["pred"]
48 | return pred_dict
49 |
50 |
51 | class GetCovMat(nn.Module):
52 | def __init__(self, rho_clamp: float, std_min: float, std_max: float) -> None:
53 | super().__init__()
54 | self.rho_clamp = rho_clamp
55 | self.std_min = std_min
56 | self.std_max = std_max
57 |
58 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
59 | """
60 | Inputs:
61 | pred_cov: [n_decoders, n_scene, n_target, n_pred, n_step_future, 1/2/3]
62 |
63 | Outputs:
64 | pred_cov: [n_decoders, n_scene, n_target, n_pred, n_step_future, 2, 2], in tril form.
65 | """
66 | if pred_dict["pred_cov"] is not None:
67 | cov_shape = pred_dict["pred_cov"].shape
68 | cov_dof = cov_shape[-1]
69 | if cov_dof == 3:
70 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp()
71 | b = torch.clamp(pred_dict["pred_cov"][..., 1], min=self.std_min, max=self.std_max).exp()
72 | c = torch.clamp(pred_dict["pred_cov"][..., 2], min=-self.rho_clamp, max=self.rho_clamp)
73 | elif cov_dof == 2:
74 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp()
75 | b = torch.clamp(pred_dict["pred_cov"][..., 1], min=self.std_min, max=self.std_max).exp()
76 | c = torch.zeros_like(a)
77 | elif cov_dof == 1:
78 | a = torch.clamp(pred_dict["pred_cov"][..., 0], min=self.std_min, max=self.std_max).exp()
79 | b = a
80 | c = torch.zeros_like(a)
81 |
82 | pred_dict["pred_cov"] = torch.stack([a, torch.zeros_like(a), c, b], dim=-1).view(*cov_shape[:-1], 2, 2)
83 |
84 | return pred_dict
85 |
86 |
87 | class OffsetToKmeans(nn.Module):
88 | def __init__(self, *args, **kwargs) -> None:
89 | super().__init__()
90 |
91 | def forward(self, pred_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
92 | # not used, for back compatibility
93 | return pred_dict
94 |
--------------------------------------------------------------------------------
/configs/model/scg_womd.yaml:
--------------------------------------------------------------------------------
1 | _target_: pl_modules.waymo_motion.WaymoMotion
2 |
3 | time_step_current: 10
4 | time_step_end: 90
5 | n_video_batch: 3
6 | interactive_challenge: False
7 | inference_cache_map: False
8 | inference_repeat_n: 1
9 |
10 | train_metric:
11 | _target_: models.metrics.nll.NllMetrics
12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd
13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2
14 | p_rand_train_agent: -1 # 0.2
15 | n_step_add_train_agent: [-1, 40, 40] # -1 to turn off
16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off
17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc
18 | w_pos: [1.0, 1.0, 1.0]
19 | w_yaw: [1.0, 1.0, 1.0]
20 | w_vel: [1.0, 1.0, 1.0]
21 | w_spd: [1.0, 1.0, 1.0]
22 |
23 | waymo_metric:
24 | _target_: models.metrics.waymo.WaymoMetrics
25 | n_max_pred_agent: 8
26 |
27 | pre_processing:
28 | scene_centric:
29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing
30 | gt_in_local: True
31 | mask_invalid: False
32 | global:
33 | _target_: data_modules.sc_global.SceneCentricGlobal
34 | dropout_p_history: -1
35 | use_current_tl: True
36 | add_ohe: True
37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet
38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw
39 | agent: pe_xy_dir
40 | map: pe_xy_dir
41 | tl: pe_xy_dir
42 |
43 | post_processing:
44 | to_dict:
45 | _target_: data_modules.post_processing.ToDict
46 | predictions: ${...model.decoder.mlp_head.predictions}
47 | get_cov_mat:
48 | _target_: data_modules.post_processing.GetCovMat
49 | rho_clamp: 5.0
50 | std_min: -1.609
51 | std_max: 5.0
52 | waymo:
53 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
54 | k_pred: 6
55 | use_ade: True
56 | score_temperature: -1
57 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc
58 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local}
59 |
60 | model:
61 | _target_: models.sc_global.SceneCentricGlobal
62 | hidden_dim: 256
63 | n_tgt_knn: 36
64 | dist_limit_map: 1500
65 | dist_limit_tl: 1000
66 | dist_limit_agent: [1500, 500, 1000]
67 | decoder_remove_ego_agent: False
68 | n_decoders: 1
69 | tf_cfg:
70 | n_head: 4
71 | dropout_p: 0.1
72 | norm_first: True
73 | bias: True
74 | intra_class_encoder:
75 | n_layer_mlp: 3
76 | mlp_cfg:
77 | end_layer_activation: True
78 | use_layernorm: False
79 | use_batchnorm: False
80 | dropout_p: null
81 | n_layer_tf_map: 6
82 | n_layer_tf_tl: -1
83 | n_layer_tf_agent: -1
84 | decoder:
85 | _target_: models.sc_global.Decoder
86 | n_pred: 6
87 | tf_n_layer: 2
88 | k_reinforce_tl: 2
89 | k_reinforce_agent: 4
90 | k_reinforce_all: -1
91 | k_reinforce_anchor: 10
92 | n_latent_query: -1
93 | latent_query:
94 | use_agent_type: False
95 | mode_emb: linear # linear, mlp, add, none
96 | mode_init: xavier # uniform, xavier
97 | scale: 5.0
98 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n
99 | multi_modal_anchors:
100 | use_agent_type: True
101 | mode_emb: linear # linear, mlp, add, none
102 | mode_init: xavier # uniform, xavier
103 | scale: 5.0
104 | anchor_self_attn: True
105 | mlp_head:
106 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox
107 | use_agent_type: False
108 | flatten_conf_head: False
109 | out_mlp_layernorm: False
110 | out_mlp_batchnorm: False
111 | n_step_future: 80
112 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor
113 | use_vmap: True
114 |
115 | # * optimizer
116 | optimizer:
117 | _target_: torch.optim.AdamW
118 | lr: 1e-4 # 2e-4, 1e-5
119 | lr_scheduler:
120 | _target_: torch.optim.lr_scheduler.StepLR
121 | gamma: 0.5
122 | step_size: 25
123 |
124 | sub_womd:
125 | _target_: utils.submission.SubWOMD
126 | activate: False
127 | method_name: METHOD_NAME
128 | authors: [NAME1, NAME2]
129 | affiliation: AFFILIATION
130 | description: scr_womd
131 | method_link: METHOD_LINK
132 | account_name: ACCOUNT_NAME
133 | sub_av2:
134 | _target_: utils.submission.SubAV2
135 | activate: False
136 |
--------------------------------------------------------------------------------
/configs/model/scr_av2.yaml:
--------------------------------------------------------------------------------
1 | _target_: pl_modules.waymo_motion.WaymoMotion
2 |
3 | time_step_current: 49
4 | time_step_end: 109
5 | n_video_batch: 3
6 | interactive_challenge: False
7 | inference_cache_map: False
8 | inference_repeat_n: 1
9 |
10 | train_metric:
11 | _target_: models.metrics.nll.NllMetrics
12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd
13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2
14 | p_rand_train_agent: -1 # 0.2
15 | n_step_add_train_agent: [-1, -1, -1] # -1 to turn off
16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off
17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc
18 | w_pos: [1.0, 1.0, 1.0]
19 | w_yaw: [1.0, 1.0, 1.0]
20 | w_vel: [1.0, 1.0, 1.0]
21 | w_spd: [1.0, 1.0, 1.0]
22 |
23 | waymo_metric:
24 | _target_: models.metrics.waymo.WaymoMetrics
25 | n_max_pred_agent: 1
26 |
27 | pre_processing:
28 | scene_centric:
29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing
30 | gt_in_local: True
31 | mask_invalid: False
32 | relative:
33 | _target_: data_modules.sc_relative.SceneCentricRelative
34 | dropout_p_history: -1
35 | use_current_tl: True
36 | add_ohe: True
37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet
38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw
39 | agent: mpa_pl
40 | map: mpa_pl
41 |
42 | post_processing:
43 | to_dict:
44 | _target_: data_modules.post_processing.ToDict
45 | predictions: ${...model.decoder.mlp_head.predictions}
46 | get_cov_mat:
47 | _target_: data_modules.post_processing.GetCovMat
48 | rho_clamp: 5.0
49 | std_min: -1.609
50 | std_max: 5.0
51 | waymo:
52 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
53 | k_pred: 6
54 | use_ade: True
55 | score_temperature: 0.5
56 | mpa_nms_thresh: [] # veh, ped, cyc
57 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local}
58 |
59 | model:
60 | _target_: models.sc_relative.SceneCentricRelative
61 | hidden_dim: 256
62 | rpe_mode: pe_xy_yaw # xy_dir, pe_xy_dir, pe_xy_yaw
63 | n_tgt_knn: 36
64 | dist_limit_map: 1500
65 | dist_limit_tl: 1000
66 | dist_limit_agent: [1500, 500, 1000]
67 | decoder_remove_ego_agent: False
68 | n_decoders: 1
69 | tf_cfg:
70 | n_head: 4
71 | dropout_p: 0.1
72 | norm_first: True
73 | apply_q_rpe: False
74 | bias: True
75 | intra_class_encoder:
76 | n_layer_mlp: 3
77 | mlp_cfg:
78 | end_layer_activation: True
79 | use_layernorm: False
80 | use_batchnorm: False
81 | dropout_p: null
82 | n_layer_tf_map: 6
83 | n_layer_tf_tl: -1
84 | n_layer_tf_agent: -1
85 | decoder:
86 | _target_: models.sc_relative.Decoder
87 | n_pred: 6
88 | tf_n_layer: 2
89 | k_reinforce_tl: 2
90 | k_reinforce_agent: 4
91 | k_reinforce_all: -1
92 | k_reinforce_anchor: 10
93 | n_latent_query: -1
94 | latent_query:
95 | use_agent_type: False
96 | mode_emb: linear # linear, mlp, add, none
97 | mode_init: randn # uniform, xavier
98 | scale: 5.0
99 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n
100 | multi_modal_anchors:
101 | use_agent_type: True
102 | mode_emb: linear # linear, mlp, add, none
103 | mode_init: randn # uniform, xavier
104 | scale: 5.0
105 | anchor_self_attn: True
106 | mlp_head:
107 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox
108 | use_agent_type: False
109 | flatten_conf_head: False
110 | out_mlp_layernorm: False
111 | out_mlp_batchnorm: False
112 | n_step_future: 60
113 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor
114 | use_vmap: True
115 |
116 | # * optimizer
117 | optimizer:
118 | _target_: torch.optim.AdamW
119 | lr: 1e-4 # 2e-4, 1e-5
120 | lr_scheduler:
121 | _target_: torch.optim.lr_scheduler.StepLR
122 | gamma: 0.5
123 | step_size: 25
124 |
125 | sub_womd:
126 | _target_: utils.submission.SubWOMD
127 | activate: False
128 | method_name: METHOD_NAME
129 | authors: [NAME1, NAME2]
130 | affiliation: AFFILIATION
131 | description: scr_womd
132 | method_link: METHOD_LINK
133 | account_name: ACCOUNT_NAME
134 | sub_av2:
135 | _target_: utils.submission.SubAV2
136 | activate: False
137 |
--------------------------------------------------------------------------------
/configs/model/scr_womd.yaml:
--------------------------------------------------------------------------------
1 | _target_: pl_modules.waymo_motion.WaymoMotion
2 |
3 | time_step_current: 10
4 | time_step_end: 90
5 | n_video_batch: 3
6 | interactive_challenge: False
7 | inference_cache_map: False
8 | inference_repeat_n: 1
9 |
10 | train_metric:
11 | _target_: models.metrics.nll.NllMetrics
12 | winner_takes_all: hard1 # none, or (joint) + hard + (1-6), or cmd
13 | l_pos: nll_torch # nll_torch, nll_mtr, huber, l2
14 | p_rand_train_agent: -1 # 0.2
15 | n_step_add_train_agent: [-1, 40, 40] # -1 to turn off
16 | focal_gamma_conf: [0.0, 0.0, 0.0] # 0.0 to turn off
17 | w_conf: [1.0, 1.0, 1.0] # veh, ped, cyc
18 | w_pos: [1.0, 1.0, 1.0]
19 | w_yaw: [1.0, 1.0, 1.0]
20 | w_vel: [1.0, 1.0, 1.0]
21 | w_spd: [1.0, 1.0, 1.0]
22 |
23 | waymo_metric:
24 | _target_: models.metrics.waymo.WaymoMetrics
25 | n_max_pred_agent: 8
26 |
27 | pre_processing:
28 | scene_centric:
29 | _target_: data_modules.scene_centric.SceneCentricPreProcessing
30 | gt_in_local: True
31 | mask_invalid: False
32 | relative:
33 | _target_: data_modules.sc_relative.SceneCentricRelative
34 | dropout_p_history: -1
35 | use_current_tl: True
36 | add_ohe: True
37 | pl_aggr: False # True for using our MLP as polyline subnet, False for using VectorNet PointNet as polyline subnet
38 | pose_pe: # xy_dir, mpa_pl, pe_xy_dir, pe_xy_yaw
39 | agent: mpa_pl
40 | map: mpa_pl
41 |
42 | post_processing:
43 | to_dict:
44 | _target_: data_modules.post_processing.ToDict
45 | predictions: ${...model.decoder.mlp_head.predictions}
46 | get_cov_mat:
47 | _target_: data_modules.post_processing.GetCovMat
48 | rho_clamp: 5.0
49 | std_min: -1.609
50 | std_max: 5.0
51 | waymo:
52 | _target_: data_modules.waymo_post_processing.WaymoPostProcessing
53 | k_pred: 6
54 | use_ade: True
55 | score_temperature: -1
56 | mpa_nms_thresh: [2.5, 1.0, 2.0] # veh, ped, cyc
57 | gt_in_local: ${...pre_processing.scene_centric.gt_in_local}
58 |
59 | model:
60 | _target_: models.sc_relative.SceneCentricRelative
61 | hidden_dim: 256
62 | rpe_mode: pe_xy_yaw # xy_dir, pe_xy_dir, pe_xy_yaw
63 | n_tgt_knn: 36
64 | dist_limit_map: 1500
65 | dist_limit_tl: 1000
66 | dist_limit_agent: [1500, 500, 1000]
67 | decoder_remove_ego_agent: False
68 | n_decoders: 1
69 | tf_cfg:
70 | n_head: 4
71 | dropout_p: 0.1
72 | norm_first: True
73 | apply_q_rpe: False
74 | bias: True
75 | intra_class_encoder:
76 | n_layer_mlp: 3
77 | mlp_cfg:
78 | end_layer_activation: True
79 | use_layernorm: False
80 | use_batchnorm: False
81 | dropout_p: null
82 | n_layer_tf_map: 6
83 | n_layer_tf_tl: -1
84 | n_layer_tf_agent: -1
85 | decoder:
86 | _target_: models.sc_relative.Decoder
87 | n_pred: 6
88 | tf_n_layer: 2
89 | k_reinforce_tl: 2
90 | k_reinforce_agent: 4
91 | k_reinforce_all: -1
92 | k_reinforce_anchor: 10
93 | n_latent_query: -1
94 | latent_query:
95 | use_agent_type: False
96 | mode_emb: linear # linear, mlp, add, none
97 | mode_init: xavier # uniform, xavier
98 | scale: 5.0
99 | latent_query_use_tf_decoder: False # True: (cross+self)*n, False: cross+self*n
100 | multi_modal_anchors:
101 | use_agent_type: True
102 | mode_emb: linear # linear, mlp, add, none
103 | mode_init: xavier # uniform, xavier
104 | scale: 5.0
105 | anchor_self_attn: True
106 | mlp_head:
107 | predictions: [pos, cov3, spd, vel, yaw_bbox] # keywords: pos, cov1/2/3, spd, vel, yaw_bbox
108 | use_agent_type: False
109 | flatten_conf_head: False
110 | out_mlp_layernorm: False
111 | out_mlp_batchnorm: False
112 | n_step_future: 80
113 | use_attr_for_multi_modal: False # use agent_attr instead of agent_emb for learnable latent/anchor
114 | use_vmap: True
115 |
116 | # * optimizer
117 | optimizer:
118 | _target_: torch.optim.AdamW
119 | lr: 1e-4 # 2e-4, 1e-5
120 | lr_scheduler:
121 | _target_: torch.optim.lr_scheduler.StepLR
122 | gamma: 0.5
123 | step_size: 25
124 |
125 | sub_womd:
126 | _target_: utils.submission.SubWOMD
127 | activate: False
128 | method_name: METHOD_NAME
129 | authors: [NAME1, NAME2]
130 | affiliation: AFFILIATION
131 | description: scr_womd
132 | method_link: METHOD_LINK
133 | account_name: ACCOUNT_NAME
134 | sub_av2:
135 | _target_: utils.submission.SubAV2
136 | activate: False
137 |
--------------------------------------------------------------------------------
/src/models/modules/decoder_ensemble.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Tuple, List
3 | import hydra
4 | import torch
5 | from torch import nn, Tensor
6 | from omegaconf import DictConfig
7 | from functorch import combine_state_for_ensemble, vmap
8 | from .mlp import MLP
9 |
10 |
11 | class DecoderEnsemble(nn.Module):
12 | def __init__(self, n_decoders: int, decoder_cfg: DictConfig) -> None:
13 | super().__init__()
14 | self.use_vmap = decoder_cfg["use_vmap"]
15 | self.n_decoders = n_decoders
16 | if self.use_vmap and self.n_decoders > 1:
17 | _decoders = [hydra.utils.instantiate(decoder_cfg) for _ in range(n_decoders)]
18 | fmodel_decoders, params_decoders, buffers_decoders = combine_state_for_ensemble(_decoders)
19 | assert buffers_decoders == ()
20 | self.v_model = vmap(fmodel_decoders, randomness="different")
21 | [p.requires_grad_() for p in params_decoders]
22 | self.params_decoders = nn.ParameterList(params_decoders)
23 | else:
24 | self._decoders = nn.ModuleList([hydra.utils.instantiate(decoder_cfg) for _ in range(n_decoders)])
25 |
26 | def forward(self, **kwargs) -> Tuple[Tensor, Tensor]:
27 | """
28 | Returns:
29 | conf: [n_decoders, n_scene, n_agent, n_pred]
30 | pred: [n_decoders, n_scene, n_agent, n_pred, n_step_future, pred_dim]
31 | """
32 | if self.use_vmap and self.n_decoders > 1:
33 | conf, pred = self.v_model(tuple(self.params_decoders), (), **kwargs)
34 | else:
35 | conf, pred = [], []
36 | for decoder in self._decoders:
37 | c, p = decoder(**kwargs)
38 | conf.append(c)
39 | pred.append(p)
40 | conf = torch.stack(conf, dim=0)
41 | pred = torch.stack(pred, dim=0)
42 | return conf, pred
43 |
44 |
45 | class MLPHead(nn.Module):
46 | def __init__(
47 | self,
48 | hidden_dim: int,
49 | use_vmap: bool,
50 | n_step_future: int,
51 | out_mlp_layernorm: bool,
52 | out_mlp_batchnorm: bool,
53 | use_agent_type: bool,
54 | predictions: List[str],
55 | **kwargs,
56 | ) -> None:
57 | super().__init__()
58 | self.use_agent_type = use_agent_type
59 | self.n_step_future = n_step_future
60 |
61 | self.pred_dim = 0
62 | if "pos" in predictions:
63 | self.pred_dim += 2
64 | if "spd" in predictions:
65 | self.pred_dim += 1
66 | if "vel" in predictions:
67 | self.pred_dim += 2
68 | if "yaw_bbox" in predictions:
69 | self.pred_dim += 1
70 | if "cov1" in predictions:
71 | self.pred_dim += 1
72 | elif "cov2" in predictions:
73 | self.pred_dim += 2
74 | elif "cov3" in predictions:
75 | self.pred_dim += 3
76 |
77 | _d = hidden_dim * 2
78 | cfg_mlp_pred = {
79 | "fc_dims": [hidden_dim, _d, _d, self.n_step_future * self.pred_dim],
80 | "end_layer_activation": False,
81 | "use_layernorm": out_mlp_layernorm,
82 | "use_batchnorm": out_mlp_batchnorm,
83 | }
84 | cfg_mlp_conf = {
85 | "end_layer_activation": False,
86 | "use_layernorm": out_mlp_layernorm,
87 | "use_batchnorm": out_mlp_batchnorm,
88 | }
89 | n_mlp_head = 3 if use_agent_type else 1
90 | self.mlp_pred = MLPEnsemble(n_decoders=n_mlp_head, decoder_cfg=cfg_mlp_pred, use_vmap=use_vmap)
91 |
92 | cfg_mlp_conf["fc_dims"] = [hidden_dim, _d, _d, 1]
93 | self.mlp_conf = MLPEnsemble(n_decoders=n_mlp_head, decoder_cfg=cfg_mlp_conf, use_vmap=use_vmap)
94 |
95 | def forward(self, valid: Tensor, emb: Tensor, agent_type: Tensor) -> Tuple[Tensor, Tensor]:
96 | """
97 | Args:
98 | valid: [n_scene, n_agent]
99 | emb: [n_scene, n_agent, n_pred, hidden_dim]
100 | agent_type: [n_scene, n_agent, 3]
101 |
102 | Returns:
103 | conf: [n_scene, n_agent, n_pred]
104 | pred: [n_scene, n_agent, n_pred, n_step_future, pred_dim]
105 | """
106 | pred = self.mlp_pred(x=emb, valid_mask=valid.unsqueeze(-1)) # [1/3, n_scene, n_agent, n_pred, 400]
107 | conf = self.mlp_conf(x=emb, valid_mask=valid.unsqueeze(-1)).squeeze(-1) # [1/3, n_scene, n_agent, n_pred]
108 |
109 | if self.use_agent_type:
110 | _type = agent_type.movedim(-1, 0).unsqueeze(-1) # [3, n_scene, n_agent, 1]
111 | pred = (pred * _type.unsqueeze(-1)).sum(0)
112 | conf = (conf * _type).sum(0)
113 | else:
114 | pred = pred.squeeze(0)
115 | conf = conf.squeeze(0)
116 |
117 | n_scene, n_agent, n_pred = conf.shape
118 | return conf, pred.view(n_scene, n_agent, n_pred, self.n_step_future, self.pred_dim)
119 |
120 |
121 | class MLPEnsemble(nn.Module):
122 | def __init__(self, n_decoders: int, decoder_cfg: DictConfig, use_vmap: bool) -> None:
123 | super().__init__()
124 | self.use_vmap = use_vmap
125 | self.n_decoders = n_decoders
126 | if self.use_vmap and self.n_decoders > 1:
127 | _decoders = [MLP(**decoder_cfg) for _ in range(n_decoders)]
128 | fmodel_decoders, params_decoders, buffers_decoders = combine_state_for_ensemble(_decoders)
129 | assert buffers_decoders == ()
130 | self.v_model = vmap(fmodel_decoders, randomness="different")
131 | [p.requires_grad_() for p in params_decoders]
132 | self.params_decoders = nn.ParameterList(params_decoders)
133 | else:
134 | self._decoders = nn.ModuleList([MLP(**decoder_cfg) for _ in range(n_decoders)])
135 |
136 | def forward(self, **kwargs) -> Tensor:
137 | if self.use_vmap and self.n_decoders > 1:
138 | out = self.v_model(tuple(self.params_decoders), (), **kwargs)
139 | else:
140 | out = []
141 | for decoder in self._decoders:
142 | x = decoder(**kwargs)
143 | out.append(x)
144 | out = torch.stack(out, dim=0)
145 | return out
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HPTR
2 |
3 |
4 |
5 |
HPTR realizes real-time and on-board motion prediction without sacrificing the performance.
To efficiently predict the multi-modal future of numerous agents (a), HPTR minimizes the computational overhead by: (b) Sharing contexts among target agents. (c) Reusing static contexts during online inference. (d) Avoiding expensive post-processing and ensembling.
6 |
7 |
8 | > **Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding**
9 | > [Zhejun Zhang](https://zhejz.github.io/), [Alexander Liniger](https://alexliniger.github.io/), [Christos Sakaridis](https://people.ee.ethz.ch/~csakarid/), Fisher Yu and [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjcxLC0xOTcxNDY1MTc4.html).
10 | >
11 | > [NeurIPS 2023](https://neurips.cc/virtual/2023/poster/71285)
12 | > [Project Website](https://zhejz.github.io/hptr)
13 | > [arXiv Paper](https://arxiv.org/abs/2310.12970)
14 |
15 | ```bibtex
16 | @inproceedings{zhang2023hptr,
17 | title = {Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding},
18 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
19 | author = {Zhang, Zhejun and Liniger, Alexander and Sakaridis, Christos and Yu, Fisher and Van Gool, Luc},
20 | year = {2023},
21 | }
22 | ```
23 |
24 | ## Updates
25 | - The model checkpoint for Argoverse 2 is available at this wandb artifact `zhejun/hptr_av2_ckpt/av2_ckpt:v0` ([wandb project](https://wandb.ai/zhejun/hptr_av2_ckpt)).
26 | - HPTR ranks 1st in minADE and 2nd in minFDE on the [WOMD Motion Prediction Leaderboard 2023](https://waymo.com/open/challenges/2023/motion-prediction/).
27 |
28 |
29 | ## Setup Environment
30 | - Create the main [conda](https://docs.conda.io/en/latest/miniconda.html) environment by running `conda env create -f environment.yml`.
31 | - Install [Waymo Open Dataset](https://github.com/waymo-research/waymo-open-dataset) API manually because the pip installation of version 1.5.2 is not supported on some linux, e.g. CentOS. Run
32 | ```
33 | conda activate hptr
34 | wget https://files.pythonhosted.org/packages/85/1d/4cdd31fc8e88c3d689a67978c41b28b6e242bd4fe6b080cf8c99663b77e4/waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
35 | mv waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl
36 | pip install --no-deps waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl
37 | rm waymo_open_dataset_tf_2_11_0-1.5.2-py3-none-any.whl
38 | ```
39 | - Create the conda environment for packing [Argoverse 2 Motion Forecasting Dataset](https://www.argoverse.org/av2.html#forecasting-link) by running `conda env create -f env_av2.yml`.
40 | - We use [WandB](https://wandb.ai/) for logging. You can register an account for free.
41 | - Be aware
42 | - We use 4 *NVIDIA RTX 2080Ti* for training and a single 2080Ti for evaluation. The training takes at least 5 days to converge.
43 | - This repo contains the experiments for the [Waymo Motion Prediction Challenge](https://waymo.com/open/challenges/2023/motion-prediction/) and the [Argoverse 2: Motion Forecasting Competition](https://eval.ai/web/challenges/challenge-page/1719/submission)
44 | - We cannot share pre-trained models according to the [terms](https://waymo.com/open/terms) of the Waymo Open Motion Dataset.
45 |
46 | ## Prepare Datasets
47 | - Waymo Open Motion Dataset (WOMD):
48 | - Download the [Waymo Open Motion Dataset](https://waymo.com/open/data/motion/). We use v1.2.
49 | - Run `python src/pack_h5_womd.py` or use [bash/pack_h5.sh](bash/pack_h5.sh) to pack the dataset into h5 files to accelerate data loading during the training and evaluation.
50 | - You should pack three datasets: `training`, `validation` and `testing`. Packing the `training` dataset takes around 2 days. For `validation` and `testing` it should take a few hours.
51 | - Argoverse 2 Motion Forecasting Dataset (AV2):
52 | - Download the [Argoverse 2 Motion Forecasting Dataset](https://www.argoverse.org/av2.html#download-link).
53 | - Run `python src/pack_h5_av2.py` or use [bash/pack_h5.sh](bash/pack_h5.sh) to pack the dataset into h5 files to accelerate data loading during the training and evaluation.
54 | - You should pack three datasets: `training`, `validation` and `testing`. Each dataset should take a few hours.
55 |
56 |
57 | ## Training, Validation, Testing and Submission
58 | Please refer to [bash/train.sh](bash/train.sh) for the training.
59 |
60 | Once the training converges, you can use the saved checkpoints (WandB artifacts) to do validation and testing, please refer to [bash/submission.sh](bash/submission.sh) for more details.
61 |
62 | Once the validation/testing is finished, download the file `womd_K6.tar.gz` from WandB and submit to the [Waymo Motion Prediction Leaderboard](https://waymo.com/open/challenges/2023/motion-prediction/). For AV2, download the file `av2_K6.parquet` from WandB and submit to the [Argoverse 2 Motion Forecasting Competition](https://eval.ai/web/challenges/challenge-page/1719/submission).
63 |
64 |
65 | ## Performance
66 |
67 | Our submission to the [WOMD leaderboard](https://waymo.com/open/challenges/2023/motion-prediction/) is found here [here](https://waymo.com/open/challenges/entry/?challenge=MOTION_PREDICTION&challengeId=MOTION_PREDICTION_2023&emailId=5ea7a3eb-7337×tamp=1684068775971677).
68 |
69 | Our submission to the [AV2 leaderboard](https://eval.ai/web/challenges/challenge-page/1719/overview) is found here [here](https://eval.ai/web/challenges/challenge-page/1719/leaderboard/4098).
70 |
71 | ## Ablation Models
72 |
73 | Please refer to [docs/ablation_models.md](docs/ablation_models.md) for the configurations of ablation models.
74 |
75 | Specifically you can find the [Wayformer](https://arxiv.org/abs/2207.05844) and [SceneTransformer](https://arxiv.org/abs/2106.08417) based on our backbone. You can also try out different hierarchical architectures.
76 |
77 | ## License
78 |
79 | This software is made available for non-commercial use under a creative commons [license](LICENSE). You can find a summary of the license [here](https://creativecommons.org/licenses/by-nc/4.0/).
80 |
81 | ## Acknowledgement
82 |
83 | This work is funded by Toyota Motor Europe via the research project [TRACE-Zurich](https://trace.ethz.ch) (Toyota Research on Automated Cars Europe).
--------------------------------------------------------------------------------
/src/callbacks/wandb_callbacks.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional
3 | from pathlib import Path
4 | import wandb
5 | from pytorch_lightning import Callback, Trainer
6 | from pytorch_lightning.callbacks import ModelCheckpoint
7 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger
8 | from pytorch_lightning.utilities import rank_zero_only
9 |
10 |
11 | def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:
12 | """Safely get Weights&Biases logger from Trainer."""
13 |
14 | if isinstance(trainer.logger, WandbLogger):
15 | return trainer.logger
16 |
17 | if isinstance(trainer.logger, LoggerCollection):
18 | for logger in trainer.logger:
19 | if isinstance(logger, WandbLogger):
20 | return logger
21 |
22 | return None
23 | # raise Exception("You are using wandb related callback, but WandbLogger was not found for some reason...")
24 |
25 |
26 | class ModelCheckpointWB(ModelCheckpoint):
27 | def __init__(self, save_only_best=False, *args, **kwargs):
28 | super().__init__(*args, **kwargs)
29 | self.save_only_best = save_only_best
30 |
31 | def save_checkpoint(self, trainer) -> None:
32 | super().save_checkpoint(trainer)
33 | if not hasattr(self, "_logged_model_time"):
34 | self._logged_model_time = {}
35 | logger = get_wandb_logger(trainer)
36 | if self.current_score is None:
37 | self.current_score = trainer.callback_metrics.get(self.monitor)
38 | if logger is not None:
39 | self._scan_and_log_checkpoints(logger)
40 |
41 | @rank_zero_only
42 | def _scan_and_log_checkpoints(self, wb_logger: WandbLogger) -> None:
43 | if self.save_only_best:
44 | self._log_best_checkpoint(wb_logger)
45 | else:
46 | self._log_all_checkpoints(wb_logger)
47 |
48 | def _log_all_checkpoints(self, wb_logger: WandbLogger) -> None:
49 | # adapted from pytorch_lightning 1.4.0: loggers/wandb.py
50 | checkpoints = {
51 | self.last_model_path: self.current_score,
52 | self.best_model_path: self.best_model_score,
53 | }
54 | checkpoints = sorted(
55 | (Path(p).stat().st_mtime, p, s)
56 | for p, s in checkpoints.items()
57 | if Path(p).is_file()
58 | )
59 | checkpoints = [
60 | c
61 | for c in checkpoints
62 | if c[1] not in self._logged_model_time.keys()
63 | or self._logged_model_time[c[1]] < c[0]
64 | ]
65 | # log iteratively all new checkpoints
66 | for t, p, s in checkpoints:
67 | metadata = {
68 | "score": s.item(),
69 | "original_filename": Path(p).name,
70 | "ModelCheckpoint": {
71 | k: getattr(self, k)
72 | for k in [
73 | "monitor",
74 | "mode",
75 | "save_last",
76 | "save_top_k",
77 | "save_weights_only",
78 | "_every_n_train_steps",
79 | "_every_n_val_epochs",
80 | ]
81 | # ensure it does not break if `ModelCheckpoint` args change
82 | if hasattr(self, k)
83 | },
84 | }
85 | artifact = wandb.Artifact(
86 | name=wb_logger.experiment.id, type="model", metadata=metadata
87 | )
88 | artifact.add_file(p, name="model.ckpt")
89 | aliases = ["latest", "best"] if p == self.best_model_path else ["latest"]
90 | wb_logger.experiment.log_artifact(artifact, aliases=aliases)
91 | # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
92 | self._logged_model_time[p] = t
93 |
94 | def _log_best_checkpoint(self, wb_logger: WandbLogger) -> None:
95 | # Only consider the best model checkpoint for logging
96 | if Path(self.best_model_path).is_file():
97 | best_model_mtime = Path(self.best_model_path).stat().st_mtime
98 | # Check if the best model checkpoint is new or has been updated
99 | if (
100 | self.best_model_path not in self._logged_model_time
101 | or self._logged_model_time[self.best_model_path] < best_model_mtime
102 | ):
103 | # Attempt to delete the previous best artifact if it exists
104 | try:
105 | api = wandb.Api()
106 | runs = api.run(
107 | f"{wb_logger.experiment.entity}/{wb_logger.experiment.project}/{wb_logger.experiment.id}"
108 | )
109 | for artifact in runs.logged_artifacts():
110 | if "best" in artifact.aliases:
111 | artifact.delete(delete_aliases=True)
112 | print("Deleted previous best artifact.")
113 | break
114 | else:
115 | print("No previous best artifact found to delete.")
116 | except Exception as e:
117 | print(f"Could not delete previous best artifact: {e}")
118 | # Log the best model checkpoint
119 | metadata = {
120 | "score": self.best_model_score.item(),
121 | "original_filename": Path(self.best_model_path).name,
122 | "ModelCheckpoint": {
123 | k: getattr(self, k)
124 | for k in [
125 | "monitor",
126 | "mode",
127 | "save_last",
128 | "save_top_k",
129 | "save_weights_only",
130 | "_every_n_train_steps",
131 | "_every_n_val_epochs",
132 | ]
133 | if hasattr(self, k)
134 | },
135 | }
136 | artifact = wandb.Artifact(
137 | name=wb_logger.experiment.id, type="model", metadata=metadata
138 | )
139 | artifact.add_file(self.best_model_path, name="model.ckpt")
140 | wb_logger.experiment.log_artifact(artifact, aliases=["best"])
141 | # Update the log timestamp for this model checkpoint
142 | self._logged_model_time[self.best_model_path] = best_model_mtime
143 |
144 |
145 | class WatchModel(Callback):
146 | """Make wandb watch model at the beginning of the run."""
147 |
148 | def __init__(self, log: str = "gradients", log_freq: int = 100):
149 | self._log = log
150 | self._log_freq = log_freq
151 |
152 | def on_train_start(self, trainer, pl_module):
153 | logger = get_wandb_logger(trainer)
154 | logger.watch(model=trainer.model, log=self._log, log_freq=self._log_freq)
155 |
--------------------------------------------------------------------------------
/src/utils/transform_utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import numpy as np
3 | import torch
4 | from torch import Tensor
5 | import transforms3d
6 | from typing import Union
7 |
8 |
9 | def cast_rad(angle: Union[float, np.ndarray, Tensor]) -> Union[float, np.ndarray, Tensor]:
10 | """Cast angle such that they are always in the [-pi, pi) range."""
11 | return (angle + np.pi) % (2 * np.pi) - np.pi
12 |
13 |
14 | def _rotation33_as_yaw(rotation: np.ndarray) -> float:
15 | """Compute the yaw component of given 3x3 rotation matrix.
16 |
17 | Args:
18 | rotation (np.ndarray): 3x3 rotation matrix (np.float64 dtype recommended)
19 |
20 | Returns:
21 | float: yaw rotation in radians
22 | """
23 | return transforms3d.euler.mat2euler(rotation)[2]
24 |
25 |
26 | def _yaw_as_rotation33(yaw: float) -> np.ndarray:
27 | """Create a 3x3 rotation matrix from given yaw.
28 | The rotation is counter-clockwise and it is equivalent to:
29 | [cos(yaw), -sin(yaw), 0.0],
30 | [sin(yaw), cos(yaw), 0.0],
31 | [0.0, 0.0, 1.0],
32 |
33 | Args:
34 | yaw (float): yaw rotation in radians
35 |
36 | Returns:
37 | np.ndarray: 3x3 rotation matrix
38 | """
39 | return transforms3d.euler.euler2mat(0, 0, yaw)
40 |
41 |
42 | def get_so2_from_se2(transform_se3: np.ndarray) -> np.ndarray:
43 | """Gets rotation component in SO(2) from transformation in SE(2).
44 |
45 | Args:
46 | transform_se3: se2 transformation.
47 |
48 | Returns:
49 | rotation component in so2
50 | """
51 | rotation = np.eye(3, dtype=np.float64)
52 | rotation[:2, :2] = transform_se3[:2, :2]
53 | return rotation
54 |
55 |
56 | def get_yaw_from_se2(transform_se3: np.ndarray) -> float:
57 | """Gets yaw from transformation in SE(2).
58 |
59 | Args:
60 | transform_se3: se2 transformation.
61 |
62 | Returns:
63 | yaw
64 | """
65 | return _rotation33_as_yaw(get_so2_from_se2(transform_se3))
66 |
67 |
68 | def transform_points(points: np.ndarray, transf_matrix: np.ndarray) -> np.ndarray:
69 | """Transform points using transformation matrix.
70 | Note this function assumes points.shape[1] == matrix.shape[1] - 1, which means that the last
71 | row in the matrix does not influence the final result.
72 | For 2D points only the first 2x3 part of the matrix will be used.
73 |
74 | Args:
75 | points (np.ndarray): Input points (Nx2) or (Nx3).
76 | transf_matrix (np.ndarray): np.float64, 3x3 or 4x4 transformation matrix for 2D and 3D input respectively
77 |
78 | Returns:
79 | np.ndarray: array of shape (N,2) for 2D input points, or (N,3) points for 3D input points
80 | """
81 | assert len(points.shape) == len(transf_matrix.shape) == 2, (
82 | f"dimensions mismatch, both points ({points.shape}) and "
83 | f"transf_matrix ({transf_matrix.shape}) needs to be 2D numpy ndarrays."
84 | )
85 | assert (
86 | transf_matrix.shape[0] == transf_matrix.shape[1]
87 | ), f"transf_matrix ({transf_matrix.shape}) should be a square matrix."
88 |
89 | if points.shape[1] not in [2, 3]:
90 | raise AssertionError(f"Points input should be (N, 2) or (N, 3) shape, received {points.shape}")
91 |
92 | assert points.shape[1] == transf_matrix.shape[1] - 1, "points dim should be one less than matrix dim"
93 |
94 | points_transformed = (points @ transf_matrix.T[:-1, :-1]) + transf_matrix[:-1, -1]
95 |
96 | return points_transformed.astype(points.dtype)
97 |
98 |
99 | def get_transformation_matrix(agent_translation_m: np.ndarray, agent_yaw: float) -> np.ndarray:
100 | """Get transformation matrix from world to vehicle frame
101 |
102 | Args:
103 | agent_translation_m (np.ndarray): (x, y) position of the vehicle in world frame
104 | agent_yaw (float): rotation of the vehicle in the world frame
105 |
106 | Returns:
107 | (np.ndarray) transformation matrix from world to vehicle
108 | """
109 |
110 | # Translate world to ego by applying the negative ego translation.
111 | world_to_agent_in_2d = np.eye(3, dtype=np.float64)
112 | world_to_agent_in_2d[0:2, 2] = -agent_translation_m[0:2]
113 |
114 | # Rotate counter-clockwise by negative yaw to align world such that ego faces right.
115 | world_to_agent_in_2d = _yaw_as_rotation33(-agent_yaw) @ world_to_agent_in_2d
116 |
117 | return world_to_agent_in_2d
118 |
119 |
120 | # transformation for torch
121 | def torch_rad2rot(rad: Tensor) -> Tensor:
122 | """
123 | Args:
124 | rad: [n_batch] or [n_scene, n_agent] or etc.
125 |
126 | Returns:
127 | rot_mat: [{rad.shape}, 2, 2]
128 | """
129 | _cos = torch.cos(rad)
130 | _sin = torch.sin(rad)
131 | return torch.stack([torch.stack([_cos, -_sin], dim=-1), torch.stack([_sin, _cos], dim=-1)], dim=-2)
132 |
133 |
134 | def torch_sincos2rot(in_sin: Tensor, in_cos: Tensor) -> Tensor:
135 | """
136 | Args:
137 | in_sin: [n_batch] or [n_scene, n_agent] or etc.
138 | in_cos: [n_batch] or [n_scene, n_agent] or etc.
139 |
140 | Returns:
141 | rot_mat: [{in_sin.shape}, 2, 2]
142 | """
143 | return torch.stack([torch.stack([in_cos, -in_sin], dim=-1), torch.stack([in_sin, in_cos], dim=-1)], dim=-2)
144 |
145 |
146 | def torch_pos2local(in_pos: Tensor, local_pos: Tensor, local_rot: Tensor) -> Tensor:
147 | """Transform M position to the local coordinates.
148 |
149 | Args:
150 | in_pos: [..., M, 2]
151 | local_pos: [..., 1, 2]
152 | local_rot: [..., 2, 2]
153 |
154 | Returns:
155 | out_pos: [..., M, 2]
156 | """
157 | return torch.matmul(in_pos - local_pos, local_rot)
158 |
159 |
160 | def torch_pos2global(in_pos: Tensor, local_pos: Tensor, local_rot: Tensor) -> Tensor:
161 | """Reverse torch_pos2local
162 |
163 | Args:
164 | in_pos: [..., M, 2]
165 | local_pos: [..., 1, 2]
166 | local_rot: [..., 2, 2]
167 |
168 | Returns:
169 | out_pos: [..., M, 2]
170 | """
171 | return torch.matmul(in_pos.double(), local_rot.transpose(-1, -2).double()) + local_pos.double()
172 |
173 |
174 | def torch_dir2local(in_dir: Tensor, local_rot: Tensor) -> Tensor:
175 | """Transform M dir to the local coordinates.
176 |
177 | Args:
178 | in_dir: [..., M, 2]
179 | local_rot: [..., 2, 2]
180 |
181 | Returns:
182 | out_dir: [..., M, 2]
183 | """
184 | return torch.matmul(in_dir, local_rot)
185 |
186 |
187 | def torch_dir2global(in_dir: Tensor, local_rot: Tensor) -> Tensor:
188 | """Reverse torch_dir2local
189 |
190 | Args:
191 | in_dir: [..., M, 2]
192 | local_rot: [..., 2, 2]
193 |
194 | Returns:
195 | out_dir: [..., M, 2]
196 | """
197 | return torch.matmul(in_dir, local_rot.transpose(-1, -2))
198 |
199 |
200 | def torch_rad2local(in_rad: Tensor, local_rad: Tensor, cast: bool = True) -> Tensor:
201 | """Transform M rad angles to the local coordinates.
202 |
203 | Args:
204 | in_rad: [..., M]
205 | local_rad: [...]
206 |
207 | Returns:
208 | out_rad: [..., M]
209 | """
210 | out_rad = in_rad - local_rad.unsqueeze(-1)
211 | if cast:
212 | out_rad = cast_rad(out_rad)
213 | return out_rad
214 |
215 |
216 | def torch_rad2global(in_rad: Tensor, local_rad: Tensor) -> Tensor:
217 | """Reverse torch_rad2local
218 |
219 | Args:
220 | in_rad: [..., M]
221 | local_rad: [...]
222 |
223 | Returns:
224 | out_rad: [..., M]
225 | """
226 | return cast_rad(in_rad + local_rad.unsqueeze(-1))
227 |
--------------------------------------------------------------------------------
/src/data_modules/scene_centric.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Dict
3 | from omegaconf import DictConfig
4 | import torch
5 | from torch import nn, Tensor
6 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_dir2local, torch_rad2local
7 |
8 |
9 | class SceneCentricPreProcessing(nn.Module):
10 | def __init__(self, time_step_current: int, data_size: DictConfig, gt_in_local: bool, mask_invalid: bool) -> None:
11 | super().__init__()
12 | self.step_current = time_step_current
13 | self.n_step_hist = time_step_current + 1
14 | self.gt_in_local = gt_in_local
15 | self.mask_invalid = mask_invalid
16 | self.model_kwargs = {"gt_in_local": gt_in_local, "agent_centric": False}
17 |
18 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
19 | """
20 | Args: scene-centric Dict
21 | # agent states
22 | "agent/valid": [n_scene, n_step, n_agent], bool,
23 | "agent/pos": [n_scene, n_step, n_agent, 2], float32
24 | "agent/vel": [n_scene, n_step, n_agent, 2], float32, v_x, v_y
25 | "agent/spd": [n_scene, n_step, n_agent, 1], norm of vel, signed using yaw_bbox and vel_xy
26 | "agent/acc": [n_scene, n_step, n_agent, 1], m/s2, acc[t] = (spd[t]-spd[t-1])/dt
27 | "agent/yaw_bbox": [n_scene, n_step, n_agent, 1], float32, yaw of the bbox heading
28 | "agent/yaw_rate": [n_scene, n_step, n_agent, 1], rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt
29 | # agent attributes
30 | "agent/type": [n_scene, n_agent, 3], bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
31 | "agent/role": [n_scene, n_agent, 3], bool [sdc=0, interest=1, predict=2]
32 | "agent/size": [n_scene, n_agent, 3], float32: [length, width, height]
33 | # map polylines
34 | "map/valid": [n_scene, n_pl, n_pl_node], bool
35 | "map/type": [n_scene, n_pl, 11], bool one_hot
36 | "map/pos": [n_scene, n_pl, n_pl_node, 2], float32
37 | "map/dir": [n_scene, n_pl, n_pl_node, 2], float32
38 | # traffic lights
39 | "tl_stop/valid": [n_scene, n_step, n_tl_stop], bool
40 | "tl_stop/state": [n_scene, n_step, n_tl_stop, 5], bool one_hot
41 | "tl_stop/pos": [n_scene, n_step, n_tl_stop, 2], x,y
42 | "tl_stop/dir": [n_scene, n_step, n_tl_stop, 2], x,y
43 |
44 | Returns: scene-centric Dict, masked according to valid
45 | # (ref) reference information for transform back to global coordinate and submission to waymo
46 | "ref/pos": [n_scene, n_agent, 1, 2]
47 | "ref/yaw": [n_scene, n_agent, 1]
48 | "ref/rot": [n_scene, n_agent, 2, 2]
49 | "ref/role": [n_scene, n_agent, 3]
50 | "ref/type": [n_scene, n_agent, 3]
51 | # (gt) ground-truth agent future for training, not available for testing
52 | "gt/valid": [n_scene, n_agent, n_step_future], bool
53 | "gt/pos": [n_scene, n_agent, n_step_future, 2]
54 | "gt/spd": [n_scene, n_agent, n_step_future, 1]
55 | "gt/vel": [n_scene, n_agent, n_step_future, 2]
56 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1]
57 | "gt/cmd": [n_scene, n_agent, 8]
58 | # (sc) scene-centric agents states
59 | "sc/agent_valid": [n_scene, n_agent, n_step_hist]
60 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2]
61 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2]
62 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1]
63 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1]
64 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1]
65 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1]
66 | # agents attributes
67 | "sc/agent_type": [n_scene, n_agent, 3]
68 | "sc/agent_role": [n_scene, n_agent, 3]
69 | "sc/agent_size": [n_scene, n_agent, 3]
70 | # map polylines
71 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool
72 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot
73 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32
74 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32
75 | # traffic lights
76 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool
77 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot
78 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y
79 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y
80 | """
81 | prefix = "" if self.training else "history/"
82 |
83 | # ! prepare "ref/"
84 | batch["ref/type"] = batch[prefix + "agent/type"]
85 | batch["ref/role"] = batch[prefix + "agent/role"]
86 |
87 | last_valid_step = (
88 | self.step_current - batch[prefix + "agent/valid"][:, : self.step_current + 1].flip(1).max(1)[1]
89 | ) # [n_scene, n_agent]
90 | i_scene = torch.arange(batch["ref/type"].shape[0]).unsqueeze(1) # [n_scene, 1]
91 | i_agent = torch.arange(batch["ref/type"].shape[1]).unsqueeze(0) # [1, n_agent]
92 | ref_pos = batch[prefix + "agent/pos"][i_scene, last_valid_step, i_agent].unsqueeze(-2).contiguous()
93 | ref_yaw = batch[prefix + "agent/yaw_bbox"][i_scene, last_valid_step, i_agent]
94 | ref_rot = torch_rad2rot(ref_yaw.squeeze(-1))
95 | batch["ref/pos"] = ref_pos
96 | batch["ref/yaw"] = ref_yaw
97 | batch["ref/rot"] = ref_rot
98 |
99 | # ! prepare agents states
100 | # [n_scene, n_step, n_agent, ...] -> [n_scene, n_agent, n_step_hist, ...]
101 | for k in ("valid", "pos", "vel", "spd", "acc", "yaw_bbox", "yaw_rate"):
102 | batch[f"sc/agent_{k}"] = batch[f"{prefix}agent/{k}"][:, : self.n_step_hist].transpose(1, 2)
103 |
104 | # ! prepare agents attributes
105 | for k in ("type", "role", "size"):
106 | batch[f"sc/agent_{k}"] = batch[f"{prefix}agent/{k}"]
107 |
108 | # ! training/validation time, prepare "gt/" for losses
109 | if "agent/valid" in batch.keys():
110 | batch["gt/cmd"] = batch["agent/cmd"]
111 | for k in ("valid", "spd", "pos", "vel", "yaw_bbox"):
112 | batch[f"gt/{k}"] = batch[f"agent/{k}"][:, self.n_step_hist :].transpose(1, 2).contiguous()
113 |
114 | if self.gt_in_local:
115 | # [n_scene, n_agent, n_step_hist, 2]
116 | batch["gt/pos"] = torch_pos2local(batch["gt/pos"], ref_pos, ref_rot)
117 | batch["gt/vel"] = torch_dir2local(batch["gt/vel"], ref_rot)
118 | # [n_scene, n_agent, n_step_hist, 1]
119 | batch["gt/yaw_bbox"] = torch_rad2local(batch["gt/yaw_bbox"], ref_yaw, cast=False)
120 |
121 | # ! prepare map polylines
122 | for k in ("valid", "type", "pos", "dir"):
123 | batch[f"sc/map_{k}"] = batch[f"map/{k}"]
124 |
125 | # ! prepare traffic lights
126 | for k in ("valid", "state", "pos", "dir"):
127 | batch[f"sc/tl_{k}"] = batch[f"{prefix}tl_stop/{k}"][:, : self.n_step_hist]
128 |
129 | if self.mask_invalid:
130 | self.zero_mask_invalid(batch)
131 | return batch
132 |
133 | @staticmethod
134 | def zero_mask_invalid(batch: Dict[str, Tensor]):
135 |
136 | agent_invalid = ~batch["sc/agent_valid"].unsqueeze(-1)
137 | for k in ["pos", "vel", "spd", "acc", "yaw_bbox", "yaw_rate"]:
138 | _key = f"sc/agent_{k}"
139 | batch[_key] = batch[_key].masked_fill(agent_invalid, 0)
140 |
141 | agent_invalid = ~(batch["sc/agent_valid"].any(-1, keepdim=True))
142 | for k in ["type", "role", "size"]:
143 | _key = f"sc/agent_{k}"
144 | batch[_key] = batch[_key].masked_fill(agent_invalid, 0)
145 |
146 | map_invalid = ~batch["sc/map_valid"].unsqueeze(-1)
147 | batch["sc/map_pos"] = batch["sc/map_pos"].masked_fill(map_invalid, 0)
148 | batch["sc/map_dir"] = batch["sc/map_dir"].masked_fill(map_invalid, 0)
149 | map_invalid = ~(batch["sc/map_valid"].any(-1, keepdim=True))
150 | batch["sc/map_type"] = batch["sc/map_type"].masked_fill(map_invalid, 0)
151 |
152 | tl_invalid = ~batch["sc/tl_valid"].unsqueeze(-1)
153 | for k in ["state", "pos", "dir"]:
154 | _key = f"sc/tl_{k}"
155 | batch[_key] = batch[_key].masked_fill(tl_invalid, 0)
156 |
157 | gt_invalid = ~batch["gt/valid"].unsqueeze(-1)
158 | for k in ["pos", "spd", "vel", "yaw_bbox"]:
159 | _key = f"gt/{k}"
160 | batch[_key] = batch[_key].masked_fill(gt_invalid, 0)
161 |
--------------------------------------------------------------------------------
/src/models/modules/attention.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional, Tuple
3 | import math
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.nn import functional as F
7 |
8 | # KNARPE
9 | class AttentionRPE(nn.Module):
10 | def __init__(
11 | self,
12 | d_model: int,
13 | n_head: int,
14 | dropout_p: float = 0.0,
15 | bias: bool = True,
16 | d_rpe: int = -1,
17 | apply_q_rpe: bool = False,
18 | ) -> None:
19 | """
20 | Always batch first. Always src and tgt have the same d_model.
21 | """
22 | super(AttentionRPE, self).__init__()
23 |
24 | self.d_model = d_model
25 | self.n_head = n_head
26 | self.d_head = d_model // n_head
27 | self.apply_q_rpe = apply_q_rpe
28 | self.d_rpe = d_rpe
29 |
30 | assert self.d_head * n_head == d_model, "d_model must be divisible by n_head"
31 |
32 | if self.d_rpe > 0:
33 | n_project_rpe = 3 if apply_q_rpe else 2
34 | self.mlp_rpe = nn.Linear(d_rpe, n_project_rpe * d_model, bias=bias)
35 |
36 | self.in_proj_weight = nn.Parameter(torch.empty((3 * d_model, d_model)))
37 | self.out_proj_weight = nn.Parameter(torch.empty((d_model, d_model)))
38 | if bias:
39 | self.in_proj_bias = nn.Parameter(torch.empty(3 * d_model))
40 | self.out_proj_bias = nn.Parameter(torch.empty(d_model))
41 | else:
42 | self.register_parameter("in_proj_bias", None)
43 | self.register_parameter("out_proj_bias", None)
44 |
45 | self.dropout = nn.Dropout(p=dropout_p, inplace=False) if dropout_p > 0 else None
46 |
47 | self._reset_parameters()
48 |
49 | def _reset_parameters(self):
50 | nn.init.xavier_uniform_(self.in_proj_weight)
51 | nn.init.xavier_uniform_(self.out_proj_weight)
52 | if self.in_proj_bias is not None:
53 | nn.init.constant_(self.in_proj_bias, 0.0)
54 | if self.out_proj_bias is not None:
55 | nn.init.constant_(self.out_proj_bias, 0.0)
56 |
57 | def forward(
58 | self,
59 | src: Tensor,
60 | tgt: Optional[Tensor] = None,
61 | tgt_padding_mask: Optional[Tensor] = None,
62 | attn_mask: Optional[Tensor] = None,
63 | rpe: Optional[Tensor] = None,
64 | need_weights=False,
65 | ) -> Tuple[Tensor, Optional[Tensor]]:
66 | """
67 | Args:
68 | src: [n_batch, n_src, d_model]
69 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe.
70 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe.
71 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt.
72 | rpe: [n_batch, n_src, n_tgt, d_rpe]
73 |
74 | Returns:
75 | out: [n_batch, n_src, d_model]
76 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None
77 |
78 | Remarks:
79 | absoulte_pe should be already added to src/tgt.
80 | if for a batch entry all tgt are invalid, then returns 0 for that batch entry.
81 | """
82 | n_batch, n_src, _ = src.shape
83 | if tgt is None:
84 | n_tgt = n_src
85 | # self-attention
86 | qkv = F.linear(src, self.in_proj_weight, self.in_proj_bias)
87 | q, k, v = qkv.chunk(3, dim=-1)
88 | else:
89 | n_tgt = tgt.shape[-2]
90 | # encoder-decoder attention
91 | w_src, w_tgt = self.in_proj_weight.split([self.d_model, self.d_model * 2])
92 | b_src, b_tgt = None, None
93 | if self.in_proj_bias is not None:
94 | b_src, b_tgt = self.in_proj_bias.split([self.d_model, self.d_model * 2])
95 | q = F.linear(src, w_src, b_src)
96 | kv = F.linear(tgt, w_tgt, b_tgt)
97 | k, v = kv.chunk(2, dim=-1)
98 | # q: [n_batch, n_src, d_model], k,v: [n_batch, (n_src), n_tgt, d_model]
99 |
100 | attn_invalid_mask = None # [n_batch, n_src, n_tgt]
101 | if tgt_padding_mask is not None: # [n_batch, n_tgt], bool
102 | attn_invalid_mask = tgt_padding_mask
103 | if attn_invalid_mask.dim() == 2:
104 | attn_invalid_mask = attn_invalid_mask.unsqueeze(1).expand(-1, n_src, -1)
105 | if attn_mask is not None: # [n_batch, n_src, n_tgt], bool
106 | if attn_invalid_mask is None:
107 | attn_invalid_mask = attn_mask
108 | else:
109 | attn_invalid_mask = attn_invalid_mask | attn_mask
110 |
111 | mask_no_tgt_valid = None # [n_batch, n_src]
112 | if attn_invalid_mask is not None:
113 | mask_no_tgt_valid = attn_invalid_mask.all(-1)
114 | if mask_no_tgt_valid.any():
115 | attn_invalid_mask = attn_invalid_mask & (~mask_no_tgt_valid.unsqueeze(-1)) # to avoid softmax nan
116 | else:
117 | mask_no_tgt_valid = None
118 |
119 | # get attn: [n_batch, n_head, n_src, n_tgt]
120 | if rpe is None:
121 | if k.dim() == 3:
122 | # ! normal attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_tgt, d_model]
123 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).contiguous()
124 | k = k.view(n_batch, n_tgt, self.n_head, self.d_head).transpose(1, 2).contiguous()
125 | v = v.view(n_batch, n_tgt, self.n_head, self.d_head).transpose(1, 2).contiguous()
126 | attn = torch.matmul(q, k.transpose(-2, -1)) # [n_batch, n_head, n_src, n_tgt]
127 | # q: [n_batch, n_head, n_src, d_head], k,v: [n_batch, n_head, n_tgt, d_head]
128 | else:
129 | # ! KNN attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_src, n_tgt, d_model]
130 | # k,v: [n_batch, n_src, n_tgt, d_model] -> [n_batch, n_head, n_src, n_tgt_knn, d_head]
131 | k = k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
132 | v = v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
133 | # [n_batch, n_src, d_model] -> [n_batch, n_head, n_src, 1, d_head]
134 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).unsqueeze(3)
135 | attn = torch.sum(q * k, dim=-1) # [n_batch, n_head, n_src, n_tgt_knn]
136 | else:
137 | # ! rpe attention; q: [n_batch, n_src, d_model], k,v: [n_batch, n_tgt, d_model]
138 | assert self.d_rpe > 0
139 | # k,v: [n_batch, n_src, n_tgt, d_model] -> [n_batch, n_head, n_src, n_tgt_knn, d_head]
140 | k = k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
141 | v = v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
142 | # [n_batch, n_src, d_model] -> [n_batch, n_head, n_src, 1, d_head]
143 | q = q.view(n_batch, n_src, self.n_head, self.d_head).transpose(1, 2).unsqueeze(3)
144 |
145 | # project rpe to rpe_q, rpe_k, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head]
146 | rpe = self.mlp_rpe(rpe)
147 | if self.apply_q_rpe:
148 | rpe_q, rpe_k, rpe_v = rpe.chunk(3, dim=-1)
149 | rpe_q = rpe_q.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
150 | else:
151 | rpe_k, rpe_v = rpe.chunk(2, dim=-1)
152 | rpe_k = rpe_k.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
153 | rpe_v = rpe_v.view(n_batch, n_src, n_tgt, self.n_head, self.d_head).movedim(3, 1)
154 |
155 | # get attn: [n_batch, n_head, n_src, n_tgt]
156 | if self.apply_q_rpe:
157 | attn = torch.sum((q + rpe_q) * (k + rpe_k), dim=-1)
158 | # attn = torch.sum((q + rpe_q) * (k + rpe_k) - rpe_q * rpe_k, dim=-1)
159 | else:
160 | attn = torch.sum(q * (k + rpe_k), dim=-1)
161 | # q: [n_batch, n_head, n_src, 1, d_head]
162 | # k,v: [n_batch, n_head, n_src, n_tgt, d_head]
163 | # rpe_q, rpe_k, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head]
164 |
165 | if attn_invalid_mask is not None:
166 | # attn_invalid_mask: [n_batch, n_src, n_tgt], attn: [n_batch, n_head, n_src, n_tgt]
167 | attn = attn.masked_fill(attn_invalid_mask.unsqueeze(1), float("-inf"))
168 |
169 | attn = torch.softmax(attn / math.sqrt(self.d_head), dim=-1)
170 | if self.dropout is not None:
171 | attn = self.dropout(attn)
172 |
173 | # attn: [n_batch, n_head, n_src, n_tgt]
174 | if rpe is None:
175 | if v.dim() == 4:
176 | out = torch.matmul(attn, v) # v, [n_batch, n_head, n_tgt, d_head]
177 | else:
178 | out = torch.sum(v * attn.unsqueeze(-1), dim=3) # v: [n_batch, n_head, n_src, n_tgt, d_head]
179 | else:
180 | # v, rpe_v: [n_batch, n_head, n_src, n_tgt, d_head]
181 | out = torch.sum((v + rpe_v) * attn.unsqueeze(-1), dim=3)
182 |
183 | # out: [n_batch, n_head, n_src, d_head]
184 | out = out.transpose(1, 2).flatten(2, 3) # [n_batch, n_src, d_model]
185 | out = F.linear(out, self.out_proj_weight, self.out_proj_bias)
186 |
187 | if mask_no_tgt_valid is not None:
188 | # mask_no_tgt_valid: [n_batch, n_src], out: [n_batch, n_src, d_model]
189 | out = out.masked_fill(mask_no_tgt_valid.unsqueeze(-1), 0)
190 |
191 | if need_weights:
192 | attn_weights = attn.mean(1) # [n_batch, n_src, n_tgt]
193 | if mask_no_tgt_valid is not None:
194 | attn_weights = attn_weights.masked_fill(mask_no_tgt_valid.unsqueeze(-1), 0)
195 | return out, attn_weights
196 | else:
197 | return out, None
198 |
--------------------------------------------------------------------------------
/src/utils/submission.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import List
3 | from omegaconf import ListConfig
4 | from pathlib import Path
5 | import tarfile
6 | import os
7 | import pandas as pd
8 | from torch import Tensor
9 | from waymo_open_dataset.protos import motion_submission_pb2
10 | from pytorch_lightning.loggers import WandbLogger
11 | from .transform_utils import torch_pos2global, torch_rad2rot
12 |
13 | # ! single GPU only
14 |
15 |
16 | class SubWOMD:
17 | def __init__(
18 | self,
19 | k_futures: int,
20 | wb_artifact: str,
21 | interactive_challenge: bool,
22 | activate: bool,
23 | method_name: str,
24 | authors: ListConfig[str],
25 | affiliation: str,
26 | description: str,
27 | method_link: str,
28 | account_name: str,
29 | ) -> None:
30 | self.activate = activate
31 | if activate:
32 | self.submissions = {}
33 | for _k in range(1, k_futures + 1):
34 | self.submissions[_k] = motion_submission_pb2.MotionChallengeSubmission()
35 | self.submissions[_k].account_name = account_name
36 | self.submissions[_k].unique_method_name = f"{method_name}_K{_k}"
37 | self.submissions[_k].authors.extend(list(authors))
38 | self.submissions[_k].affiliation = affiliation
39 | self.submissions[_k].description = description
40 | self.submissions[_k].method_link = method_link
41 | if interactive_challenge:
42 | self.submissions[_k].submission_type = 2
43 | else:
44 | self.submissions[_k].submission_type = 1
45 |
46 | def add_to_submissions(
47 | self,
48 | waymo_trajs: Tensor,
49 | waymo_scores: Tensor,
50 | mask_pred: Tensor,
51 | object_id: Tensor,
52 | scenario_center: Tensor,
53 | scenario_yaw: Tensor,
54 | scenario_id: List[str],
55 | ) -> None:
56 | """
57 | Args:
58 | waymo_trajs: [n_batch, step_start+1...step_end, n_agent, K, 2]
59 | waymo_scores: [n_batch, n_agent, K]
60 | mask_pred: [n_batch, n_agent] bool
61 | object_id: [n_batch, n_agent] int
62 | scenario_center: [n_batch, 2]
63 | scenario_yaw: [n_batch]
64 | scenario_id: list of str
65 | """
66 | if not self.activate:
67 | return
68 |
69 | waymo_trajs = waymo_trajs[:, 4::5].permute(0, 2, 3, 1, 4) # [n_batch, n_agent, K, n_step, 2]
70 | waymo_trajs = torch_pos2global(
71 | waymo_trajs.flatten(1, 3), scenario_center.unsqueeze(1), torch_rad2rot(scenario_yaw)
72 | ).view(waymo_trajs.shape)
73 |
74 | waymo_trajs = waymo_trajs.cpu().numpy()
75 | waymo_scores = waymo_scores.cpu().numpy()
76 | mask_pred = mask_pred.cpu().numpy()
77 | object_id = object_id.cpu().numpy()
78 |
79 | for i_batch in range(waymo_trajs.shape[0]):
80 | agent_pos = waymo_trajs[i_batch, mask_pred[i_batch]] # [n_agent_pred, K, n_step, 2]
81 | agent_id = object_id[i_batch, mask_pred[i_batch]] # [n_agent_pred]
82 | agent_score = waymo_scores[i_batch, mask_pred[i_batch]] # [n_agent_pred, K]
83 |
84 | for n_K, submission in self.submissions.items():
85 | scenario_prediction = motion_submission_pb2.ChallengeScenarioPredictions()
86 | scenario_prediction.scenario_id = scenario_id[i_batch]
87 |
88 | if submission.submission_type == 1:
89 | # single prediction
90 | for i_track in range(agent_pos.shape[0]):
91 | prediction = motion_submission_pb2.SingleObjectPrediction()
92 | prediction.object_id = agent_id[i_track]
93 | for _k in range(n_K):
94 | scored_trajectory = motion_submission_pb2.ScoredTrajectory()
95 | scored_trajectory.confidence = agent_score[i_track, _k]
96 | scored_trajectory.trajectory.center_x.extend(agent_pos[i_track, _k, :, 0])
97 | scored_trajectory.trajectory.center_y.extend(agent_pos[i_track, _k, :, 1])
98 | prediction.trajectories.append(scored_trajectory)
99 | scenario_prediction.single_predictions.predictions.append(prediction)
100 | else:
101 | # joint prediction
102 | for _k in range(n_K):
103 | scored_joint_trajectory = motion_submission_pb2.ScoredJointTrajectory()
104 | scored_joint_trajectory.confidence = agent_score[:, _k].sum(0)
105 | for i_track in range(agent_pos.shape[0]):
106 | object_trajectory = motion_submission_pb2.ObjectTrajectory()
107 | object_trajectory.object_id = agent_id[i_track]
108 | object_trajectory.trajectory.center_x.extend(agent_pos[i_track, _k, :, 0])
109 | object_trajectory.trajectory.center_y.extend(agent_pos[i_track, _k, :, 1])
110 | scored_joint_trajectory.trajectories.append(object_trajectory)
111 | scenario_prediction.joint_prediction.joint_trajectories.append(scored_joint_trajectory)
112 |
113 | submission.scenario_predictions.append(scenario_prediction)
114 |
115 | def save_sub_files(self, logger: WandbLogger) -> List[str]:
116 | if not self.activate:
117 | return []
118 |
119 | print(f"saving womd submission files to {os.getcwd()}")
120 | file_paths = []
121 | for k, submission in self.submissions.items():
122 | submission_dir = Path(f"womd_K{k}")
123 | submission_dir.mkdir(exist_ok=True)
124 | f = open(submission_dir / f"womd_K{k}.bin", "wb")
125 | f.write(submission.SerializeToString())
126 | f.close()
127 | tar_file_name = submission_dir.as_posix() + ".tar.gz"
128 | with tarfile.open(tar_file_name, "w:gz") as tar:
129 | tar.add(submission_dir, arcname=submission_dir.name)
130 | if isinstance(logger, WandbLogger):
131 | logger.experiment.save(tar_file_name)
132 | else:
133 | file_paths.append(tar_file_name)
134 | return file_paths
135 |
136 |
137 | class SubAV2:
138 | def __init__(self, k_futures: int, activate: bool) -> None:
139 | self._SUBMISSION_COL_NAMES = [
140 | "scenario_id",
141 | "track_id",
142 | "probability",
143 | "predicted_trajectory_x",
144 | "predicted_trajectory_y",
145 | ]
146 | self.activate = activate
147 | if activate:
148 | self.submissions = {}
149 | for _k in range(1, k_futures + 1):
150 | self.submissions[_k] = []
151 |
152 | def add_to_submissions(
153 | self,
154 | waymo_trajs: Tensor,
155 | waymo_scores: Tensor,
156 | mask_pred: Tensor,
157 | object_id: Tensor,
158 | scenario_center: Tensor,
159 | scenario_yaw: Tensor,
160 | scenario_id: List[str],
161 | ) -> None:
162 | """
163 | Args:
164 | waymo_trajs: [n_batch, step_start+1...step_end, n_agent, K, 2]
165 | waymo_scores: [n_batch, n_agent, K]
166 | mask_pred: [n_batch, n_agent] bool
167 | object_id: [n_batch, n_agent] int
168 | scenario_center: [n_batch, 2]
169 | scenario_yaw: [n_batch]
170 | scenario_id: list of str
171 | """
172 | if not self.activate:
173 | return
174 |
175 | waymo_trajs = waymo_trajs.permute(0, 2, 3, 1, 4) # [n_batch, n_agent, K, n_step, 2]
176 | waymo_trajs = torch_pos2global(
177 | waymo_trajs.flatten(1, 3), scenario_center.unsqueeze(1), torch_rad2rot(scenario_yaw)
178 | ).view(waymo_trajs.shape)
179 |
180 | waymo_trajs = waymo_trajs.cpu().numpy()
181 | waymo_scores = waymo_scores.cpu().numpy()
182 | mask_pred = mask_pred.cpu().numpy()
183 | object_id = object_id.cpu().numpy()
184 |
185 | for i_batch in range(waymo_trajs.shape[0]):
186 | agent_pos = waymo_trajs[i_batch, mask_pred[i_batch]] # [n_agent_pred, K, n_step, 2]
187 | agent_id = object_id[i_batch, mask_pred[i_batch]] # [n_agent_pred]
188 | agent_score = waymo_scores[i_batch, mask_pred[i_batch]] # [n_agent_pred, K]
189 |
190 | for i_agent in range(agent_id.shape[0]):
191 | for n_K in self.submissions.keys():
192 | confidence = agent_score[i_agent, :n_K] / agent_score[i_agent, :n_K].sum()
193 | for _k in range(n_K):
194 | self.submissions[n_K].append(
195 | (
196 | scenario_id[i_batch],
197 | str(agent_id[i_agent]),
198 | confidence[_k],
199 | agent_pos[i_agent, _k, :, 0],
200 | agent_pos[i_agent, _k, :, 1],
201 | )
202 | )
203 |
204 | def save_sub_files(self, logger: WandbLogger) -> List[str]:
205 | if not self.activate:
206 | return []
207 |
208 | print(f"saving av2 submission files to {os.getcwd()}")
209 | file_paths = []
210 | for k, prediction_rows in self.submissions.items():
211 | parquet_file_path = Path(f"av2_K{k}.parquet").as_posix()
212 | submission_df = pd.DataFrame(prediction_rows, columns=self._SUBMISSION_COL_NAMES)
213 | submission_df.to_parquet(parquet_file_path)
214 | if isinstance(logger, WandbLogger):
215 | logger.experiment.save(parquet_file_path)
216 | else:
217 | file_paths.append(parquet_file_path)
218 | return file_paths
--------------------------------------------------------------------------------
/src/data_modules/data_h5_womd.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional, Dict, Any, Tuple
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader, Dataset
5 | import numpy as np
6 | import h5py
7 |
8 |
9 | class DatasetBase(Dataset[Dict[str, np.ndarray]]):
10 | def __init__(self, filepath: str, tensor_size: Dict[str, Tuple]) -> None:
11 | super().__init__()
12 | self.tensor_size = tensor_size
13 | self.filepath = filepath
14 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
15 | self.dataset_len = int(hf.attrs["data_len"])
16 |
17 | def __len__(self) -> int:
18 | return self.dataset_len
19 |
20 |
21 | class DatasetTrain(DatasetBase):
22 | """
23 | The waymo 9-sec trainging.h5 is repetitive, start at {0, 2, 4, 5, 6, 8, 10} seconds within the 20-sec episode.
24 | Always train with the whole training.h5 dataset.
25 | limit_train_batches just for controlling the validation frequency.
26 | """
27 |
28 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
29 | idx = np.random.randint(self.dataset_len)
30 | idx_key = str(idx)
31 | out_dict = {"episode_idx": idx}
32 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
33 | for k in self.tensor_size.keys():
34 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k])
35 | return out_dict
36 |
37 |
38 | class DatasetVal(DatasetBase):
39 | # for validation.h5 and testing.h5
40 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
41 | idx_key = str(idx)
42 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
43 | out_dict = {
44 | "episode_idx": idx,
45 | "scenario_id": hf[idx_key].attrs["scenario_id"],
46 | "scenario_center": hf[idx_key].attrs["scenario_center"],
47 | "scenario_yaw": hf[idx_key].attrs["scenario_yaw"],
48 | "with_map": hf[idx_key].attrs["with_map"], # some epidosdes in the testing dataset do not have map.
49 | }
50 | for k, _size in self.tensor_size.items():
51 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k])
52 | if out_dict[k].shape != _size:
53 | assert "agent" in k
54 | out_dict[k] = np.ones(_size, dtype=out_dict[k].dtype)
55 | return out_dict
56 |
57 |
58 | class DataH5womd(LightningDataModule):
59 | def __init__(
60 | self,
61 | data_dir: str,
62 | filename_train: str = "training",
63 | filename_val: str = "validation",
64 | filename_test: str = "testing",
65 | batch_size: int = 3,
66 | num_workers: int = 4,
67 | n_agent: int = 64, # if not the same as h5 dataset, use dummy agents, for scalability tests.
68 | ) -> None:
69 | super().__init__()
70 | self.interactive_challenge = "interactive" in filename_val or "interactive" in filename_test
71 |
72 | self.path_train_h5 = f"{data_dir}/{filename_train}.h5"
73 | self.path_val_h5 = f"{data_dir}/{filename_val}.h5"
74 | self.path_test_h5 = f"{data_dir}/{filename_test}.h5"
75 | self.batch_size = batch_size
76 | self.num_workers = num_workers
77 |
78 | n_step = 91
79 | n_step_history = 11
80 | n_agent_no_sim = 256
81 | n_pl = 1024
82 | n_tl = 100
83 | n_tl_stop = 40
84 | n_pl_node = 20
85 | self.tensor_size_train = {
86 | # agent states
87 | "agent/valid": (n_step, n_agent), # bool,
88 | "agent/pos": (n_step, n_agent, 2), # float32
89 | # v[1] = p[1]-p[0]. if p[1] invalid, v[1] also invalid, v[2]=v[3]
90 | "agent/vel": (n_step, n_agent, 2), # float32, v_x, v_y
91 | "agent/spd": (n_step, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy
92 | "agent/acc": (n_step, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt
93 | "agent/yaw_bbox": (n_step, n_agent, 1), # float32, yaw of the bbox heading
94 | "agent/yaw_rate": (n_step, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt
95 | # agent attributes
96 | "agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
97 | "agent/cmd": (n_agent, 8), # bool one_hot
98 | "agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2]
99 | "agent/size": (n_agent, 3), # float32: [length, width, height]
100 | "agent/goal": (n_agent, 4), # float32: [x, y, theta, v]
101 | "agent/dest": (n_agent,), # int64: index to map n_pl
102 | # map polylines
103 | "map/valid": (n_pl, n_pl_node), # bool
104 | "map/type": (n_pl, 11), # bool one_hot
105 | "map/pos": (n_pl, n_pl_node, 2), # float32
106 | "map/dir": (n_pl, n_pl_node, 2), # float32
107 | "map/boundary": (4,), # xmin, xmax, ymin, ymax
108 | # traffic lights
109 | "tl_lane/valid": (n_step, n_tl), # bool
110 | "tl_lane/state": (n_step, n_tl, 5), # bool one_hot
111 | "tl_lane/idx": (n_step, n_tl), # int, -1 means not valid
112 | "tl_stop/valid": (n_step, n_tl_stop), # bool
113 | "tl_stop/state": (n_step, n_tl_stop, 5), # bool one_hot
114 | "tl_stop/pos": (n_step, n_tl_stop, 2), # x,y
115 | "tl_stop/dir": (n_step, n_tl_stop, 2), # x,y
116 | }
117 |
118 | self.tensor_size_test = {
119 | # object_id for waymo metrics
120 | "history/agent/object_id": (n_agent,),
121 | "history/agent_no_sim/object_id": (n_agent_no_sim,),
122 | # agent_sim
123 | "history/agent/valid": (n_step_history, n_agent), # bool,
124 | "history/agent/pos": (n_step_history, n_agent, 2), # float32
125 | "history/agent/vel": (n_step_history, n_agent, 2), # float32, v_x, v_y
126 | "history/agent/spd": (n_step_history, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy
127 | "history/agent/acc": (n_step_history, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt
128 | "history/agent/yaw_bbox": (n_step_history, n_agent, 1), # float32, yaw of the bbox heading
129 | "history/agent/yaw_rate": (n_step_history, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt
130 | "history/agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
131 | "history/agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2]
132 | "history/agent/size": (n_agent, 3), # float32: [length, width, height]
133 | # agent_no_sim not used by the models currently
134 | "history/agent_no_sim/valid": (n_step_history, n_agent_no_sim),
135 | "history/agent_no_sim/pos": (n_step_history, n_agent_no_sim, 2),
136 | "history/agent_no_sim/vel": (n_step_history, n_agent_no_sim, 2),
137 | "history/agent_no_sim/spd": (n_step_history, n_agent_no_sim, 1),
138 | "history/agent_no_sim/yaw_bbox": (n_step_history, n_agent_no_sim, 1),
139 | "history/agent_no_sim/type": (n_agent_no_sim, 3),
140 | "history/agent_no_sim/size": (n_agent_no_sim, 3),
141 | # map
142 | "map/valid": (n_pl, n_pl_node), # bool
143 | "map/type": (n_pl, 11), # bool one_hot
144 | "map/pos": (n_pl, n_pl_node, 2), # float32
145 | "map/dir": (n_pl, n_pl_node, 2), # float32
146 | "map/boundary": (4,), # xmin, xmax, ymin, ymax
147 | # traffic_light
148 | "history/tl_lane/valid": (n_step_history, n_tl), # bool
149 | "history/tl_lane/state": (n_step_history, n_tl, 5), # bool one_hot
150 | "history/tl_lane/idx": (n_step_history, n_tl), # int, -1 means not valid
151 | "history/tl_stop/valid": (n_step_history, n_tl_stop), # bool
152 | "history/tl_stop/state": (n_step_history, n_tl_stop, 5), # bool one_hot
153 | "history/tl_stop/pos": (n_step_history, n_tl_stop, 2), # x,y
154 | "history/tl_stop/dir": (n_step_history, n_tl_stop, 2), # dx,dy
155 | }
156 |
157 | self.tensor_size_val = {
158 | "agent/object_id": (n_agent,),
159 | "agent_no_sim/object_id": (n_agent_no_sim,),
160 | # agent_no_sim
161 | "agent_no_sim/valid": (n_step, n_agent_no_sim), # bool,
162 | "agent_no_sim/pos": (n_step, n_agent_no_sim, 2), # float32
163 | "agent_no_sim/vel": (n_step, n_agent_no_sim, 2), # float32, v_x, v_y
164 | "agent_no_sim/spd": (n_step, n_agent_no_sim, 1), # norm of vel, signed using yaw_bbox and vel_xy
165 | "agent_no_sim/yaw_bbox": (n_step, n_agent_no_sim, 1), # float32, yaw of the bbox heading
166 | "agent_no_sim/type": (n_agent_no_sim, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
167 | "agent_no_sim/size": (n_agent_no_sim, 3), # float32: [length, width, height]
168 | }
169 |
170 | self.tensor_size_val = self.tensor_size_val | self.tensor_size_train | self.tensor_size_test
171 |
172 | def setup(self, stage: Optional[str] = None) -> None:
173 | if stage == "fit" or stage is None:
174 | self.train_dataset = DatasetTrain(self.path_train_h5, self.tensor_size_train)
175 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val)
176 | elif stage == "validate":
177 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val)
178 | elif stage == "test":
179 | self.test_dataset = DatasetVal(self.path_test_h5, self.tensor_size_test)
180 |
181 | def train_dataloader(self) -> DataLoader[Any]:
182 | return self._get_dataloader(self.train_dataset, self.batch_size, self.num_workers)
183 |
184 | def val_dataloader(self) -> DataLoader[Any]:
185 | return self._get_dataloader(self.val_dataset, self.batch_size, self.num_workers)
186 |
187 | def test_dataloader(self) -> DataLoader[Any]:
188 | return self._get_dataloader(self.test_dataset, self.batch_size, self.num_workers)
189 |
190 | @staticmethod
191 | def _get_dataloader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader[Any]:
192 | return DataLoader(
193 | ds,
194 | batch_size=batch_size,
195 | num_workers=num_workers,
196 | pin_memory=True,
197 | shuffle=False,
198 | drop_last=False,
199 | persistent_workers=True,
200 | )
201 |
--------------------------------------------------------------------------------
/src/models/modules/transformer.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional, Tuple
3 | from torch import Tensor, nn
4 | from torch.nn import functional as F
5 | from .attention import AttentionRPE
6 |
7 |
8 | def _get_activation_fn(activation):
9 | if activation == "relu":
10 | return F.relu
11 | elif activation == "gelu":
12 | return F.gelu
13 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
14 |
15 |
16 | class TransformerBlock(nn.Module):
17 | __constants__ = ["norm"]
18 |
19 | def __init__(
20 | self,
21 | d_model: int,
22 | n_head: int = 2,
23 | d_feedforward: int = 256,
24 | dropout_p: float = 0.1,
25 | activation: str = "relu",
26 | n_layer: int = 1,
27 | norm_first: bool = True,
28 | decoder_self_attn: bool = False,
29 | bias: bool = True,
30 | d_rpe: int = -1,
31 | apply_q_rpe: bool = False,
32 | ) -> None:
33 | super(TransformerBlock, self).__init__()
34 | self.layers = nn.ModuleList(
35 | [
36 | TransformerCrossAttention(
37 | d_model=d_model,
38 | n_head=n_head,
39 | d_feedforward=d_feedforward,
40 | dropout_p=dropout_p,
41 | activation=activation,
42 | norm_first=norm_first,
43 | decoder_self_attn=decoder_self_attn,
44 | bias=bias,
45 | d_rpe=d_rpe,
46 | apply_q_rpe=apply_q_rpe,
47 | )
48 | for _ in range(n_layer)
49 | ]
50 | )
51 |
52 | # self.layers = _get_clones(encoder_layer, n_layer)
53 | # self.n_layer = n_layer
54 | # self.norm = nn.LayerNorm(d_model) if norm_first else None
55 |
56 | def forward(
57 | self,
58 | src: Tensor,
59 | src_padding_mask: Optional[Tensor] = None,
60 | tgt: Optional[Tensor] = None,
61 | tgt_padding_mask: Optional[Tensor] = None,
62 | rpe: Optional[Tensor] = None,
63 | decoder_tgt: Optional[Tensor] = None,
64 | decoder_tgt_padding_mask: Optional[Tensor] = None,
65 | decoder_rpe: Optional[Tensor] = None,
66 | attn_mask: Optional[Tensor] = None,
67 | need_weights: bool = False,
68 | ) -> Tuple[Tensor, Optional[Tensor]]:
69 | """
70 | Args:
71 | src: [n_batch, n_src, d_model]
72 | src_padding_mask: [n_batch, n_src], bool, if True, src is invalid.
73 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe.
74 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe.
75 | rpe: [n_batch, n_src, n_tgt, d_rpe]
76 | decoder_tgt: [n_batch, (n_src), n_tgt_decoder, d_model], (n_src) if using rpe.
77 | decoder_tgt_padding_mask: [n_batch, (n_src), n_tgt_decoder], (n_src) if using rpe.
78 | decoder_rpe: [n_batch, n_src, n_tgt_decoder, d_rpe]
79 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt.
80 |
81 | Returns:
82 | src: [n_batch, n_src, d_model]
83 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None
84 |
85 | Remarks:
86 | absoulte_pe should be already added to src/tgt.
87 | """
88 | attn_weights = None
89 | for mod in self.layers:
90 | src, attn_weights = mod(
91 | src=src,
92 | src_padding_mask=src_padding_mask,
93 | tgt=tgt,
94 | tgt_padding_mask=tgt_padding_mask,
95 | rpe=rpe,
96 | decoder_tgt=decoder_tgt,
97 | decoder_tgt_padding_mask=decoder_tgt_padding_mask,
98 | decoder_rpe=decoder_rpe,
99 | attn_mask=attn_mask,
100 | need_weights=need_weights,
101 | )
102 | # if self.norm is not None:
103 | # src = self.norm(src)
104 | return src, attn_weights
105 |
106 |
107 | class TransformerCrossAttention(nn.Module):
108 | def __init__(
109 | self,
110 | d_model: int,
111 | n_head: int,
112 | d_feedforward: int,
113 | dropout_p: float,
114 | activation: str,
115 | norm_first: bool,
116 | decoder_self_attn: bool,
117 | bias: bool,
118 | d_rpe: int = -1,
119 | apply_q_rpe: bool = False,
120 | ) -> None:
121 | super(TransformerCrossAttention, self).__init__()
122 | self.norm_first = norm_first
123 | self.d_feedforward = d_feedforward
124 | self.decoder_self_attn = decoder_self_attn
125 | inplace = False
126 |
127 | self.dropout = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None
128 | self.activation = _get_activation_fn(activation)
129 | self.norm1 = nn.LayerNorm(d_model)
130 |
131 | if self.decoder_self_attn:
132 | self.attn_src = AttentionRPE(
133 | d_model=d_model, n_head=n_head, dropout_p=dropout_p, bias=bias, d_rpe=d_rpe, apply_q_rpe=apply_q_rpe
134 | )
135 | self.norm_src = nn.LayerNorm(d_model)
136 | self.dropout_src = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None
137 |
138 | if self.norm_first:
139 | self.norm_tgt = nn.LayerNorm(d_model)
140 |
141 | self.attn = AttentionRPE(
142 | d_model=d_model, n_head=n_head, dropout_p=dropout_p, bias=bias, d_rpe=d_rpe, apply_q_rpe=apply_q_rpe
143 | )
144 | if self.d_feedforward > 0:
145 | self.linear1 = nn.Linear(d_model, d_feedforward)
146 | self.linear2 = nn.Linear(d_feedforward, d_model)
147 | self.norm2 = nn.LayerNorm(d_model)
148 | self.dropout1 = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None
149 | self.dropout2 = nn.Dropout(p=dropout_p, inplace=inplace) if dropout_p > 0 else None
150 |
151 | def forward(
152 | self,
153 | src: Tensor,
154 | src_padding_mask: Optional[Tensor] = None,
155 | tgt: Optional[Tensor] = None,
156 | tgt_padding_mask: Optional[Tensor] = None,
157 | rpe: Optional[Tensor] = None,
158 | decoder_tgt: Optional[Tensor] = None,
159 | decoder_tgt_padding_mask: Optional[Tensor] = None,
160 | decoder_rpe: Optional[Tensor] = None,
161 | attn_mask: Optional[Tensor] = None,
162 | need_weights: bool = False,
163 | ) -> Tuple[Tensor, Optional[Tensor]]:
164 | """
165 | Args:
166 | src: [n_batch, n_src, d_model]
167 | src_padding_mask: [n_batch, n_src], bool, if True, src is invalid.
168 | tgt: [n_batch, (n_src), n_tgt, d_model], None for self attention, (n_src) if using rpe.
169 | tgt_padding_mask: [n_batch, (n_src), n_tgt], bool, if True, tgt is invalid, (n_src) if using rpe.
170 | rpe: [n_batch, n_src, n_tgt, d_rpe]
171 | decoder_tgt: [n_batch, n_src, n_tgt_decoder, d_model], when use decoder_rpe
172 | decoder_tgt_padding_mask: [n_batch, n_src, n_tgt_decoder], when use decoder_rpe
173 | decoder_rpe: [n_batch, n_src, n_tgt_decoder, d_rpe]
174 | attn_mask: [n_batch, n_src, n_tgt], bool, if True, attn is disabled for that pair of src/tgt.
175 |
176 | Returns:
177 | out: [n_batch, n_src, d_model]
178 | attn_weights: [n_batch, n_src, n_tgt] if need_weights else None
179 |
180 | Remarks:
181 | absoulte_pe should be already added to src/tgt.
182 | """
183 | if self.decoder_self_attn:
184 | # transformer decoder
185 | if self.norm_first:
186 | _s = self.norm_src(src)
187 | if decoder_tgt is None:
188 | _s = self.attn_src(_s, tgt_padding_mask=src_padding_mask)[0]
189 | else:
190 | decoder_tgt = self.norm_src(decoder_tgt)
191 | _s = self.attn_src(_s, decoder_tgt, tgt_padding_mask=decoder_tgt_padding_mask, rpe=decoder_rpe)[0]
192 |
193 | if self.dropout_src is None:
194 | src = src + _s
195 | else:
196 | src = src + self.dropout_src(_s)
197 | else:
198 | if decoder_tgt is None:
199 | _s = self.attn_src(src, tgt_padding_mask=src_padding_mask)[0]
200 | else:
201 | _s = self.attn_src(src, decoder_tgt, tgt_padding_mask=decoder_tgt_padding_mask, rpe=decoder_rpe)[0]
202 |
203 | if self.dropout_src is None:
204 | src = self.norm_src(src + _s)
205 | else:
206 | src = self.norm_src(src + self.dropout_src(_s))
207 |
208 | if tgt is None:
209 | tgt_padding_mask = src_padding_mask
210 |
211 | if self.norm_first:
212 | src2 = self.norm1(src)
213 | if tgt is not None:
214 | tgt = self.norm_tgt(tgt)
215 | else:
216 | src2 = src
217 |
218 | # [n_batch, n_src, d_model]
219 | src2, attn_weights = self.attn(
220 | src=src2,
221 | tgt=tgt,
222 | tgt_padding_mask=tgt_padding_mask,
223 | attn_mask=attn_mask,
224 | rpe=rpe,
225 | need_weights=need_weights,
226 | )
227 |
228 | if self.d_feedforward > 0:
229 | if self.dropout1 is None:
230 | src = src + src2
231 | else:
232 | src = src + self.dropout1(src2)
233 |
234 | if self.norm_first:
235 | src2 = self.norm2(src)
236 | else:
237 | src = self.norm1(src)
238 | src2 = src
239 |
240 | src2 = self.activation(self.linear1(src2))
241 | if self.dropout is None:
242 | src2 = self.linear2(src2)
243 | else:
244 | src2 = self.linear2(self.dropout(src2))
245 |
246 | if self.dropout2 is None:
247 | src = src + src2
248 | else:
249 | src = src + self.dropout2(src2)
250 |
251 | if not self.norm_first:
252 | src = self.norm2(src)
253 | else:
254 | # densetnt vectornet
255 | src2 = self.activation(src2)
256 | if self.dropout is None:
257 | src = src + src2
258 | else:
259 | src = src + self.dropout(src2)
260 | if not self.norm_first:
261 | src = self.norm1(src)
262 |
263 | if src_padding_mask is not None:
264 | src.masked_fill_(src_padding_mask.unsqueeze(-1), 0.0)
265 | if need_weights:
266 | attn_weights.masked_fill_(src_padding_mask.unsqueeze(-1), 0.0)
267 | return src, attn_weights
268 |
--------------------------------------------------------------------------------
/src/data_modules/data_h5_av2.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Optional, Dict, Any, Tuple
3 | from pytorch_lightning import LightningDataModule
4 | from torch.utils.data import DataLoader, Dataset
5 | import numpy as np
6 | import h5py
7 |
8 |
9 | class DatasetBase(Dataset[Dict[str, np.ndarray]]):
10 | def __init__(self, filepath: str, tensor_size: Dict[str, Tuple]) -> None:
11 | super().__init__()
12 | self.tensor_size = tensor_size
13 | self.filepath = filepath
14 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
15 | self.dataset_len = int(hf.attrs["data_len"])
16 |
17 | def __len__(self) -> int:
18 | return self.dataset_len
19 |
20 |
21 | class DatasetTrain(DatasetBase):
22 | """
23 | Always train with the whole training.h5 dataset.
24 | limit_train_batches just for controlling the validation frequency.
25 | """
26 |
27 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
28 | idx = np.random.randint(self.dataset_len)
29 | idx_key = str(idx)
30 | out_dict = {"episode_idx": idx}
31 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
32 | for k, _size in self.tensor_size.items():
33 | if k in hf[idx_key]:
34 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k])
35 | else:
36 | if "/valid" in k or "/state" in k:
37 | _dtype = np.bool
38 | elif "/idx" in k:
39 | _dtype = np.int64
40 | else:
41 | _dtype = np.float32
42 | out_dict[k] = np.zeros(_size, dtype=_dtype)
43 | return out_dict
44 |
45 |
46 | class DatasetVal(DatasetBase):
47 | # for validation.h5 and testing.h5
48 | def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
49 | idx_key = str(idx)
50 | with h5py.File(self.filepath, "r", libver="latest", swmr=True) as hf:
51 | out_dict = {
52 | "episode_idx": idx,
53 | "scenario_id": hf[idx_key].attrs["scenario_id"],
54 | "scenario_center": hf[idx_key].attrs["scenario_center"],
55 | "scenario_yaw": hf[idx_key].attrs["scenario_yaw"],
56 | "with_map": hf[idx_key].attrs["with_map"], # some epidosdes in the testing dataset do not have map.
57 | }
58 | for k, _size in self.tensor_size.items():
59 | if k in hf[idx_key]:
60 | out_dict[k] = np.ascontiguousarray(hf[idx_key][k])
61 | else:
62 | if "/valid" in k or "/state" in k:
63 | _dtype = np.bool
64 | elif "/idx" in k:
65 | _dtype = np.int64
66 | else:
67 | _dtype = np.float32
68 | out_dict[k] = np.zeros(_size, dtype=_dtype)
69 | if out_dict[k].shape != _size:
70 | assert "agent" in k
71 | out_dict[k] = np.ones(_size, dtype=out_dict[k].dtype)
72 | return out_dict
73 |
74 |
75 | class DataH5av2(LightningDataModule):
76 | def __init__(
77 | self,
78 | data_dir: str,
79 | filename_train: str = "training",
80 | filename_val: str = "validation",
81 | filename_test: str = "testing",
82 | batch_size: int = 3,
83 | num_workers: int = 4,
84 | n_agent: int = 64, # if not the same as h5 dataset, use dummy agents, for scalability tests.
85 | ) -> None:
86 | super().__init__()
87 | self.interactive_challenge = False
88 |
89 | self.path_train_h5 = f"{data_dir}/{filename_train}.h5"
90 | self.path_val_h5 = f"{data_dir}/{filename_val}.h5"
91 | self.path_test_h5 = f"{data_dir}/{filename_test}.h5"
92 | self.batch_size = batch_size
93 | self.num_workers = num_workers
94 |
95 | n_step = 110
96 | n_step_history = 50
97 | n_agent_no_sim = 256
98 | n_pl = 1024
99 | n_pl_node = 20
100 |
101 | n_tl = 1
102 | n_tl_stop = 1
103 | self.tensor_size_train = {
104 | # agent states
105 | "agent/valid": (n_step, n_agent), # bool,
106 | "agent/pos": (n_step, n_agent, 2), # float32
107 | # v[1] = p[1]-p[0]. if p[1] invalid, v[1] also invalid, v[2]=v[3]
108 | "agent/vel": (n_step, n_agent, 2), # float32, v_x, v_y
109 | "agent/spd": (n_step, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy
110 | "agent/acc": (n_step, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt
111 | "agent/yaw_bbox": (n_step, n_agent, 1), # float32, yaw of the bbox heading
112 | "agent/yaw_rate": (n_step, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt
113 | # agent attributes
114 | "agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
115 | "agent/cmd": (n_agent, 8), # bool one_hot
116 | "agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2]
117 | "agent/size": (n_agent, 3), # float32: [length, width, height]
118 | "agent/goal": (n_agent, 4), # float32: [x, y, theta, v]
119 | "agent/dest": (n_agent,), # int64: index to map n_pl
120 | # map polylines
121 | "map/valid": (n_pl, n_pl_node), # bool
122 | "map/type": (n_pl, 11), # bool one_hot
123 | "map/pos": (n_pl, n_pl_node, 2), # float32
124 | "map/dir": (n_pl, n_pl_node, 2), # float32
125 | "map/boundary": (4,), # xmin, xmax, ymin, ymax
126 | # dummy traffic lights
127 | "tl_lane/valid": (n_step, n_tl), # bool
128 | "tl_lane/state": (n_step, n_tl, 5), # bool one_hot
129 | "tl_lane/idx": (n_step, n_tl), # int, -1 means not valid
130 | "tl_stop/valid": (n_step, n_tl_stop), # bool
131 | "tl_stop/state": (n_step, n_tl_stop, 5), # bool one_hot
132 | "tl_stop/pos": (n_step, n_tl_stop, 2), # x,y
133 | "tl_stop/dir": (n_step, n_tl_stop, 2), # x,y
134 | }
135 |
136 | self.tensor_size_test = {
137 | # object_id for waymo metrics
138 | "history/agent/object_id": (n_agent,),
139 | "history/agent_no_sim/object_id": (n_agent_no_sim,),
140 | # agent_sim
141 | "history/agent/valid": (n_step_history, n_agent), # bool,
142 | "history/agent/pos": (n_step_history, n_agent, 2), # float32
143 | "history/agent/vel": (n_step_history, n_agent, 2), # float32, v_x, v_y
144 | "history/agent/spd": (n_step_history, n_agent, 1), # norm of vel, signed using yaw_bbox and vel_xy
145 | "history/agent/acc": (n_step_history, n_agent, 1), # m/s2, acc[t] = (spd[t]-spd[t-1])/dt
146 | "history/agent/yaw_bbox": (n_step_history, n_agent, 1), # float32, yaw of the bbox heading
147 | "history/agent/yaw_rate": (n_step_history, n_agent, 1), # rad/s, yaw_rate[t] = (yaw[t]-yaw[t-1])/dt
148 | "history/agent/type": (n_agent, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
149 | "history/agent/role": (n_agent, 3), # bool [sdc=0, interest=1, predict=2]
150 | "history/agent/size": (n_agent, 3), # float32: [length, width, height]
151 | # agent_no_sim not used by the models currently
152 | "history/agent_no_sim/valid": (n_step_history, n_agent_no_sim),
153 | "history/agent_no_sim/pos": (n_step_history, n_agent_no_sim, 2),
154 | "history/agent_no_sim/vel": (n_step_history, n_agent_no_sim, 2),
155 | "history/agent_no_sim/spd": (n_step_history, n_agent_no_sim, 1),
156 | "history/agent_no_sim/yaw_bbox": (n_step_history, n_agent_no_sim, 1),
157 | "history/agent_no_sim/type": (n_agent_no_sim, 3),
158 | "history/agent_no_sim/size": (n_agent_no_sim, 3),
159 | # map
160 | "map/valid": (n_pl, n_pl_node), # bool
161 | "map/type": (n_pl, 11), # bool one_hot
162 | "map/pos": (n_pl, n_pl_node, 2), # float32
163 | "map/dir": (n_pl, n_pl_node, 2), # float32
164 | "map/boundary": (4,), # xmin, xmax, ymin, ymax
165 | # dummy traffic_light
166 | "history/tl_lane/valid": (n_step_history, n_tl), # bool
167 | "history/tl_lane/state": (n_step_history, n_tl, 5), # bool one_hot
168 | "history/tl_lane/idx": (n_step_history, n_tl), # int, -1 means not valid
169 | "history/tl_stop/valid": (n_step_history, n_tl_stop), # bool
170 | "history/tl_stop/state": (n_step_history, n_tl_stop, 5), # bool one_hot
171 | "history/tl_stop/pos": (n_step_history, n_tl_stop, 2), # x,y
172 | "history/tl_stop/dir": (n_step_history, n_tl_stop, 2), # dx,dy
173 | }
174 |
175 | self.tensor_size_val = {
176 | "agent/object_id": (n_agent,),
177 | "agent_no_sim/object_id": (n_agent_no_sim,),
178 | # agent_no_sim
179 | "agent_no_sim/valid": (n_step, n_agent_no_sim), # bool,
180 | "agent_no_sim/pos": (n_step, n_agent_no_sim, 2), # float32
181 | "agent_no_sim/vel": (n_step, n_agent_no_sim, 2), # float32, v_x, v_y
182 | "agent_no_sim/spd": (n_step, n_agent_no_sim, 1), # norm of vel, signed using yaw_bbox and vel_xy
183 | "agent_no_sim/yaw_bbox": (n_step, n_agent_no_sim, 1), # float32, yaw of the bbox heading
184 | "agent_no_sim/type": (n_agent_no_sim, 3), # bool one_hot [Vehicle=0, Pedestrian=1, Cyclist=2]
185 | "agent_no_sim/size": (n_agent_no_sim, 3), # float32: [length, width, height]
186 | }
187 |
188 | self.tensor_size_val = self.tensor_size_val | self.tensor_size_train | self.tensor_size_test
189 |
190 | def setup(self, stage: Optional[str] = None) -> None:
191 | if stage == "fit" or stage is None:
192 | self.train_dataset = DatasetTrain(self.path_train_h5, self.tensor_size_train)
193 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val)
194 | elif stage == "validate":
195 | self.val_dataset = DatasetVal(self.path_val_h5, self.tensor_size_val)
196 | elif stage == "test":
197 | self.test_dataset = DatasetVal(self.path_test_h5, self.tensor_size_test)
198 |
199 | def train_dataloader(self) -> DataLoader[Any]:
200 | return self._get_dataloader(self.train_dataset, self.batch_size, self.num_workers)
201 |
202 | def val_dataloader(self) -> DataLoader[Any]:
203 | return self._get_dataloader(self.val_dataset, self.batch_size, self.num_workers)
204 |
205 | def test_dataloader(self) -> DataLoader[Any]:
206 | return self._get_dataloader(self.test_dataset, self.batch_size, self.num_workers)
207 |
208 | @staticmethod
209 | def _get_dataloader(ds: Dataset, batch_size: int, num_workers: int) -> DataLoader[Any]:
210 | return DataLoader(
211 | ds,
212 | batch_size=batch_size,
213 | num_workers=num_workers,
214 | pin_memory=True,
215 | shuffle=False,
216 | drop_last=False,
217 | persistent_workers=True,
218 | )
219 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: hptr
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - absl-py=1.4.0=pyhd8ed1ab_0
11 | - aiohttp=3.7.0=py39h07f9747_0
12 | - antlr-python-runtime=4.8=pyhd8ed1ab_3
13 | - anyio=3.5.0=py39h06a4308_0
14 | - argon2-cffi=21.3.0=pyhd3eb1b0_0
15 | - argon2-cffi-bindings=21.2.0=py39h7f8727e_0
16 | - asttokens=2.0.5=pyhd3eb1b0_0
17 | - async-timeout=3.0.1=py_1000
18 | - attrs=22.1.0=pyh71513ae_1
19 | - babel=2.9.1=pyhd3eb1b0_0
20 | - backcall=0.2.0=pyhd3eb1b0_0
21 | - beautifulsoup4=4.11.1=py39h06a4308_0
22 | - blas=1.0=mkl
23 | - bleach=4.1.0=pyhd3eb1b0_0
24 | - blinker=1.5=pyhd8ed1ab_0
25 | - brotli=1.0.9=h5eee18b_7
26 | - brotli-bin=1.0.9=h5eee18b_7
27 | - brotlipy=0.7.0=py39h27cfd23_1003
28 | - bzip2=1.0.8=h7b6447c_0
29 | - c-ares=1.18.1=h7f98852_0
30 | - ca-certificates=2022.9.24=ha878542_0
31 | - cachetools=5.2.0=pyhd8ed1ab_0
32 | - certifi=2022.9.24=pyhd8ed1ab_0
33 | - cffi=1.15.1=py39h5eee18b_2
34 | - chardet=3.0.4=py39h079e4ff_1008
35 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
36 | - colorama=0.4.6=pyhd8ed1ab_0
37 | - cryptography=38.0.1=py39h9ce1e76_0
38 | - cuda=11.7.1=0
39 | - cuda-cccl=11.7.91=0
40 | - cuda-command-line-tools=11.7.1=0
41 | - cuda-compiler=11.7.1=0
42 | - cuda-cudart=11.7.99=0
43 | - cuda-cudart-dev=11.7.99=0
44 | - cuda-cuobjdump=11.7.91=0
45 | - cuda-cupti=11.7.101=0
46 | - cuda-cuxxfilt=11.7.91=0
47 | - cuda-demo-suite=11.8.86=0
48 | - cuda-documentation=11.8.86=0
49 | - cuda-driver-dev=11.7.99=0
50 | - cuda-gdb=11.8.86=0
51 | - cuda-libraries=11.7.1=0
52 | - cuda-libraries-dev=11.7.1=0
53 | - cuda-memcheck=11.8.86=0
54 | - cuda-nsight=11.8.86=0
55 | - cuda-nsight-compute=11.8.0=0
56 | - cuda-nvcc=11.7.99=0
57 | - cuda-nvdisasm=11.8.86=0
58 | - cuda-nvml-dev=11.7.91=0
59 | - cuda-nvprof=11.8.87=0
60 | - cuda-nvprune=11.7.91=0
61 | - cuda-nvrtc=11.7.99=0
62 | - cuda-nvrtc-dev=11.7.99=0
63 | - cuda-nvtx=11.7.91=0
64 | - cuda-nvvp=11.8.87=0
65 | - cuda-runtime=11.7.1=0
66 | - cuda-sanitizer-api=11.8.86=0
67 | - cuda-toolkit=11.7.1=0
68 | - cuda-tools=11.7.1=0
69 | - cuda-visual-tools=11.7.1=0
70 | - cudatoolkit=11.3.1=h2bc3f7f_2
71 | - cudnn=8.2.1=cuda11.3_0
72 | - cycler=0.11.0=pyhd3eb1b0_0
73 | - dbus=1.13.18=hb2f20db_0
74 | - debugpy=1.5.1=py39h295c915_0
75 | - decorator=5.1.1=pyhd3eb1b0_0
76 | - defusedxml=0.7.1=pyhd3eb1b0_0
77 | - docker-pycreds=0.4.0=py_0
78 | - entrypoints=0.4=py39h06a4308_0
79 | - executing=0.8.3=pyhd3eb1b0_0
80 | - expat=2.4.9=h6a678d5_0
81 | - ffmpeg=4.3.2=hca11adc_0
82 | - fftw=3.3.9=h27cfd23_1
83 | - flit-core=3.6.0=pyhd3eb1b0_0
84 | - fontconfig=2.14.1=hef1e5e3_0
85 | - fonttools=4.25.0=pyhd3eb1b0_0
86 | - freetype=2.12.1=h4a9f257_0
87 | - fsspec=2022.11.0=pyhd8ed1ab_0
88 | - future=0.18.2=pyhd8ed1ab_6
89 | - gds-tools=1.4.0.31=0
90 | - giflib=5.2.1=h7b6447c_0
91 | - gitdb=4.0.10=pyhd8ed1ab_0
92 | - gitpython=3.1.29=pyhd8ed1ab_0
93 | - glib=2.69.1=he621ea3_2
94 | - gmp=6.2.1=h295c915_3
95 | - gnutls=3.6.15=he1e5248_0
96 | - google-auth=2.15.0=pyh1a96a4e_0
97 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
98 | - grpcio=1.38.1=py39hff7568b_0
99 | - gst-plugins-base=1.14.0=h8213a91_2
100 | - gstreamer=1.14.0=h28cd5cc_2
101 | - hdf5=1.10.6=h3ffc7dd_1
102 | - hydra-core=1.1.1=pyhd8ed1ab_0
103 | - icu=58.2=he6710b0_3
104 | - idna=3.4=py39h06a4308_0
105 | - importlib-metadata=5.1.0=pyha770c72_0
106 | - importlib_resources=5.10.1=pyhd8ed1ab_0
107 | - intel-openmp=2021.4.0=h06a4308_3561
108 | - ipykernel=6.15.2=py39h06a4308_0
109 | - ipython=8.6.0=py39h06a4308_0
110 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
111 | - ipywidgets=7.6.5=pyhd3eb1b0_1
112 | - jedi=0.18.1=py39h06a4308_1
113 | - jinja2=3.1.2=py39h06a4308_0
114 | - joblib=1.1.1=py39h06a4308_0
115 | - jpeg=9e=h7f8727e_0
116 | - json5=0.9.6=pyhd3eb1b0_0
117 | - jsonschema=4.16.0=py39h06a4308_0
118 | - jupyter=1.0.0=py39h06a4308_8
119 | - jupyter_client=7.4.7=py39h06a4308_0
120 | - jupyter_console=6.4.3=pyhd3eb1b0_0
121 | - jupyter_core=4.11.2=py39h06a4308_0
122 | - jupyter_server=1.18.1=py39h06a4308_0
123 | - jupyterlab=3.5.0=py39h06a4308_0
124 | - jupyterlab_pygments=0.1.2=py_0
125 | - jupyterlab_server=2.16.3=py39h06a4308_0
126 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
127 | - kiwisolver=1.4.2=py39h295c915_0
128 | - krb5=1.19.2=hac12032_0
129 | - lame=3.100=h7b6447c_0
130 | - lcms2=2.12=h3be6417_0
131 | - ld_impl_linux-64=2.38=h1181459_1
132 | - lerc=3.0=h295c915_0
133 | - libblas=3.9.0=12_linux64_mkl
134 | - libbrotlicommon=1.0.9=h5eee18b_7
135 | - libbrotlidec=1.0.9=h5eee18b_7
136 | - libbrotlienc=1.0.9=h5eee18b_7
137 | - libcblas=3.9.0=12_linux64_mkl
138 | - libclang=10.0.1=default_hb85057a_2
139 | - libcublas=11.11.3.6=0
140 | - libcublas-dev=11.11.3.6=0
141 | - libcufft=10.9.0.58=0
142 | - libcufft-dev=10.9.0.58=0
143 | - libcufile=1.4.0.31=0
144 | - libcufile-dev=1.4.0.31=0
145 | - libcurand=10.3.0.86=0
146 | - libcurand-dev=10.3.0.86=0
147 | - libcusolver=11.4.1.48=0
148 | - libcusolver-dev=11.4.1.48=0
149 | - libcusparse=11.7.5.86=0
150 | - libcusparse-dev=11.7.5.86=0
151 | - libdeflate=1.8=h7f8727e_5
152 | - libedit=3.1.20210910=h7f8727e_0
153 | - libevent=2.1.12=h8f2d780_0
154 | - libffi=3.4.2=h6a678d5_6
155 | - libgcc-ng=11.2.0=h1234567_1
156 | - libgfortran-ng=11.2.0=h00389a5_1
157 | - libgfortran5=11.2.0=h1234567_1
158 | - libgomp=11.2.0=h1234567_1
159 | - libiconv=1.16=h7f8727e_2
160 | - libidn2=2.3.2=h7f8727e_0
161 | - liblapack=3.9.0=12_linux64_mkl
162 | - libllvm10=10.0.1=hbcb73fb_5
163 | - libnpp=11.8.0.86=0
164 | - libnpp-dev=11.8.0.86=0
165 | - libnvjpeg=11.9.0.86=0
166 | - libnvjpeg-dev=11.9.0.86=0
167 | - libpng=1.6.37=hbc83047_0
168 | - libpq=12.9=h16c4e8d_3
169 | - libprotobuf=3.20.1=h4ff587b_0
170 | - libsodium=1.0.18=h7b6447c_0
171 | - libstdcxx-ng=11.2.0=h1234567_1
172 | - libtasn1=4.16.0=h27cfd23_0
173 | - libtiff=4.4.0=hecacb30_2
174 | - libunistring=0.9.10=h27cfd23_0
175 | - libwebp=1.2.4=h11a3e52_0
176 | - libwebp-base=1.2.4=h5eee18b_0
177 | - libxcb=1.15=h7f8727e_0
178 | - libxkbcommon=1.0.1=hfa300c1_0
179 | - libxml2=2.9.14=h74e7548_0
180 | - libxslt=1.1.35=h4e12654_0
181 | - lxml=4.9.1=py39h1edc446_0
182 | - lz4-c=1.9.3=h295c915_1
183 | - markdown=3.4.1=pyhd8ed1ab_0
184 | - markupsafe=2.1.1=py39hb9d737c_1
185 | - matplotlib=3.5.3=py39h06a4308_0
186 | - matplotlib-base=3.5.3=py39hf590b9c_0
187 | - matplotlib-inline=0.1.6=py39h06a4308_0
188 | - mistune=0.8.4=py39h27cfd23_1000
189 | - mkl=2021.4.0=h06a4308_640
190 | - mkl-service=2.4.0=py39h7f8727e_0
191 | - mkl_fft=1.3.1=py39hd3c417c_0
192 | - mkl_random=1.2.2=py39h51133e4_0
193 | - multidict=6.0.2=py39hb9d737c_1
194 | - munkres=1.1.4=py_0
195 | - nbclassic=0.4.8=py39h06a4308_0
196 | - nbclient=0.5.13=py39h06a4308_0
197 | - nbconvert=6.5.4=py39h06a4308_0
198 | - nbformat=5.7.0=py39h06a4308_0
199 | - ncurses=6.3=h5eee18b_3
200 | - nest-asyncio=1.5.5=py39h06a4308_0
201 | - nettle=3.7.3=hbbd107a_1
202 | - notebook=6.5.2=py39h06a4308_0
203 | - notebook-shim=0.2.2=py39h06a4308_0
204 | - nsight-compute=2022.3.0.22=0
205 | - nspr=4.33=h295c915_0
206 | - nss=3.74=h0370c37_0
207 | - oauthlib=3.2.2=pyhd8ed1ab_0
208 | - omegaconf=2.1.1=py39hf3d152e_1
209 | - openh264=2.1.1=h4ff587b_0
210 | - openssl=1.1.1s=h7f8727e_0
211 | - packaging=21.3=pyhd8ed1ab_0
212 | - pandocfilters=1.5.0=pyhd3eb1b0_0
213 | - parso=0.8.3=pyhd3eb1b0_0
214 | - pathtools=0.1.2=py_1
215 | - pcre=8.45=h295c915_0
216 | - pexpect=4.8.0=pyhd3eb1b0_3
217 | - pickleshare=0.7.5=pyhd3eb1b0_1003
218 | - pillow=9.2.0=py39hace64e9_1
219 | - pip=22.2.2=py39h06a4308_0
220 | - ply=3.11=py39h06a4308_0
221 | - prometheus_client=0.14.1=py39h06a4308_0
222 | - promise=2.3=py39hf3d152e_7
223 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
224 | - prompt_toolkit=3.0.20=hd3eb1b0_0
225 | - protobuf=3.20.1=py39h295c915_0
226 | - psutil=5.9.0=py39h5eee18b_0
227 | - ptyprocess=0.7.0=pyhd3eb1b0_2
228 | - pure_eval=0.2.2=pyhd3eb1b0_0
229 | - pyasn1=0.4.8=py_0
230 | - pyasn1-modules=0.2.7=py_0
231 | - pycparser=2.21=pyhd3eb1b0_0
232 | - pydeprecate=0.3.2=pyhd8ed1ab_0
233 | - pygments=2.11.2=pyhd3eb1b0_0
234 | - pyjwt=2.6.0=pyhd8ed1ab_0
235 | - pyopenssl=22.0.0=pyhd3eb1b0_0
236 | - pyparsing=3.0.9=pyhd8ed1ab_0
237 | - pyqt=5.15.7=py39h6a678d5_1
238 | - pyqt5-sip=12.11.0=py39h6a678d5_1
239 | - pyrsistent=0.18.0=py39heee7806_0
240 | - pysocks=1.7.1=py39h06a4308_0
241 | - python=3.9.15=h7a1cb2a_2
242 | - python-dateutil=2.8.2=pyhd3eb1b0_0
243 | - python-fastjsonschema=2.16.2=py39h06a4308_0
244 | - python_abi=3.9=2_cp39
245 | - pytorch=1.13.0=py3.9_cuda11.7_cudnn8.5.0_0
246 | - pytorch-cuda=11.7=h67b0de4_0
247 | - pytorch-lightning=1.5.10=pyhd8ed1ab_0
248 | - pytorch-mutex=1.0=cuda
249 | - pytz=2022.1=py39h06a4308_0
250 | - pyu2f=0.1.5=pyhd8ed1ab_0
251 | - pyyaml=6.0=py39hb9d737c_4
252 | - pyzmq=23.2.0=py39h6a678d5_0
253 | - qt-main=5.15.2=h327a75a_7
254 | - qt-webengine=5.15.9=hd2b0992_4
255 | - qtconsole=5.3.2=py39h06a4308_0
256 | - qtpy=2.2.0=py39h06a4308_0
257 | - qtwebkit=5.212=h4eab89a_4
258 | - readline=8.2=h5eee18b_0
259 | - requests=2.28.1=py39h06a4308_0
260 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0
261 | - rsa=4.9=pyhd8ed1ab_0
262 | - scikit-learn=1.1.3=py39h6a678d5_0
263 | - scipy=1.9.3=py39h14f4228_0
264 | - send2trash=1.8.0=pyhd3eb1b0_1
265 | - sentry-sdk=1.11.1=pyhd8ed1ab_0
266 | - setproctitle=1.1.10=py39h3811e60_1004
267 | - setuptools=59.5.0=py39hf3d152e_0
268 | - shortuuid=1.0.11=pyhd8ed1ab_0
269 | - sip=6.6.2=py39h6a678d5_0
270 | - smmap=3.0.5=pyh44b312d_0
271 | - sniffio=1.2.0=py39h06a4308_1
272 | - soupsieve=2.3.2.post1=py39h06a4308_0
273 | - sqlite=3.40.0=h5082296_0
274 | - stack_data=0.2.0=pyhd3eb1b0_0
275 | - tensorboard=2.11.0=pyhd8ed1ab_0
276 | - tensorboard-data-server=0.6.0=py39hd97740a_2
277 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
278 | - terminado=0.13.1=py39h06a4308_0
279 | - threadpoolctl=2.2.0=pyh0d69192_0
280 | - tinycss2=1.2.1=py39h06a4308_0
281 | - tk=8.6.12=h1ccaba5_0
282 | - toml=0.10.2=pyhd3eb1b0_0
283 | - tomli=2.0.1=py39h06a4308_0
284 | - torchaudio=0.13.0=py39_cu117
285 | - torchmetrics=0.11.0=pyhd8ed1ab_0
286 | - torchvision=0.14.0=py39_cu117
287 | - tornado=6.2=py39h5eee18b_0
288 | - tqdm=4.64.1=py39h06a4308_0
289 | - traitlets=5.1.1=pyhd3eb1b0_0
290 | - transforms3d=0.4.1=pyhd8ed1ab_0
291 | - tzdata=2022g=h04d1e81_0
292 | - urllib3=1.26.12=py39h06a4308_0
293 | - wandb=0.13.6=pyhd8ed1ab_0
294 | - wcwidth=0.2.5=pyhd3eb1b0_0
295 | - webencodings=0.5.1=py39h06a4308_1
296 | - websocket-client=0.58.0=py39h06a4308_4
297 | - werkzeug=2.2.2=pyhd8ed1ab_0
298 | - wheel=0.37.1=pyhd3eb1b0_0
299 | - widgetsnbextension=3.5.2=py39h06a4308_0
300 | - x264=1!161.3030=h7f98852_1
301 | - xz=5.2.8=h5eee18b_0
302 | - yaml=0.2.5=h7f98852_2
303 | - yarl=1.7.2=py39hb9d737c_2
304 | - zeromq=4.3.4=h2531618_0
305 | - zipp=3.11.0=pyhd8ed1ab_0
306 | - zlib=1.2.13=h5eee18b_0
307 | - zstd=1.5.2=ha4553b6_0
308 | - pip:
309 | - appdirs==1.4.4
310 | - astunparse==1.6.3
311 | - black==20.8b0
312 | - clang==5.0
313 | - click==7.1.2
314 | - cloudpickle==2.2.0
315 | - flatbuffers==23.5.26
316 | - gast==0.4.0
317 | - google-pasta==0.2.0
318 | - gym==0.21.0
319 | - h5py==3.1.0
320 | - immutabledict==2.2.0
321 | - keras==2.11.0
322 | - keras-preprocessing==1.1.2
323 | - mypy-extensions==0.4.3
324 | - numpy==1.21.5
325 | - opencv-python==4.6.0.66
326 | - opt-einsum==3.3.0
327 | - pandas==1.5.3
328 | - pathspec==0.10.2
329 | - plotly==5.13.1
330 | - pyarrow==10.0.0
331 | - regex==2022.10.31
332 | - six==1.15.0
333 | - tenacity==8.2.2
334 | - tensorflow-cpu==2.11.0
335 | - tensorflow_probability==0.19.0
336 | - termcolor==1.1.0
337 | - typed-ast==1.5.4
338 | - typing-extensions==3.7.4.3
339 | - wrapt==1.12.1
340 |
--------------------------------------------------------------------------------
/src/models/metrics/nll.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import torch
3 | from omegaconf import ListConfig
4 | from typing import Dict, Optional
5 | from torch import Tensor, tensor
6 | from torch.nn import functional as F
7 | from torchmetrics.metric import Metric
8 | from torch.distributions import MultivariateNormal
9 |
10 |
11 | def compute_nll_mtr(dmean: Tensor, cov: Tensor) -> Tensor:
12 | dx = dmean[..., 0]
13 | dy = dmean[..., 1]
14 | sx = cov[..., 0, 0]
15 | sy = cov[..., 1, 1]
16 | rho = torch.tanh(cov[..., 1, 0]) # mtr uses clamp to [-0.5, 0.5]
17 | one_minus_rho2 = 1 - rho ** 2
18 | log_prob = (
19 | torch.log(sx)
20 | + torch.log(sy)
21 | + 0.5 * torch.log(one_minus_rho2)
22 | + 0.5 / one_minus_rho2 * ((dx / sx) ** 2 + (dy / sy) ** 2 - 2 * rho * dx * dy / (sx * sy))
23 | )
24 | return log_prob
25 |
26 |
27 | class NllMetrics(Metric):
28 | full_state_update = False
29 |
30 | def __init__(
31 | self,
32 | prefix: str,
33 | winner_takes_all: str,
34 | p_rand_train_agent: float,
35 | n_decoders: int,
36 | n_pred: int,
37 | l_pos: str,
38 | n_step_add_train_agent: ListConfig,
39 | focal_gamma_conf: ListConfig,
40 | w_conf: ListConfig,
41 | w_pos: ListConfig,
42 | w_yaw: ListConfig, # cos
43 | w_spd: ListConfig, # huber
44 | w_vel: ListConfig, # huber
45 | ) -> None:
46 | super().__init__(dist_sync_on_step=False)
47 | self.prefix = prefix
48 | self.winner_takes_all = winner_takes_all
49 | self.p_rand_train_agent = p_rand_train_agent
50 | self.n_decoders = n_decoders
51 | self.n_pred = n_pred
52 | self.l_pos = l_pos
53 | self.n_step_add_train_agent = n_step_add_train_agent
54 | self.focal_gamma_conf = list(focal_gamma_conf)
55 | self.w_conf = list(w_conf)
56 | self.w_pos = list(w_pos)
57 | self.w_yaw = list(w_yaw)
58 | self.w_spd = list(w_spd)
59 | self.w_vel = list(w_vel)
60 |
61 | self.add_state("counter_traj", default=tensor(0.0), dist_reduce_fx="sum")
62 | self.add_state("counter_conf", default=tensor(0.0), dist_reduce_fx="sum")
63 | self.add_state("error_pos", default=tensor(0.0), dist_reduce_fx="sum")
64 | self.add_state("error_conf", default=tensor(0.0), dist_reduce_fx="sum")
65 | self.add_state("error_yaw", default=tensor(0.0), dist_reduce_fx="sum")
66 | self.add_state("error_spd", default=tensor(0.0), dist_reduce_fx="sum")
67 | self.add_state("error_vel", default=tensor(0.0), dist_reduce_fx="sum")
68 |
69 | for i in range(self.n_decoders):
70 | for j in range(self.n_pred):
71 | self.add_state(f"counter_d{i}_p{j}", default=tensor(0.0), dist_reduce_fx="sum")
72 | self.add_state(f"conf_d{i}_p{j}", default=tensor(0.0), dist_reduce_fx="sum")
73 |
74 | def update(
75 | self,
76 | pred_valid: Tensor,
77 | pred_conf: Tensor,
78 | pred_pos: Tensor,
79 | pred_spd: Optional[Tensor],
80 | pred_vel: Optional[Tensor],
81 | pred_yaw_bbox: Optional[Tensor],
82 | pred_cov: Optional[Tensor],
83 | ref_role: Tensor,
84 | ref_type: Tensor,
85 | gt_valid: Tensor,
86 | gt_pos: Tensor,
87 | gt_spd: Tensor,
88 | gt_vel: Tensor,
89 | gt_yaw_bbox: Tensor,
90 | gt_cmd: Tensor,
91 | **kwargs,
92 | ) -> None:
93 | """
94 | Args:
95 | pred_valid: [n_scene, n_agent], bool
96 | pred_conf: [n_decoder, n_scene, n_agent, n_pred], not normalized!
97 | pred_pos: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2]
98 | pred_spd: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 1]
99 | pred_vel: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2]
100 | pred_yaw_bbox: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 1]
101 | pred_cov: [n_decoder, n_scene, n_agent, n_pred, n_step_future, 2, 2]
102 | gt_valid: [n_scene, n_agent, n_step_future], bool
103 | gt_pos: [n_scene, n_agent, n_step_future, 2]
104 | gt_spd: [n_scene, n_agent, n_step_future, 1]
105 | gt_vel: [n_scene, n_agent, n_step_future, 2]
106 | gt_yaw_bbox: [n_scene, n_agent, n_step_future, 1]
107 | ref_role: [n_scene, n_agent, 3], one hot bool [sdc=0, interest=1, predict=2]
108 | ref_type: [n_scene, n_agent, 3], one hot bool [veh=0, ped=1, cyc=2]
109 | agent_cmd: [n_scene, n_agent, 8], one hot bool
110 | """
111 | n_agent_type = ref_type.shape[-1]
112 | n_decoder, n_scene, n_agent, n_pred = pred_conf.shape
113 | assert (ref_role.any(-1) & pred_valid == ref_role.any(-1)).all(), "All relevat agents shall be predicted!"
114 |
115 | # ! prepare avails
116 | avails = ref_role.any(-1) # [n_scene, n_agent]
117 | # add rand agents for training
118 | if self.p_rand_train_agent > 0:
119 | avails = avails | (torch.bernoulli(self.p_rand_train_agent * torch.ones_like(avails)).bool())
120 | # add long tracked agents for training
121 | _track_len = gt_valid.sum(-1) # [n_scene, n_agent]
122 | for i in range(n_agent_type):
123 | if self.n_step_add_train_agent[i] > 0:
124 | avails = avails | (ref_type[:, :, i] & (_track_len > self.n_step_add_train_agent[i]))
125 |
126 | avails = gt_valid & avails.unsqueeze(-1) # [n_scene, n_agent, n_step_future]
127 | avails = avails.unsqueeze(0).expand(n_decoder, -1, -1, -1) # [n_decoder, n_scene, n_agent, n_step_future]
128 | if n_decoder > 1:
129 | # [n_decoder], randomly train ensembles with 50% of chance
130 | mask_ensemble = torch.bernoulli(0.5 * torch.ones_like(pred_conf[:, 0, 0, 0])).bool()
131 | # make sure at least one ensemble is trained
132 | if not mask_ensemble.any():
133 | mask_ensemble[torch.randint(0, n_decoder, (1,))] |= True
134 | avails = avails & mask_ensemble[:, None, None, None]
135 | # [n_decoder, n_scene, n_agent, n_pred, n_step_future]
136 | avails = avails.unsqueeze(3).expand(-1, -1, -1, n_pred, -1)
137 |
138 | # ! normalize pred_conf
139 | # [n_decoder, n_scene, n_agent, n_pred], per ensemble
140 | pred_conf = torch.softmax(pred_conf, dim=-1)
141 |
142 | # ! save conf histogram
143 | _prob = pred_conf.masked_fill(~(pred_valid[None, :, :, None]), 0.0)
144 | for i in range(self.n_decoders):
145 | for j in range(self.n_pred):
146 | x = getattr(self, f"conf_d{i}_p{j}")
147 | x += (_prob[i, :, :, j] * (avails[i, :, :, j].any(-1))).sum()
148 |
149 | # ! winnter takes all
150 | with torch.no_grad():
151 | decoder_idx = torch.arange(n_decoder)[:, None, None, None] # [n_decoder, 1, 1, 1]
152 | scene_idx = torch.arange(n_scene)[None, :, None, None] # [1, n_scene, 1, 1]
153 | agent_idx = torch.arange(n_agent)[None, None, :, None] # [1, 1, n_agent, 1]
154 |
155 | if "hard" in self.winner_takes_all:
156 | # [n_decoder, n_scene, n_agent, n_pred, n_step_future]
157 | dist = torch.norm(pred_pos - gt_pos[None, :, :, None, :, :], dim=-1)
158 | dist = dist.masked_fill(~avails, 0.0).sum(-1) # [n_decoder, n_scene, n_agent, n_pred]
159 | if "joint" in self.winner_takes_all:
160 | dist = dist.sum(2, keepdim=True) # [n_decoder, n_scene, 1, n_pred]
161 | k_top = int(self.winner_takes_all[-1])
162 | i = torch.randint(high=k_top, size=())
163 | # [n_decoder, n_scene, n_agent, 1]
164 | mode_idx = dist.topk(k_top, dim=-1, largest=False, sorted=False)[1][..., [i]]
165 | elif self.winner_takes_all == "cmd":
166 | assert n_pred == gt_cmd.shape[-1]
167 | mode_idx = (gt_cmd + 0.0).argmax(-1, keepdim=True) # [n_scene, n_agent, 1]
168 | mode_idx = mode_idx.unsqueeze(0).expand(n_decoder, -1, -1, -1) # [n_decoder, n_scene, n_agent, 1]
169 |
170 | # ! save hard assignment histogram: [n_decoder, n_scene, n_agent, n_pred]
171 | counter_modes = torch.nn.functional.one_hot(mode_idx.squeeze(-1), self.n_pred)
172 | for i in range(self.n_decoders):
173 | for j in range(self.n_pred):
174 | x = getattr(self, f"counter_d{i}_p{j}")
175 | x += (counter_modes[i, :, :, j] * (avails[i, :, :, j].any(-1))).sum()
176 |
177 | # ! avails and counter
178 | # avails: [n_decoder, n_scene, n_agent, n_pred, n_step_future]
179 | avails = avails[decoder_idx, scene_idx, agent_idx, mode_idx]
180 | self.counter_traj += avails.sum()
181 | self.counter_conf += avails[:, :, :, 0, :].any(-1).sum()
182 |
183 | # ! prepare agent dependent loss weights
184 | focal_gamma_conf, w_conf, w_pos, w_yaw, w_spd, w_vel = 0, 0, 0, 0, 0, 0
185 | for i in range(n_agent_type): # [n_scene, n_agent]
186 | focal_gamma_conf += ref_type[:, :, i] * self.focal_gamma_conf[i]
187 | w_conf += ref_type[:, :, i] * self.w_conf[i]
188 | w_pos += ref_type[:, :, i] * self.w_pos[i]
189 | w_yaw += ref_type[:, :, i] * self.w_yaw[i]
190 | w_spd += ref_type[:, :, i] * self.w_spd[i]
191 | w_vel += ref_type[:, :, i] * self.w_vel[i]
192 |
193 | # ! error_conf
194 | # pred_conf: [n_decoder, n_scene, n_agent, n_pred], not normalized!
195 | pred_conf = pred_conf[decoder_idx, scene_idx, agent_idx, mode_idx]
196 | focal_gamma_conf = torch.pow(1 - pred_conf, focal_gamma_conf[None, :, :, None])
197 | w_conf = w_conf[None, :, :, None]
198 | self.error_conf += (-torch.log(pred_conf) * w_conf * focal_gamma_conf).masked_fill(~(avails.any(-1)), 0.0).sum()
199 |
200 | # ! error_pos
201 | pred_pos = pred_pos[decoder_idx, scene_idx, agent_idx, mode_idx]
202 | if self.l_pos == "huber":
203 | errors_pos = F.huber_loss(pred_pos, gt_pos[None, :, :, None, :, :], reduction="none").sum(-1)
204 | elif self.l_pos == "l2":
205 | errors_pos = torch.norm(pred_pos - gt_pos[None, :, :, None, :, :], p=2, dim=-1)
206 | elif self.l_pos == "nll_mtr":
207 | pred_cov = pred_cov[decoder_idx, scene_idx, agent_idx, mode_idx]
208 | errors_pos = compute_nll_mtr(pred_pos - gt_pos[None, :, :, None, :, :], pred_cov)
209 | elif self.l_pos == "nll_torch":
210 | gmm = MultivariateNormal(pred_pos, scale_tril=pred_cov[decoder_idx, scene_idx, agent_idx, mode_idx])
211 | errors_pos = -gmm.log_prob(gt_pos[None, :, :, None, :, :])
212 | self.error_pos += (errors_pos * w_pos[None, :, :, None, None]).masked_fill(~avails, 0.0).sum()
213 |
214 | # ! error_spd
215 | if sum(self.w_spd) > 0 and pred_spd is not None:
216 | pred_spd = pred_spd[decoder_idx, scene_idx, agent_idx, mode_idx]
217 | errors_spd = F.huber_loss(pred_spd, gt_spd[None, :, :, None, :, :], reduction="none").squeeze(-1)
218 | self.error_spd += (errors_spd * w_spd[None, :, :, None, None]).masked_fill(~avails, 0.0).sum()
219 |
220 | # ! error_vel
221 | if sum(self.w_vel) > 0 and pred_vel is not None:
222 | pred_vel = pred_vel[decoder_idx, scene_idx, agent_idx, mode_idx]
223 | errors_vel = F.huber_loss(pred_vel, gt_vel[None, :, :, None, :, :], reduction="none").sum(-1)
224 | self.error_vel += (errors_vel * w_vel[None, :, :, None, None]).masked_fill(~avails, 0.0).sum()
225 |
226 | # ! error_yaw
227 | if sum(self.w_yaw) > 0 and pred_yaw_bbox is not None:
228 | pred_yaw_bbox = pred_yaw_bbox[decoder_idx, scene_idx, agent_idx, mode_idx]
229 | errors_yaw = -torch.cos(pred_yaw_bbox - gt_yaw_bbox[None, :, :, None, :, :]).squeeze(-1)
230 | self.error_yaw += (errors_yaw * w_yaw[None, :, :, None, None]).masked_fill(~avails, 0.0).sum()
231 |
232 | def compute(self) -> Dict[str, Tensor]:
233 |
234 | out_dict = {
235 | f"{self.prefix}/counter_traj": self.counter_traj,
236 | f"{self.prefix}/counter_conf": self.counter_conf,
237 | f"{self.prefix}/error_pos": self.error_pos,
238 | f"{self.prefix}/error_conf": self.error_conf,
239 | f"{self.prefix}/error_yaw": self.error_yaw,
240 | f"{self.prefix}/error_spd": self.error_spd,
241 | f"{self.prefix}/error_vel": self.error_vel,
242 | }
243 | out_dict[f"{self.prefix}/loss"] = (
244 | self.error_pos + self.error_yaw + self.error_spd + self.error_vel
245 | ) / self.counter_traj + self.error_conf / self.counter_conf
246 |
247 | for i in range(self.n_decoders):
248 | for j in range(self.n_pred):
249 | out_dict[f"{self.prefix}/counter_d{i}_p{j}"] = getattr(self, f"counter_d{i}_p{j}")
250 | out_dict[f"{self.prefix}/conf_d{i}_p{j}"] = getattr(self, f"conf_d{i}_p{j}")
251 |
252 | return out_dict
253 |
--------------------------------------------------------------------------------
/src/pack_h5_av2.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | import sys
3 |
4 | sys.path.append(".")
5 |
6 | from argparse import ArgumentParser
7 | from tqdm import tqdm
8 | import h5py
9 | import numpy as np
10 | from pathlib import Path
11 | from av2.datasets.motion_forecasting import scenario_serialization
12 | from av2.map.map_api import ArgoverseStaticMap
13 | from av2.geometry.interpolate import interp_arc
14 | import src.utils.pack_h5 as pack_utils
15 |
16 | # "map/type"
17 | # VEHICLE = 0
18 | # BUS = 1
19 | # BIKE = 2
20 | # UNKNOWN = 3
21 | # DOUBLE_DASH = 4
22 | # DASHED = 5
23 | # SOLID = 6
24 | # DOUBLE_SOLID = 7
25 | # DASH_SOLID = 8
26 | # SOLID_DASH = 9
27 | # CROSSWALK = 10
28 | N_PL_TYPE = 11
29 | PL_TYPE = {
30 | "VEHICLE": 0,
31 | "BUS": 1,
32 | "BIKE": 2,
33 | "NONE": 3,
34 | "UNKNOWN": 3,
35 | "DOUBLE_DASH_WHITE": 4,
36 | "DOUBLE_DASH_YELLOW": 4,
37 | "DASHED_WHITE": 5,
38 | "DASHED_YELLOW": 5,
39 | "SOLID_WHITE": 6,
40 | "SOLID_BLUE": 6,
41 | "SOLID_YELLOW": 6,
42 | "DOUBLE_SOLID_WHITE": 7,
43 | "DOUBLE_SOLID_YELLOW": 7,
44 | "DASH_SOLID_YELLOW": 8,
45 | "DASH_SOLID_WHITE": 8,
46 | "SOLID_DASH_WHITE": 9,
47 | "SOLID_DASH_YELLOW": 9,
48 | "CROSSWALK": 10,
49 | }
50 | DIM_VEH_LANES = [0, 1]
51 | DIM_CYC_LANES = [2]
52 | DIM_PED_LANES = [3, 6, 10]
53 |
54 | # "agent/type"
55 | AGENT_TYPE = {
56 | "vehicle": 0,
57 | "bus": 0,
58 | "pedestrian": 1,
59 | "motorcyclist": 2,
60 | "cyclist": 2,
61 | "riderless_bicycle": 2,
62 | }
63 | AGENT_SIZE = {
64 | "vehicle": [4.7, 2.1, 1.7],
65 | "bus": [11, 3, 3.5],
66 | "pedestrian": [0.85, 0.85, 1.75],
67 | "motorcyclist": [2, 0.8, 1.8],
68 | "cyclist": [2, 0.8, 1.8],
69 | "riderless_bicycle": [2, 0.8, 1.8],
70 | }
71 |
72 |
73 | N_PL_MAX = 1500
74 | N_AGENT_MAX = 256
75 |
76 | N_PL = 1024
77 | N_AGENT = 64
78 | N_AGENT_NO_SIM = N_AGENT_MAX - N_AGENT
79 |
80 | THRESH_MAP = 120
81 | THRESH_AGENT = 120
82 |
83 | STEP_CURRENT = 49
84 | N_STEP = 110
85 |
86 |
87 | def collate_agent_features(scenario_path, n_step):
88 | tracks = scenario_serialization.load_argoverse_scenario_parquet(scenario_path).tracks
89 |
90 | agent_id = []
91 | agent_type = []
92 | agent_states = []
93 | agent_role = []
94 | for _track in tracks:
95 | if _track.object_type not in AGENT_TYPE:
96 | continue
97 | # role: [sdc=0, interest=1, predict=2]
98 | if _track.track_id == "AV":
99 | agent_id.append(0)
100 | agent_role.append([True, False, False])
101 | else:
102 | assert int(_track.track_id) > 0
103 | agent_id.append(int(_track.track_id))
104 | # [sdc=0, interest=1, predict=2]
105 | if _track.category.value == 2: # SCORED_TRACK
106 | agent_role.append([False, True, False])
107 | elif _track.category.value == 3: # FOCAL_TRACK
108 | agent_role.append([False, False, True])
109 | else: # TRACK_FRAGMENT or UNSCORED_TRACK
110 | agent_role.append([False, False, False])
111 |
112 | agent_type.append(AGENT_TYPE[_track.object_type])
113 |
114 | step_states = [[0.0] * 9 + [False]] * n_step
115 | agent_size = AGENT_SIZE[_track.object_type]
116 | for s in _track.object_states:
117 | step_states[int(s.timestep)] = [
118 | s.position[0], # center_x
119 | s.position[1], # center_y
120 | 3, # center_z
121 | agent_size[0], # length
122 | agent_size[1], # width
123 | agent_size[2], # height
124 | s.heading, # heading in radian
125 | s.velocity[0], # velocity_x
126 | s.velocity[1], # velocity_y
127 | True, # valid
128 | ]
129 |
130 | agent_states.append(step_states)
131 |
132 | return agent_id, agent_type, agent_states, agent_role
133 |
134 |
135 | def collate_map_features(map_path):
136 | static_map = ArgoverseStaticMap.from_json(map_path)
137 |
138 | def _interpolate_centerline(left_ln_boundary, right_ln_boundary):
139 | num_interp_pts = (
140 | np.linalg.norm(np.diff(left_ln_boundary, axis=0), axis=-1).sum()
141 | + np.linalg.norm(np.diff(right_ln_boundary, axis=0), axis=-1).sum()
142 | ) / 2.0
143 | num_interp_pts = int(num_interp_pts) + 1
144 | left_even_pts = interp_arc(num_interp_pts, points=left_ln_boundary)
145 | right_even_pts = interp_arc(num_interp_pts, points=right_ln_boundary)
146 | centerline_pts = (left_even_pts + right_even_pts) / 2.0
147 | return centerline_pts, left_even_pts, right_even_pts
148 |
149 | mf_id = []
150 | mf_xyz = []
151 | mf_type = []
152 | mf_edge = []
153 | lane_boundary_set = []
154 |
155 | for _id, ped_xing in static_map.vector_pedestrian_crossings.items():
156 | v0, v1 = ped_xing.edge1.xyz
157 | v2, v3 = ped_xing.edge2.xyz
158 | pl_crosswalk = pack_utils.get_polylines_from_polygon(np.array([v0, v1, v3, v2]))
159 | mf_id.extend([_id] * len(pl_crosswalk))
160 | mf_type.extend([PL_TYPE["CROSSWALK"]] * len(pl_crosswalk))
161 | mf_xyz.extend(pl_crosswalk)
162 |
163 | for _id, lane_segment in static_map.vector_lane_segments.items():
164 | centerline_pts, left_even_pts, right_even_pts = _interpolate_centerline(
165 | lane_segment.left_lane_boundary.xyz, lane_segment.right_lane_boundary.xyz
166 | )
167 |
168 | mf_id.append(_id)
169 | mf_xyz.append(centerline_pts)
170 | mf_type.append(PL_TYPE[lane_segment.lane_type])
171 |
172 | if (lane_segment.left_lane_boundary not in lane_boundary_set) and not (
173 | lane_segment.is_intersection and lane_segment.left_mark_type in ["NONE", "UNKOWN"]
174 | ):
175 | lane_boundary_set.append(lane_segment.left_lane_boundary)
176 | mf_xyz.append(left_even_pts)
177 | mf_id.append(-2)
178 | mf_type.append(PL_TYPE[lane_segment.left_mark_type])
179 |
180 | if (lane_segment.right_lane_boundary not in lane_boundary_set) and not (
181 | lane_segment.is_intersection and lane_segment.right_mark_type in ["NONE", "UNKOWN"]
182 | ):
183 | lane_boundary_set.append(lane_segment.right_lane_boundary)
184 | mf_xyz.append(right_even_pts)
185 | mf_id.append(-2)
186 | mf_type.append(PL_TYPE[lane_segment.right_mark_type])
187 |
188 | for _id_exit in lane_segment.successors:
189 | mf_edge.append([_id, _id_exit])
190 | else:
191 | mf_edge.append([_id, -1])
192 |
193 | return mf_id, mf_xyz, mf_type, mf_edge
194 |
195 |
196 | def main():
197 | parser = ArgumentParser(allow_abbrev=True)
198 | parser.add_argument("--data-dir", default="/cluster/scratch/zhejzhan/av2_motion")
199 | parser.add_argument("--dataset", default="training")
200 | parser.add_argument("--out-dir", default="/cluster/scratch/zhejzhan/h5_av2_hptr")
201 | parser.add_argument("--rand-pos", default=50.0, type=float, help="Meter. Set to -1 to disable.")
202 | parser.add_argument("--rand-yaw", default=3.14, type=float, help="Radian. Set to -1 to disable.")
203 | parser.add_argument("--dest-no-pred", action="store_true")
204 | args = parser.parse_args()
205 |
206 | if "training" in args.dataset:
207 | pack_all = True # ["agent/valid"]
208 | pack_history = False # ["history/agent/valid"]
209 | n_step = N_STEP
210 | elif "validation" in args.dataset:
211 | pack_all = True
212 | pack_history = True
213 | n_step = N_STEP
214 | elif "testing" in args.dataset:
215 | pack_all = False
216 | pack_history = True
217 | n_step = STEP_CURRENT + 1
218 |
219 | out_path = Path(args.out_dir)
220 | out_path.mkdir(exist_ok=True)
221 | out_h5_path = out_path / (args.dataset + ".h5")
222 |
223 | scenario_list = sorted(list((Path(args.data_dir) / args.dataset).glob("*")))
224 | n_pl_max, n_agent_max, n_agent_sim, n_agent_no_sim = 0, 0, 0, 0
225 | with h5py.File(out_h5_path, "w") as hf:
226 | hf.attrs["data_len"] = len(scenario_list)
227 | for i in tqdm(range(hf.attrs["data_len"])):
228 | scenario_folder = scenario_list[i]
229 | mf_id, mf_xyz, mf_type, mf_edge = collate_map_features(
230 | scenario_folder / f"log_map_archive_{scenario_folder.name}.json"
231 | )
232 | agent_id, agent_type, agent_states, agent_role = collate_agent_features(
233 | scenario_folder / f"scenario_{scenario_folder.name}.parquet", n_step
234 | )
235 |
236 | episode = {}
237 | n_pl = pack_utils.pack_episode_map(
238 | episode=episode, mf_id=mf_id, mf_xyz=mf_xyz, mf_type=mf_type, mf_edge=mf_edge, n_pl_max=N_PL_MAX
239 | )
240 | n_agent = pack_utils.pack_episode_agents(
241 | episode=episode,
242 | agent_id=agent_id,
243 | agent_type=agent_type,
244 | agent_states=agent_states,
245 | agent_role=agent_role,
246 | pack_all=pack_all,
247 | pack_history=pack_history,
248 | n_agent_max=N_AGENT_MAX,
249 | step_current=STEP_CURRENT,
250 | )
251 | scenario_center, scenario_yaw = pack_utils.center_at_sdc(episode, args.rand_pos, args.rand_yaw)
252 | n_pl_max = max(n_pl_max, n_pl)
253 | n_agent_max = max(n_agent_max, n_agent)
254 |
255 | episode_reduced = {}
256 | pack_utils.filter_episode_map(episode, N_PL, THRESH_MAP, thresh_z=-1)
257 | assert episode["map/valid"].any(1).sum() > 0
258 | pack_utils.repack_episode_map(episode, episode_reduced, N_PL, N_PL_TYPE)
259 |
260 | if "training" in args.dataset:
261 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents(
262 | episode=episode,
263 | episode_reduced=episode_reduced,
264 | n_agent=N_AGENT,
265 | prefix="",
266 | dim_veh_lanes=DIM_VEH_LANES,
267 | dist_thresh_agent=THRESH_AGENT,
268 | step_current=STEP_CURRENT,
269 | )
270 | pack_utils.repack_episode_agents(
271 | episode=episode,
272 | episode_reduced=episode_reduced,
273 | mask_sim=mask_sim,
274 | n_agent=N_AGENT,
275 | prefix="",
276 | dim_veh_lanes=DIM_VEH_LANES,
277 | dim_cyc_lanes=DIM_CYC_LANES,
278 | dim_ped_lanes=DIM_PED_LANES,
279 | dest_no_pred=args.dest_no_pred,
280 | )
281 | elif "validation" in args.dataset:
282 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents(
283 | episode=episode,
284 | episode_reduced=episode_reduced,
285 | n_agent=N_AGENT,
286 | prefix="history/",
287 | dim_veh_lanes=DIM_VEH_LANES,
288 | dist_thresh_agent=THRESH_AGENT,
289 | step_current=STEP_CURRENT,
290 | )
291 | pack_utils.repack_episode_agents(
292 | episode=episode,
293 | episode_reduced=episode_reduced,
294 | mask_sim=mask_sim,
295 | n_agent=N_AGENT,
296 | prefix="",
297 | dim_veh_lanes=DIM_VEH_LANES,
298 | dim_cyc_lanes=DIM_CYC_LANES,
299 | dim_ped_lanes=DIM_PED_LANES,
300 | dest_no_pred=args.dest_no_pred,
301 | )
302 | pack_utils.repack_episode_agents(episode, episode_reduced, mask_sim, N_AGENT, "history/")
303 | pack_utils.repack_episode_agents_no_sim(episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "")
304 | pack_utils.repack_episode_agents_no_sim(
305 | episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "history/"
306 | )
307 | elif "testing" in args.dataset:
308 | mask_sim, mask_no_sim = pack_utils.filter_episode_agents(
309 | episode=episode,
310 | episode_reduced=episode_reduced,
311 | n_agent=N_AGENT,
312 | prefix="history/",
313 | dim_veh_lanes=DIM_VEH_LANES,
314 | dist_thresh_agent=THRESH_AGENT,
315 | step_current=STEP_CURRENT,
316 | )
317 | pack_utils.repack_episode_agents(episode, episode_reduced, mask_sim, N_AGENT, "history/")
318 | pack_utils.repack_episode_agents_no_sim(
319 | episode, episode_reduced, mask_no_sim, N_AGENT_NO_SIM, "history/"
320 | )
321 | n_agent_sim = max(n_agent_sim, mask_sim.sum())
322 | n_agent_no_sim = max(n_agent_no_sim, mask_no_sim.sum())
323 |
324 | episode_reduced["map/boundary"] = pack_utils.get_map_boundary(
325 | episode_reduced["map/valid"], episode_reduced["map/pos"]
326 | )
327 |
328 | hf_episode = hf.create_group(str(i))
329 | hf_episode.attrs["scenario_id"] = scenario_folder.name
330 | hf_episode.attrs["scenario_center"] = scenario_center
331 | hf_episode.attrs["scenario_yaw"] = scenario_yaw
332 | hf_episode.attrs["with_map"] = True
333 |
334 | for k, v in episode_reduced.items():
335 | hf_episode.create_dataset(k, data=v, compression="gzip", compression_opts=4, shuffle=True)
336 |
337 | print(f"n_pl_max: {n_pl_max}, n_agent_max: {n_agent_max}")
338 | print(f"n_agent_sim: {n_agent_sim}, n_agent_no_sim: {n_agent_no_sim}")
339 |
340 |
341 | if __name__ == "__main__":
342 | main()
343 |
--------------------------------------------------------------------------------
/src/data_modules/sc_global.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Dict
3 | from omegaconf import DictConfig
4 | import torch
5 | from torch import nn, Tensor
6 | from utils.pose_pe import PosePE
7 |
8 |
9 | class SceneCentricGlobal(nn.Module):
10 | def __init__(
11 | self,
12 | time_step_current: int,
13 | data_size: DictConfig,
14 | dropout_p_history: float,
15 | use_current_tl: bool,
16 | add_ohe: bool,
17 | pl_aggr: bool,
18 | pose_pe: DictConfig,
19 | ) -> None:
20 | super().__init__()
21 | self.dropout_p_history = dropout_p_history # [0, 1], turn off if set to negative
22 | self.step_current = time_step_current
23 | self.n_step_hist = time_step_current + 1
24 | self.use_current_tl = use_current_tl
25 | self.add_ohe = add_ohe
26 | self.pl_aggr = pl_aggr
27 | self.n_pl_node = data_size["map/valid"][-1]
28 |
29 | self.pose_pe_agent = PosePE(pose_pe["agent"])
30 | self.pose_pe_map = PosePE(pose_pe["map"])
31 | self.pose_pe_tl = PosePE(pose_pe["tl"])
32 |
33 | tl_attr_dim = self.pose_pe_tl.out_dim + data_size["tl_stop/state"][-1]
34 | if self.pl_aggr:
35 | agent_attr_dim = (
36 | self.pose_pe_agent.out_dim * self.n_step_hist
37 | + data_size["agent/spd"][-1] * self.n_step_hist # 1
38 | + data_size["agent/vel"][-1] * self.n_step_hist # 2
39 | + data_size["agent/yaw_rate"][-1] * self.n_step_hist # 1
40 | + data_size["agent/acc"][-1] * self.n_step_hist # 1
41 | + data_size["agent/size"][-1] # 3
42 | + data_size["agent/type"][-1] # 3
43 | + self.n_step_hist # valid
44 | )
45 | map_attr_dim = self.pose_pe_map.out_dim * self.n_pl_node + data_size["map/type"][-1] + self.n_pl_node
46 | else:
47 | agent_attr_dim = (
48 | self.pose_pe_agent.out_dim
49 | + data_size["agent/spd"][-1] # 1
50 | + data_size["agent/vel"][-1] # 2
51 | + data_size["agent/yaw_rate"][-1] # 1
52 | + data_size["agent/acc"][-1] # 1
53 | + data_size["agent/size"][-1] # 3
54 | + data_size["agent/type"][-1] # 3
55 | )
56 | map_attr_dim = self.pose_pe_map.out_dim + data_size["map/type"][-1]
57 |
58 | if self.add_ohe:
59 | self.register_buffer("history_step_ohe", torch.eye(self.n_step_hist))
60 | self.register_buffer("pl_node_ohe", torch.eye(self.n_pl_node))
61 | if not self.pl_aggr:
62 | map_attr_dim += self.n_pl_node
63 | agent_attr_dim += self.n_step_hist
64 | if not self.use_current_tl:
65 | tl_attr_dim += self.n_step_hist
66 |
67 | self.model_kwargs = {
68 | "agent_attr_dim": agent_attr_dim,
69 | "map_attr_dim": map_attr_dim,
70 | "tl_attr_dim": tl_attr_dim,
71 | "n_step_hist": self.n_step_hist,
72 | "n_pl_node": self.n_pl_node,
73 | "use_current_tl": self.use_current_tl,
74 | "pl_aggr": self.pl_aggr,
75 | }
76 |
77 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
78 | """
79 | Args: scene-centric Dict
80 | # (ref) reference information for transform back to global coordinate and submission to waymo
81 | "ref/pos": [n_scene, n_agent, 1, 2]
82 | "ref/yaw": [n_scene, n_agent, 1]
83 | "ref/rot": [n_scene, n_agent, 2, 2]
84 | "ref/role": [n_scene, n_agent, 3]
85 | "ref/type": [n_scene, n_agent, 3]
86 | # (gt) ground-truth agent future for training, not available for testing
87 | "gt/valid": [n_scene, n_agent, n_step_future], bool
88 | "gt/pos": [n_scene, n_agent, n_step_future, 2]
89 | "gt/spd": [n_scene, n_agent, n_step_future, 1]
90 | "gt/vel": [n_scene, n_agent, n_step_future, 2]
91 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1]
92 | "gt/cmd": [n_scene, n_agent, 8]
93 | # (sc) scene-centric agents states
94 | "sc/agent_valid": [n_scene, n_agent, n_step_hist]
95 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2]
96 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2]
97 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1]
98 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1]
99 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1]
100 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1]
101 | # agent attributes
102 | "sc/agent_type": [n_scene, n_agent, 3]
103 | "sc/agent_role": [n_scene, n_agent, 3]
104 | "sc/agent_size": [n_scene, n_agent, 3]
105 | # map polylines
106 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool
107 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot
108 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32
109 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32
110 | # traffic lights
111 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool
112 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot
113 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y
114 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y
115 |
116 | Returns: add following keys to batch Dict
117 | # agent type: no need to be aggregated.
118 | "input/agent_type": [n_scene, n_agent, 3]
119 | # agent history
120 | if pl_aggr: # processed by mlp encoder, c.f. our method.
121 | "input/agent_valid": [n_scene, n_agent], bool
122 | "input/agent_attr": [n_scene, n_agent, agent_attr_dim]
123 | "input/agent_pos": [n_scene, n_agent, 2], (x,y), pos of the current step.
124 | "input/map_valid": [n_scene, n_pl], bool
125 | "input/map_attr": [n_scene, n_pl, map_attr_dim]
126 | "input/map_pos": [n_scene, n_pl, 2], (x,y), polyline starting node in global coordinate.
127 | else: # processed by pointnet encoder, c.f. vectornet.
128 | "input/agent_valid": [n_scene, n_agent, n_step_hist], bool
129 | "input/agent_attr": [n_scene, n_agent, n_step_hist, agent_attr_dim]
130 | "input/agent_pos": [n_scene, n_agent, 2]
131 | "input/map_valid": [n_scene, n_pl, n_pl_node], bool
132 | "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim]
133 | "input/"map_pos"": [n_scene, n_pl, 2]
134 | # traffic lights: stop point, cannot be aggregated, detections are not tracked, singular node polyline.
135 | if use_current_tl:
136 | "input/tl_valid": [n_scene, 1, n_tl], bool
137 | "input/tl_attr": [n_scene, 1, n_tl, tl_attr_dim]
138 | "input/tl_pos": [n_scene, 1, n_tl, 2] (x,y)
139 | else:
140 | "input/tl_valid": [n_scene, n_step_hist, n_tl], bool
141 | "input/tl_attr": [n_scene, n_step_hist, n_tl, tl_attr_dim]
142 | "input/tl_pos": [n_scene, n_step_hist, n_tl, 2]
143 | """
144 | batch["input/agent_type"] = batch["sc/agent_type"]
145 | batch["input/agent_valid"] = batch["sc/agent_valid"]
146 | batch["input/tl_valid"] = batch["sc/tl_valid"]
147 | batch["input/map_valid"] = batch["sc/map_valid"]
148 |
149 | # ! randomly mask history agent/tl/map
150 | if self.training and (0 < self.dropout_p_history <= 1.0):
151 | prob_mask = torch.ones_like(batch["input/agent_valid"][..., :-1]) * (1 - self.dropout_p_history)
152 | batch["input/agent_valid"][..., :-1] &= torch.bernoulli(prob_mask).bool()
153 | prob_mask = torch.ones_like(batch["input/tl_valid"]) * (1 - self.dropout_p_history)
154 | batch["input/tl_valid"] &= torch.bernoulli(prob_mask).bool()
155 | prob_mask = torch.ones_like(batch["input/map_valid"]) * (1 - self.dropout_p_history)
156 | batch["input/map_valid"] &= torch.bernoulli(prob_mask).bool()
157 |
158 | # ! prepare "input/agent"
159 | batch["input/agent_pos"] = batch["ref/pos"].squeeze(2)
160 | if self.pl_aggr: # [n_scene, n_agent, agent_attr_dim]
161 | agent_invalid = ~batch["input/agent_valid"].unsqueeze(-1) # [n_scene, n_agent, n_step_hist, 1]
162 | agent_invalid_reduced = agent_invalid.all(-2) # [n_scene, n_agent, 1]
163 | batch["input/agent_attr"] = torch.cat(
164 | [
165 | self.pose_pe_agent(batch["sc/agent_pos"], batch["sc/agent_yaw_bbox"])
166 | .masked_fill(agent_invalid, 0)
167 | .flatten(-2, -1),
168 | batch["sc/agent_vel"].masked_fill(agent_invalid, 0).flatten(-2, -1), # n_step_hist*2
169 | batch["sc/agent_spd"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
170 | batch["sc/agent_yaw_rate"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
171 | batch["sc/agent_acc"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
172 | batch["sc/agent_size"].masked_fill(agent_invalid_reduced, 0), # 3
173 | batch["sc/agent_type"].masked_fill(agent_invalid_reduced, 0), # 3
174 | batch["input/agent_valid"], # n_step_hist
175 | ],
176 | dim=-1,
177 | )
178 | batch["input/agent_valid"] = batch["input/agent_valid"].any(-1) # [n_scene, n_agent]
179 | else: # [n_scene, n_agent, n_step_hist, agent_attr_dim]
180 | batch["input/agent_attr"] = torch.cat(
181 | [
182 | self.pose_pe_agent(batch["sc/agent_pos"], batch["sc/agent_yaw_bbox"]),
183 | batch["sc/agent_vel"], # vel xy, 2
184 | batch["sc/agent_spd"], # speed, 1
185 | batch["sc/agent_yaw_rate"], # yaw rate, 1
186 | batch["sc/agent_acc"], # acc, 1
187 | batch["sc/agent_size"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3
188 | batch["sc/agent_type"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3
189 | ],
190 | dim=-1,
191 | )
192 |
193 | # ! prepare "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim]
194 | batch["input/map_pos"] = batch["sc/map_pos"][:, :, 0]
195 | if self.pl_aggr: # [n_scene, n_pl, map_attr_dim]
196 | map_invalid = ~batch["input/map_valid"].unsqueeze(-1) # [n_scene, n_pl, n_pl_node, 1]
197 | map_invalid_reduced = map_invalid.all(-2) # [n_scene, n_pl, 1]
198 | batch["input/map_attr"] = torch.cat(
199 | [
200 | self.pose_pe_map(batch["sc/map_pos"], batch["sc/map_dir"])
201 | .masked_fill(map_invalid, 0)
202 | .flatten(-2, -1),
203 | batch["sc/map_type"].masked_fill(map_invalid_reduced, 0), # n_pl_type
204 | batch["input/map_valid"], # n_pl_node
205 | ],
206 | dim=-1,
207 | )
208 | batch["input/map_valid"] = batch["input/map_valid"].any(-1) # [n_scene, n_pl]
209 | else: # [n_scene, n_pl, n_pl_node, map_attr_dim]
210 | batch["input/map_attr"] = torch.cat(
211 | [
212 | self.pose_pe_map(batch["sc/map_pos"], batch["sc/map_dir"]), # pl_dim
213 | batch["sc/map_type"].unsqueeze(-2).expand(-1, -1, self.n_pl_node, -1), # n_pl_type
214 | ],
215 | dim=-1,
216 | )
217 |
218 | # ! prepare "input/tl_attr": [n_scene, n_step_hist/1, n_tl, tl_attr_dim]
219 | # [n_scene, n_step_hist, n_tl, 2]
220 | tl_pos = batch["sc/tl_pos"]
221 | tl_dir = batch["sc/tl_dir"]
222 | tl_state = batch["sc/tl_state"]
223 | if self.use_current_tl:
224 | tl_pos = tl_pos[:, [-1]] # [n_scene, 1, n_tl, 2]
225 | tl_dir = tl_dir[:, [-1]] # [n_scene, 1, n_tl, 2]
226 | tl_state = tl_state[:, [-1]] # [n_scene, 1, n_tl, 5]
227 | batch["input/tl_valid"] = batch["input/tl_valid"][:, [-1]] # [n_scene, 1, n_tl]
228 | batch["input/tl_attr"] = torch.cat([self.pose_pe_tl(tl_pos, tl_dir), tl_state], dim=-1)
229 | batch["input/tl_pos"] = tl_pos
230 | # ! add one-hot encoding for sequence (temporal, order of polyline nodes)
231 | if self.add_ohe:
232 | n_scene, n_agent, _ = batch["sc/agent_valid"].shape
233 | n_pl = batch["sc/map_valid"].shape[1]
234 | if not self.pl_aggr: # there is no need to add ohe if self.pl_aggr
235 | batch["input/agent_attr"] = torch.cat(
236 | [
237 | batch["input/agent_attr"],
238 | self.history_step_ohe[None, None, :, :].expand(n_scene, n_agent, -1, -1),
239 | ],
240 | dim=-1,
241 | )
242 | batch["input/map_attr"] = torch.cat(
243 | [batch["input/map_attr"], self.pl_node_ohe[None, None, :, :].expand(n_scene, n_pl, -1, -1),],
244 | dim=-1,
245 | )
246 |
247 | if not self.use_current_tl: # there is no need to add ohe if use_current_tl
248 | n_tl = batch["input/tl_valid"].shape[-1]
249 | batch["input/tl_attr"] = torch.cat(
250 | [batch["input/tl_attr"], self.history_step_ohe[None, :, None, :].expand(n_scene, -1, n_tl, -1)],
251 | dim=-1,
252 | )
253 |
254 | return batch
255 |
--------------------------------------------------------------------------------
/src/data_modules/sc_relative.py:
--------------------------------------------------------------------------------
1 | # Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
2 | from typing import Dict
3 | from omegaconf import DictConfig
4 | import torch
5 | from torch import nn, Tensor
6 | from utils.transform_utils import torch_rad2rot, torch_pos2local, torch_dir2local, torch_rad2local
7 | from utils.pose_pe import PosePE
8 |
9 |
10 | class SceneCentricRelative(nn.Module):
11 | def __init__(
12 | self,
13 | time_step_current: int,
14 | data_size: DictConfig,
15 | dropout_p_history: float,
16 | use_current_tl: bool,
17 | add_ohe: bool,
18 | pl_aggr: bool,
19 | pose_pe: DictConfig,
20 | ) -> None:
21 | super().__init__()
22 | self.dropout_p_history = dropout_p_history # [0, 1], turn off if set to negative
23 | self.step_current = time_step_current
24 | self.n_step_hist = time_step_current + 1
25 | self.use_current_tl = use_current_tl
26 | self.add_ohe = add_ohe
27 | self.pl_aggr = pl_aggr
28 | self.n_pl_node = data_size["map/valid"][-1]
29 |
30 | self.pose_pe_agent = PosePE(pose_pe["agent"])
31 | self.pose_pe_map = PosePE(pose_pe["map"])
32 |
33 | tl_attr_dim = data_size["tl_stop/state"][-1]
34 | if self.pl_aggr:
35 | agent_attr_dim = (
36 | self.pose_pe_agent.out_dim * self.n_step_hist
37 | + data_size["agent/spd"][-1] * self.n_step_hist # 1
38 | + data_size["agent/vel"][-1] * self.n_step_hist # 2
39 | + data_size["agent/yaw_rate"][-1] * self.n_step_hist # 1
40 | + data_size["agent/acc"][-1] * self.n_step_hist # 1
41 | + data_size["agent/size"][-1] # 3
42 | + data_size["agent/type"][-1] # 3
43 | + self.n_step_hist # valid
44 | )
45 | map_attr_dim = self.pose_pe_map.out_dim * self.n_pl_node + data_size["map/type"][-1] + self.n_pl_node
46 | else:
47 | agent_attr_dim = (
48 | self.pose_pe_agent.out_dim
49 | + data_size["agent/spd"][-1] # 1
50 | + data_size["agent/vel"][-1] # 2
51 | + data_size["agent/yaw_rate"][-1] # 1
52 | + data_size["agent/acc"][-1] # 1
53 | + data_size["agent/size"][-1] # 3
54 | + data_size["agent/type"][-1] # 3
55 | )
56 | map_attr_dim = self.pose_pe_map.out_dim + data_size["map/type"][-1]
57 |
58 | if self.add_ohe:
59 | self.register_buffer("history_step_ohe", torch.eye(self.n_step_hist))
60 | self.register_buffer("pl_node_ohe", torch.eye(self.n_pl_node))
61 | if not self.pl_aggr:
62 | map_attr_dim += self.n_pl_node
63 | agent_attr_dim += self.n_step_hist
64 | if not self.use_current_tl:
65 | tl_attr_dim += self.n_step_hist
66 |
67 | self.model_kwargs = {
68 | "agent_attr_dim": agent_attr_dim,
69 | "map_attr_dim": map_attr_dim,
70 | "tl_attr_dim": tl_attr_dim,
71 | "n_step_hist": self.n_step_hist,
72 | "n_pl_node": self.n_pl_node,
73 | "use_current_tl": self.use_current_tl,
74 | "pl_aggr": self.pl_aggr,
75 | }
76 |
77 | def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
78 | """
79 | Args: scene-centric Dict, masked according to valid
80 | # (ref) reference information for transform back to global coordinate and submission to waymo
81 | "ref/pos": [n_scene, n_agent, 1, 2]
82 | "ref/yaw": [n_scene, n_agent, 1]
83 | "ref/rot": [n_scene, n_agent, 2, 2]
84 | "ref/role": [n_scene, n_agent, 3]
85 | "ref/type": [n_scene, n_agent, 3]
86 | # (gt) ground-truth agent future for training, not available for testing
87 | "gt/valid": [n_scene, n_agent, n_step_future], bool
88 | "gt/pos": [n_scene, n_agent, n_step_future, 2]
89 | "gt/spd": [n_scene, n_agent, n_step_future, 1]
90 | "gt/vel": [n_scene, n_agent, n_step_future, 2]
91 | "gt/yaw_bbox": [n_scene, n_agent, n_step_future, 1]
92 | "gt/cmd": [n_scene, n_agent, 8]
93 | # (sc) scene-centric agents states
94 | "sc/agent_valid": [n_scene, n_agent, n_step_hist]
95 | "sc/agent_pos": [n_scene, n_agent, n_step_hist, 2]
96 | "sc/agent_vel": [n_scene, n_agent, n_step_hist, 2]
97 | "sc/agent_spd": [n_scene, n_agent, n_step_hist, 1]
98 | "sc/agent_acc": [n_scene, n_agent, n_step_hist, 1]
99 | "sc/agent_yaw_bbox": [n_scene, n_agent, n_step_hist, 1]
100 | "sc/agent_yaw_rate": [n_scene, n_agent, n_step_hist, 1]
101 | # agent attributes
102 | "sc/agent_type": [n_scene, n_agent, 3]
103 | "sc/agent_role": [n_scene, n_agent, 3]
104 | "sc/agent_size": [n_scene, n_agent, 3]
105 | # map polylines
106 | "sc/map_valid": [n_scene, n_pl, n_pl_node], bool
107 | "sc/map_type": [n_scene, n_pl, 11], bool one_hot
108 | "sc/map_pos": [n_scene, n_pl, n_pl_node, 2], float32
109 | "sc/map_dir": [n_scene, n_pl, n_pl_node, 2], float32
110 | # traffic lights
111 | "sc/tl_valid": [n_scene, n_step_hist, n_tl], bool
112 | "sc/tl_state": [n_scene, n_step_hist, n_tl, 5], bool one_hot
113 | "sc/tl_pos": [n_scene, n_step_hist, n_tl, 2], x,y
114 | "sc/tl_dir": [n_scene, n_step_hist, n_tl, 2], x,y
115 |
116 | Returns: add following keys to batch Dict
117 | # agent type: no need to be aggregated.
118 | "input/agent_type": [n_scene, n_agent, 3]
119 | # agent history
120 | if self.pl_aggr: # processed by mlp encoder, c.f. our method.
121 | "input/agent_valid": [n_scene, n_agent], bool
122 | "input/agent_attr": [n_scene, n_agent, agent_attr_dim] in local coordinate wrt. the current step.
123 | "input/agent_pose": [n_scene, n_agent, 3], (x,y,yaw), pos of the current step.
124 | "input/map_valid": [n_scene, n_pl], bool
125 | "input/map_attr": [n_scene, n_pl, map_attr_dim], in local coordinate wrt. the first node.
126 | "input/map_pose": [n_scene, n_pl, 3], (x,y,yaw), polyline starting node in global coordinate.
127 | else: # processed by pointnet encoder, c.f. vectornet.
128 | "input/agent_valid": [n_scene, n_agent, n_step_hist], bool
129 | "input/agent_attr": [n_scene, n_agent, n_step_hist, agent_attr_dim]
130 | "input/agent_pose": [n_scene, n_agent, 3]
131 | "input/map_valid": [n_scene, n_pl, n_pl_node], bool
132 | "input/map_attr": [n_scene, n_pl, n_pl_node, map_attr_dim]
133 | "input/"map_pose"": [n_scene, n_pl, 3]
134 | # traffic lights: stop point, cannot be aggregated, detections are not tracked, singular node polyline.
135 | if use_current_tl:
136 | "input/tl_valid": [n_scene, 1, n_tl], bool
137 | "input/tl_attr": [n_scene, 1, n_tl, tl_attr_dim] no positional info
138 | "input/tl_pose": [n_scene, 1, n_tl, 3] (x,y,yaw)
139 | else:
140 | "input/tl_valid": [n_scene, n_step_hist, n_tl], bool
141 | "input/tl_attr": [n_scene, n_step_hist, n_tl, tl_attr_dim]
142 | "input/tl_pose": [n_scene, n_step_hist, n_tl, 3]
143 | """
144 | batch["input/agent_type"] = batch["sc/agent_type"]
145 | batch["input/agent_valid"] = batch["sc/agent_valid"]
146 | batch["input/tl_valid"] = batch["sc/tl_valid"]
147 | batch["input/map_valid"] = batch["sc/map_valid"]
148 |
149 | # ! randomly mask history agent/tl/map
150 | if self.training and (0 < self.dropout_p_history <= 1.0):
151 | prob_mask = torch.ones_like(batch["input/agent_valid"][..., :-1]) * (1 - self.dropout_p_history)
152 | batch["input/agent_valid"][..., :-1] &= torch.bernoulli(prob_mask).bool()
153 | prob_mask = torch.ones_like(batch["input/tl_valid"]) * (1 - self.dropout_p_history)
154 | batch["input/tl_valid"] &= torch.bernoulli(prob_mask).bool()
155 | prob_mask = torch.ones_like(batch["input/map_valid"]) * (1 - self.dropout_p_history)
156 | batch["input/map_valid"] &= torch.bernoulli(prob_mask).bool()
157 |
158 | # ! prepare "input/agent"
159 | # [n_scene, n_agent, 3]
160 | batch["input/agent_pose"] = torch.cat([batch["ref/pos"].squeeze(2), batch["ref/yaw"]], dim=-1)
161 | agent_pos_local = torch_pos2local(batch["sc/agent_pos"], batch["ref/pos"], batch["ref/rot"])
162 | agent_vel_local = torch_dir2local(batch["sc/agent_vel"], batch["ref/rot"])
163 | agent_yaw_local = torch_rad2local(batch["sc/agent_yaw_bbox"], batch["ref/yaw"], cast=False)
164 | if self.pl_aggr: # [n_scene, n_agent, agent_attr_dim]
165 | agent_invalid = ~batch["input/agent_valid"].unsqueeze(-1)
166 | agent_invalid_reduced = agent_invalid.all(-2) # [n_scene, n_agent, 1]
167 | batch["input/agent_attr"] = torch.cat(
168 | [
169 | self.pose_pe_agent(agent_pos_local, agent_yaw_local).masked_fill(agent_invalid, 0).flatten(-2, -1),
170 | agent_vel_local.masked_fill(agent_invalid, 0).flatten(-2, -1), # n_step_hist*2
171 | batch["sc/agent_spd"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
172 | batch["sc/agent_yaw_rate"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
173 | batch["sc/agent_acc"].masked_fill(agent_invalid, 0).squeeze(-1), # n_step_hist
174 | batch["sc/agent_size"].masked_fill(agent_invalid_reduced, 0), # 3
175 | batch["sc/agent_type"].masked_fill(agent_invalid_reduced, 0), # 3
176 | batch["input/agent_valid"], # n_step_hist
177 | ],
178 | dim=-1,
179 | )
180 | batch["input/agent_valid"] = batch["input/agent_valid"].any(-1) # [n_scene, n_agent]
181 | else: # [n_scene, n_agent, n_step_hist, agent_attr_dim]
182 | batch["input/agent_attr"] = torch.cat(
183 | [
184 | self.pose_pe_agent(agent_pos_local, agent_yaw_local),
185 | agent_vel_local, # vel xy, 2
186 | batch["sc/agent_spd"], # speed, 1
187 | batch["sc/agent_yaw_rate"], # yaw rate, 1
188 | batch["sc/agent_acc"], # acc, 1
189 | batch["sc/agent_size"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3
190 | batch["sc/agent_type"].unsqueeze(-2).expand(-1, -1, self.n_step_hist, -1), # 3
191 | ],
192 | dim=-1,
193 | )
194 |
195 | # ! prepare "input/map"
196 | # [n_scene, n_pl]
197 | pl_yaw = torch.atan2(batch["sc/map_dir"][:, :, 0, 1], batch["sc/map_dir"][:, :, 0, 0])
198 | # [n_scene, n_pl, 3]
199 | batch["input/map_pose"] = torch.cat([batch["sc/map_pos"][:, :, 0], pl_yaw.unsqueeze(-1)], dim=-1)
200 | # batch["sc/map_pos"], batch["sc/map_dir"]: [n_scene, n_pl, n_pl_node, 2]
201 | pl_rot = torch_rad2rot(pl_yaw) # [n_scene, n_pl, 2, 2]
202 | map_pos_local = torch_pos2local(batch["sc/map_pos"], batch["sc/map_pos"][:, :, [0]], pl_rot)
203 | map_dir_local = torch_dir2local(batch["sc/map_dir"], pl_rot)
204 | if self.pl_aggr: # [n_scene, n_pl, map_attr_dim]
205 | map_invalid = ~batch["input/map_valid"].unsqueeze(-1) # [n_scene, n_pl, n_pl_node, 1]
206 | map_invalid_reduced = map_invalid.all(-2) # [n_scene, n_pl, 1]
207 | batch["input/map_attr"] = torch.cat(
208 | [
209 | self.pose_pe_map(map_pos_local, map_dir_local).masked_fill(map_invalid, 0).flatten(-2, -1),
210 | batch["sc/map_type"].masked_fill(map_invalid_reduced, 0), # n_pl_type
211 | batch["input/map_valid"], # n_pl_node
212 | ],
213 | dim=-1,
214 | )
215 | batch["input/map_valid"] = batch["input/map_valid"].any(-1) # [n_scene, n_pl]
216 | else: # [n_scene, n_pl, n_pl_node, map_attr_dim]
217 | batch["input/map_attr"] = torch.cat(
218 | [
219 | self.pose_pe_map(map_pos_local, map_dir_local), # pl_dim
220 | batch["sc/map_type"].unsqueeze(-2).expand(-1, -1, self.n_pl_node, -1), # n_pl_type
221 | ],
222 | dim=-1,
223 | )
224 |
225 | # ! prepare "input/tl_attr": [n_scene, n_step_hist/1, n_tl, tl_attr_dim]
226 | # [n_scene, n_step_hist, n_tl, 2]
227 | tl_pos = batch["sc/tl_pos"]
228 | tl_dir = batch["sc/tl_dir"]
229 | tl_state = batch["sc/tl_state"]
230 | if self.use_current_tl:
231 | tl_pos = tl_pos[:, [-1]] # [n_scene, 1, n_tl, 2]
232 | tl_dir = tl_dir[:, [-1]] # [n_scene, 1, n_tl, 2]
233 | tl_state = tl_state[:, [-1]] # [n_scene, 1, n_tl, 5]
234 | batch["input/tl_valid"] = batch["input/tl_valid"][:, [-1]] # [n_scene, 1, n_tl]
235 | tl_yaw = torch.atan2(tl_dir[..., 1], tl_dir[..., 0]) # [n_scene, n_step_hist/1, n_tl]
236 | batch["input/tl_pose"] = torch.cat([tl_pos, tl_yaw.unsqueeze(-1)], dim=-1)
237 | batch["input/tl_attr"] = tl_state.type_as(batch["sc/map_pos"]) # to float
238 |
239 | # ! add one-hot encoding for sequence (temporal, order of polyline nodes)
240 | if self.add_ohe:
241 | n_scene, n_agent, _ = batch["sc/agent_valid"].shape
242 | n_pl = batch["sc/map_valid"].shape[1]
243 | if not self.pl_aggr: # there is no need to add ohe if self.pl_aggr
244 | batch["input/agent_attr"] = torch.cat(
245 | [
246 | batch["input/agent_attr"],
247 | self.history_step_ohe[None, None, :, :].expand(n_scene, n_agent, -1, -1),
248 | ],
249 | dim=-1,
250 | )
251 | batch["input/map_attr"] = torch.cat(
252 | [batch["input/map_attr"], self.pl_node_ohe[None, None, :, :].expand(n_scene, n_pl, -1, -1),],
253 | dim=-1,
254 | )
255 |
256 | if not self.use_current_tl: # there is no need to add ohe if use_current_tl
257 | n_tl = batch["input/tl_valid"].shape[-1]
258 | batch["input/tl_attr"] = torch.cat(
259 | [batch["input/tl_attr"], self.history_step_ohe[None, :, None, :].expand(n_scene, -1, n_tl, -1)],
260 | dim=-1,
261 | )
262 |
263 | return batch
264 |
--------------------------------------------------------------------------------