├── .gitignore
├── README.md
├── config
├── custom_trainer
│ └── pluto_trainer.yaml
├── data_augmentation
│ └── contrastive_scenario_generator.yaml
├── default_simulation.yaml
├── default_training.yaml
├── lightning
│ └── custom_lightning.yaml
├── model
│ └── pluto_model.yaml
├── planner
│ └── pluto_planner.yaml
├── scenario_filter
│ ├── mini_demo_scenario.yaml
│ ├── random14_benchmark.yaml
│ ├── training_scenarios_1M.yaml
│ ├── training_scenarios_tiny.yaml
│ ├── val14_benchmark.yaml
│ └── val_demo_scenario.yaml
└── training
│ └── train_pluto.yaml
├── requirements.txt
├── run_simulation.py
├── run_training.py
├── script
├── run_pluto_planner.sh
└── setup_env.sh
└── src
├── custom_training
├── custom_datamodule.py
└── custom_training_builder.py
├── data_augmentation
└── contrastive_scenario_generator.py
├── feature_builders
├── common.py
├── nuplan_scenario_render.py
└── pluto_feature_builder.py
├── features
└── pluto_feature.py
├── metrics
├── __init__.py
├── min_ade.py
├── min_fde.py
├── mr.py
├── prediction_avg_ade.py
├── prediction_avg_fde.py
└── utils.py
├── models
└── pluto
│ ├── layers
│ ├── common_layers.py
│ ├── embedding.py
│ ├── fourier_embedding.py
│ ├── mlp_layer.py
│ └── transformer.py
│ ├── loss
│ └── esdf_collision_loss.py
│ ├── modules
│ ├── agent_encoder.py
│ ├── agent_predictor.py
│ ├── map_encoder.py
│ ├── planning_decoder.py
│ └── static_objects_encoder.py
│ ├── pluto_model.py
│ └── pluto_trainer.py
├── optim
└── warmup_cos_lr.py
├── planners
├── ml_planner_utils.py
└── pluto_planner.py
├── post_processing
├── common
│ ├── enum.py
│ └── geometry.py
├── emergency_brake.py
├── evaluation
│ └── comfort_metrics.py
├── forward_simulation
│ ├── batch_kinematic_bicycle.py
│ ├── batch_lqr.py
│ ├── batch_lqr_utils.py
│ └── forward_simulator.py
├── observation
│ └── world_from_prediction.py
└── trajectory_evaluator.py
├── scenario_manager
├── cost_map_manager.py
├── occupancy_map.py
├── route_manager.py
├── scenario_manager.py
└── utils
│ ├── bfs_roadblock.py
│ ├── dijkstra.py
│ └── route_utils.py
└── utils
├── collision_checker.py
├── utils.py
└── vis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | parts/
18 | sdist/
19 | var/
20 | wheels/
21 | share/python-wheels/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 | MANIFEST
26 | .DS_Store
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .nox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | *.py,cover
49 | .hypothesis/
50 | .pytest_cache/
51 | cover/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | .pybuilder/
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | # For a library or package, you might want to ignore these files since the code is
86 | # intended to run in multiple environments; otherwise, check them in:
87 | # .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # poetry
97 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
98 | # This is especially recommended for binary packages to ensure reproducibility, and is more
99 | # commonly ignored for libraries.
100 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
101 | #poetry.lock
102 |
103 | # pdm
104 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
105 | #pdm.lock
106 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
107 | # in version control.
108 | # https://pdm.fming.dev/#use-with-ide
109 | .pdm.toml
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .venv
123 | # env/
124 | venv/
125 | ENV/
126 | env.bak/
127 | venv.bak/
128 |
129 | # Spyder project settings
130 | .spyderproject
131 | .spyproject
132 |
133 | # Rope project settings
134 | .ropeproject
135 |
136 | # mkdocs documentation
137 | /site
138 |
139 | # mypy
140 | .mypy_cache/
141 | .dmypy.json
142 | dmypy.json
143 |
144 | # Pyre type checker
145 | .pyre/
146 |
147 | # pytype static type analyzer
148 | .pytype/
149 |
150 | # Cython debug symbols
151 | cython_debug/
152 |
153 | # PyCharm
154 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
156 | # and can be added to the global gitignore or merged into this file. For a more nuclear
157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
158 | #.idea/
159 |
160 | .vscode
161 | wandb/
162 | outputs/
163 | *.ckpt
164 | checkpoints/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PLUTO
2 |
3 | This is the official repository of
4 |
5 | **PLUTO: Push the Limit of Imitation Learning-based Planning for Autonomous Driving**,
6 |
7 | [Jie Cheng](https://jchengai.github.io/), [Yingbing Chen](https://sites.google.com/view/chenyingbing-homepage), and [Qifeng Chen](https://cqf.io/)
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Setup Environment
20 |
21 | ### Setup dataset
22 |
23 | Setup the nuPlan dataset following the [offiical-doc](https://nuplan-devkit.readthedocs.io/en/latest/dataset_setup.html)
24 |
25 | ### Setup conda environment
26 |
27 | ```
28 | conda create -n pluto python=3.9
29 | conda activate pluto
30 |
31 | # install nuplan-devkit
32 | git clone https://github.com/motional/nuplan-devkit.git && cd nuplan-devkit
33 | pip install -e .
34 | pip install -r ./requirements.txt
35 |
36 | # setup pluto
37 | cd ..
38 | git clone https://github.com/jchengai/pluto.git && cd pluto
39 | sh ./script/setup_env.sh
40 | ```
41 |
42 | ## Feature Cache
43 |
44 | Preprocess the dataset to accelerate training. It is recommended to run a small sanity check to make sure everything is correctly setup.
45 |
46 | ```
47 | python run_training.py \
48 | py_func=cache +training=train_pluto \
49 | scenario_builder=nuplan_mini \
50 | cache.cache_path=/nuplan/exp/sanity_check \
51 | cache.cleanup_cache=true \
52 | scenario_filter=training_scenarios_tiny \
53 | worker=sequential
54 | ```
55 |
56 | Then preprocess the whole nuPlan training set (this will take some time). You may need to change `cache.cache_path` to suit your condition
57 |
58 | ```
59 | export PYTHONPATH=$PYTHONPATH:$(pwd)
60 |
61 | python run_training.py \
62 | py_func=cache +training=train_pluto \
63 | scenario_builder=nuplan \
64 | cache.cache_path=/nuplan/exp/cache_pluto_1M \
65 | cache.cleanup_cache=true \
66 | scenario_filter=training_scenarios_1M \
67 | worker.threads_per_node=40
68 | ```
69 |
70 | ## Training
71 |
72 | (The training part it not fully tested)
73 |
74 | Same, it is recommended to run a sanity check first:
75 |
76 | ```
77 | CUDA_VISIBLE_DEVICES=0 python run_training.py \
78 | py_func=train +training=train_pluto \
79 | worker=single_machine_thread_pool worker.max_workers=4 \
80 | scenario_builder=nuplan cache.cache_path=/nuplan/exp/sanity_check cache.use_cache_without_dataset=true \
81 | data_loader.params.batch_size=4 data_loader.params.num_workers=1
82 | ```
83 |
84 | Training on the full dataset (without CIL):
85 |
86 | ```
87 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_training.py \
88 | py_func=train +training=train_pluto \
89 | worker=single_machine_thread_pool worker.max_workers=32 \
90 | scenario_builder=nuplan cache.cache_path=/nuplan/exp/cache_pluto_1M cache.use_cache_without_dataset=true \
91 | data_loader.params.batch_size=32 data_loader.params.num_workers=16 \
92 | lr=1e-3 epochs=25 warmup_epochs=3 weight_decay=0.0001 \
93 | wandb.mode=online wandb.project=nuplan wandb.name=pluto
94 | ```
95 |
96 | - add option `model.use_hidden_proj=true +custom_trainer.use_contrast_loss=true` to enable CIL.
97 |
98 | - you can remove wandb related configurations if your prefer tensorboard.
99 |
100 |
101 | ## Checkpoint
102 |
103 | Download and place the checkpoint in the `pluto/checkpoints` folder.
104 |
105 | | Model | Download |
106 | | ---------------- | -------- |
107 | | Pluto-1M-aux-cil | [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchengai_connect_ust_hk/EaFpLwwHFYVKsPVLH2nW5nEBNbPS7gqqu_Rv2V1dzODO-Q?e=LAZQcI) |
108 |
109 |
110 | ## Run Pluto-planner simulation
111 |
112 | Run simulation for a random scenario in the nuPlan-mini split
113 |
114 | ```
115 | sh ./script/run_pluto_planner.sh pluto_planner nuplan_mini mini_demo_scenario pluto_1M_aux_cil.ckpt /dir_to_save_the_simulation_result_video
116 | ```
117 |
118 | The rendered simulation video will be saved to the specified directory (need change `/dir_to_save_the_simulation_result_video`).
119 |
120 | ## To Do
121 |
122 | The code is under cleaning and will be released gradually.
123 |
124 | - [ ] improve docs
125 | - [x] training code
126 | - [x] visualization
127 | - [x] pluto-planner & checkpoint
128 | - [x] feature builder & model
129 | - [x] initial repo & paper
130 |
131 | ## Citation
132 |
133 | If you find this repo useful, please consider giving us a star 🌟 and citing our related paper.
134 |
135 | ```bibtex
136 | @article{cheng2024pluto,
137 | title={PLUTO: Pushing the Limit of Imitation Learning-based Planning for Autonomous Driving},
138 | author={Cheng, Jie and Chen, Yingbing and Chen, Qifeng},
139 | journal={arXiv preprint arXiv:2404.14327},
140 | year={2024}
141 | }
142 | ```
--------------------------------------------------------------------------------
/config/custom_trainer/pluto_trainer.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.pluto.pluto_trainer.LightningTrainer
--------------------------------------------------------------------------------
/config/data_augmentation/contrastive_scenario_generator.yaml:
--------------------------------------------------------------------------------
1 | contrastive_augmentation:
2 | _target_: src.data_augmentation.contrastive_scenario_generator.ContrastiveScenarioGenerator
3 | _convert_: "all"
4 |
5 | history_steps: 21
6 | low: [-1.0, -0.75, -0.35, -1, -0.5, -0.2, -0.1]
7 | high: [1.0, 0.75, 0.35, 1, 0.5, 0.2, 0.1]
8 | max_interaction_horizon: 40
9 |
--------------------------------------------------------------------------------
/config/default_simulation.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${output_dir}
4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging
5 | searchpath: # Only in these paths are discoverable
6 | - pkg://nuplan.planning.script.config.common
7 | - pkg://nuplan.planning.script.config.simulation
8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/
9 | - config/simulation
10 | - config/scenario_filter
11 |
12 | defaults:
13 | # Add ungrouped items
14 | - default_experiment
15 | - default_common
16 | - default_submission
17 |
18 | - simulation_metric:
19 | - default_metrics
20 | - callback:
21 | - simulation_log_callback
22 | - main_callback:
23 | - time_callback
24 | - metric_file_callback
25 | - metric_aggregator_callback
26 | - metric_summary_callback
27 | - splitter: nuplan
28 |
29 | # Hyperparameters need to be specified
30 | - observation: null
31 | - ego_controller: null
32 | - planner: null
33 | - simulation_time_controller: step_simulation_time_controller
34 | - metric_aggregator:
35 | - default_weighted_average
36 |
37 | - override hydra/job_logging: none # Disable hydra's logging
38 | - override hydra/hydra_logging: none # Disable hydra's logging
39 |
40 | experiment_name: 'simulation'
41 | aggregated_metric_folder_name: 'aggregator_metric' # Aggregated metric folder name
42 | aggregator_save_path: ${output_dir}/${aggregated_metric_folder_name}
43 |
44 |
45 | # Progress Visualization
46 | enable_simulation_progress_bar: true # Show for every simulation its progress
47 |
48 | # Simulation Setup
49 | simulation_history_buffer_duration: 2.0 # [s] The look back duration to initialize the simulation history buffer with
50 |
51 | # Number (or fractional, e.g., 0.25) of GPUs available for single simulation (per scenario and planner).
52 | # This number can also be < 1 because we allow multiple models to be loaded into a single GPU.
53 | # In case this number is 0 or null, no GPU is used for simulation and all cpu cores are leveraged
54 | # Note, that the user have to make sure that if a number < 1 is chosen, the model will fit 1 / num_gpus into GPU memory
55 | number_of_gpus_allocated_per_simulation: 1
56 |
57 | # This number specifies number of CPU threads that are used for simulation
58 | # In case this is null, then each simulation will use unlimited resources.
59 | # That will typically swamp the host computer, leading to slowdowns and failure.
60 | number_of_cpus_allocated_per_simulation: 1
61 |
62 | # Set false to disable metric computation
63 | run_metric: true
64 |
65 | # Set to rerun metrics with existing simulation logs without setting run_metric to false.
66 | simulation_log_main_path: null
67 |
68 | # If false, continue running the simulation even it a scenario has failed
69 | exit_on_failure: false
70 |
71 | # Maximum number of workers to be used for running simulation callbacks outside the main process
72 | max_callback_workers: 4
73 |
74 | # Disable callback parallelization when using the Sequential worker. By default, when running with the sequential worker,
75 | # on_simulation_end callbacks are not submitted to a parallel worker.
76 | disable_callback_parallelization: true
77 |
78 | # Distributed processing mode. If multi-node simulation is enable, this parameter selects how the scenarios distributed
79 | # to each node. The modes are:
80 | # - SCENARIO_BASED: Works in two stages, first getting a list of all, scenarios to process, then breaking up that
81 | # list and distributing across the workers
82 | # - LOG_FILE_BASED: Works in a single stage, breaking up the scenarios based on what log file they are in and
83 | # distributing the number of log files evenly across all workers
84 | # - SINGLE_NODE: Does no distribution, processes all scenarios in config
85 | distributed_mode: 'SINGLE_NODE'
--------------------------------------------------------------------------------
/config/default_training.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${output_dir}
4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging
5 | searchpath: # Only in these paths are discoverable
6 | - pkg://nuplan.planning.script.config.common
7 | - pkg://nuplan.planning.script.config.training
8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/
9 | - config/training
10 |
11 |
12 | defaults:
13 | - default_experiment
14 | - default_common
15 |
16 | # Trainer and callbacks
17 | - lightning: custom_lightning
18 | - callbacks: default_callbacks
19 |
20 | # Optimizer settings
21 | - optimizer: adam # [adam, adamw] supported optimizers
22 | - lr_scheduler: null # [one_cycle_lr] supported lr_schedulers
23 | - warm_up_lr_scheduler: null # [linear_warm_up, constant_warm_up] supported warm up lr schedulers
24 |
25 | # Data Loading
26 | - data_loader: default_data_loader
27 | - splitter: ???
28 |
29 | # Objectives and metrics
30 | - objective:
31 | - training_metric:
32 | - data_augmentation: null
33 | - data_augmentation_scheduler: null # [default_augmentation_schedulers, stepwise_augmentation_probability_scheduler, stepwise_noise_parameter_scheduler] supported data augmentation schedulers
34 | - scenario_type_weights: default_scenario_type_weights
35 | - custom_trainer: null
36 |
37 | nuplan_trainer: false
38 | experiment_name: 'training'
39 | objective_aggregate_mode: ??? # How to aggregate multiple objectives, can be 'mean', 'max', 'sum'
40 |
41 | # Cache parameters
42 | cache:
43 | cache_path: # Local/remote path to store all preprocessed artifacts from the data pipeline
44 | use_cache_without_dataset: false # Load all existing features from a local/remote cache without loading the dataset
45 | force_feature_computation: false # Recompute features even if a cache exists
46 | cleanup_cache: false # Cleanup cached data in the cache_path, this ensures that new data are generated if the same cache_path is passed
47 |
48 | # Mandatory parameters
49 | py_func: ??? # Function to be run inside main (can be "train", "test", "cache")
50 | epochs: 25
51 | warmup_epochs: 3
52 | lr: 1e-3
53 | weight_decay: 0.0001
54 | checkpoint:
55 |
56 | # wandb settings
57 | wandb:
58 | mode: disable
59 | project: nuplan-pluto
60 | name: ${experiment_name}
61 | log_model: all
62 | artifact:
63 | run_id:
64 |
--------------------------------------------------------------------------------
/config/lightning/custom_lightning.yaml:
--------------------------------------------------------------------------------
1 | distributed_training:
2 | equal_variance_scaling_strategy: true # scales lr and betas either linearly if false (multiply by num GPUs) or with equal_variance if true (multiply by sqaure root of num GPUs)
3 |
4 | trainer:
5 | checkpoint:
6 | resume_training: false # load the model from the last epoch and resume training
7 | save_top_k: 5 # save the top K models in terms of performance
8 | monitor: loss/val_loss # metric to monitor for performance
9 | mode: min # minimize/maximize metric
10 |
11 | params:
12 | # max_time: 00:16:00:00 # training time before the process is terminated
13 |
14 | max_epochs: ${epochs} # maximum number of training epochs
15 | # check_val_every_n_epoch: 1 # run validation set every n training epochs
16 | val_check_interval: 1.0 # [%] run validation set every X% of training set
17 |
18 | limit_train_batches: # how much of training dataset to check (float = fraction, int = num_batches)
19 | limit_val_batches: # how much of validation dataset to check (float = fraction, int = num_batches)
20 | limit_test_batches: # how much of test dataset to check (float = fraction, int = num_batches)
21 |
22 | devices: -1 # number of GPUs to utilize (-1 means all available GPUs)
23 | accelerator: gpu # distribution method
24 | precision: 32 # floating point precision
25 | # amp_level: O2 # AMP optimization level
26 | # num_nodes: 1 # Number of nodes used for training
27 |
28 | # auto_scale_batch_size: false
29 | # auto_lr_find: false # tunes LR before beginning training
30 | # terminate_on_nan: true # terminates training if a nan is encountered in loss/weights
31 |
32 | num_sanity_val_steps: 1 # number of validation steps to run before training begins
33 | fast_dev_run: false # runs 1 batch of train/val/test for sanity
34 |
35 | # accumulate_grad_batches: 1 # accumulates gradients every n batches
36 | # track_grad_norm: -1 # logs the p-norm for inspection
37 | gradient_clip_val: 5.0 # value to clip gradients
38 | gradient_clip_algorithm: norm # [value, norm] method to clip gradients
39 | sync_batchnorm: true
40 | strategy: ddp_find_unused_parameters_false
41 |
42 | # checkpoint_callback: true # enable default checkpoint
43 |
44 | overfitting:
45 | enable: false # run an overfitting test instead of training
46 |
47 | params:
48 | max_epochs: 150 # number of epochs to overfit the same batches
49 | overfit_batches: 1 # number of batches to overfit
50 |
--------------------------------------------------------------------------------
/config/model/pluto_model.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.pluto.pluto_model.PlanningModel
2 | _convert_: "all"
3 |
4 | dim: 128
5 | state_channel: 6
6 | polygon_channel: 6
7 | history_channel: 9
8 | history_steps: 21
9 | future_steps: 80
10 | encoder_depth: 4
11 | decoder_depth: 4
12 | drop_path: 0.2
13 | dropout: 0.1
14 | num_heads: 4
15 | num_modes: 12
16 | state_dropout: 0.75
17 | use_ego_history: false
18 | state_attn_encoder: true
19 | use_hidden_proj: false
20 |
21 | feature_builder:
22 | _target_: src.feature_builders.pluto_feature_builder.PlutoFeatureBuilder
23 | _convert_: "all"
24 | radius: 120
25 | history_horizon: 2
26 | future_horizon: 8
27 | sample_interval: 0.1
28 | max_agents: 48
29 | build_reference_line: true
30 |
--------------------------------------------------------------------------------
/config/planner/pluto_planner.yaml:
--------------------------------------------------------------------------------
1 | pluto_planner:
2 | _target_: src.planners.pluto_planner.PlutoPlanner
3 | _convert_: "all"
4 |
5 | render: false
6 | eval_dt: 0.1
7 | eval_num_frames: 40
8 | candidate_subsample_ratio: 1.0
9 | candidate_min_num: 10
10 | learning_based_score_weight: 0.3
11 |
12 | planner:
13 | _target_: src.models.pluto.pluto_model.PlanningModel
14 | _convert_: "all"
15 |
16 | dim: 128
17 | state_channel: 6
18 | polygon_channel: 6
19 | history_channel: 9
20 | history_steps: 21
21 | future_steps: 80
22 | encoder_depth: 4
23 | decoder_depth: 4
24 | drop_path: 0.2
25 | dropout: 0.1
26 | num_heads: 4
27 | num_modes: 12
28 | state_dropout: 0.75
29 | use_ego_history: false
30 | state_attn_encoder: true
31 | use_hidden_proj: true
32 | cat_x: true
33 | ref_free_traj: true
34 |
35 | feature_builder:
36 | _target_: src.feature_builders.pluto_feature_builder.PlutoFeatureBuilder
37 | _convert_: "all"
38 | radius: 120
39 | history_horizon: 2
40 | future_horizon: 8
41 | sample_interval: 0.1
42 | max_agents: 48
43 | build_reference_line: true
44 |
45 | planner_ckpt:
46 |
--------------------------------------------------------------------------------
/config/scenario_filter/mini_demo_scenario.yaml:
--------------------------------------------------------------------------------
1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter
2 | _convert_: 'all'
3 |
4 | scenario_types: null # List of scenario types to include
5 | scenario_tokens: null # List of scenario tokens to include
6 |
7 | log_names: null # Filter scenarios by log names
8 | map_names: null # Filter scenarios by map names
9 |
10 | num_scenarios_per_type: 1 # Number of scenarios per type
11 | limit_total_scenarios: 1 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type
12 | timestamp_threshold_s: 15 # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps
13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount
14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below
15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above
16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise.
17 |
18 |
19 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios
20 | remove_invalid_goals: True # Whether to remove scenarios where the mission goal is invalid
21 | shuffle: true # Whether to shuffle the scenarios
22 |
--------------------------------------------------------------------------------
/config/scenario_filter/random14_benchmark.yaml:
--------------------------------------------------------------------------------
1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter
2 | _convert_: all
3 | scenario_types:
4 | - starting_left_turn
5 | - starting_right_turn
6 | - starting_straight_traffic_light_intersection_traversal
7 | - stopping_with_lead
8 | - high_lateral_acceleration
9 | - high_magnitude_speed
10 | - low_magnitude_speed
11 | - traversing_pickup_dropoff
12 | - waiting_for_pedestrian_to_cross
13 | - behind_long_vehicle
14 | - stationary_in_traffic
15 | - near_multiple_vehicles
16 | - changing_lane
17 | - following_lane_with_lead
18 |
19 | scenario_tokens:
20 | - "09ca86cc9ae65428"
21 | - d77598de12fb5f46
22 | - "011118ec4f9952bc"
23 | - e9faae87fb83540d
24 | - "7b28cbcdcae35e7e"
25 | - "80c56ad735545de3"
26 | - "9031cf5175f253cf"
27 | - "22022119f3bd53f3"
28 | - ac67e458c88d5d78
29 | - "9cab2c90252c549b"
30 | - a639bd510b0d5b0a
31 | - "9e507708119c5596"
32 | - "4f9321ffbcb95b55"
33 | - a0f50ac13caa51ac
34 | - "843eb1eee80b529e"
35 | - "5051ccd75314515c"
36 | - "1939feda817551d5"
37 | - "281b1fe38bdc5467"
38 | - c9d8621d7c2255e6
39 | - "8b8cdc059d585494"
40 | - "6c6ddf6740d355ef"
41 | - "74e550e274f959a4"
42 | - "0393fbf49cd55295"
43 | - f00be015e344557c
44 | - c00c5e8e89e3571c
45 | - "33ef46dbce1c577f"
46 | - "69b83d7c62d65006"
47 | - "7cc26b3af3b0557f"
48 | - dd117e2e2dfa560d
49 | - "7280988648d05358"
50 | - "9a81702c38d757b8"
51 | - a88d3779eb535274
52 | - "27fb96d75d4759a2"
53 | - "5dec4c8bfd49502f"
54 | - "942f6862e6055f62"
55 | - "72038bf480a55042"
56 | - e3342e2e8b535405
57 | - "07874767a6d55ad7"
58 | - f1cd02f2371c5c67
59 | - "2e56997063c057a0"
60 | - "382951755d8e5e77"
61 | - "2f69576b55305cc8"
62 | - f9e6e7fc604b5ff9
63 | - "24cc8621c8595a92"
64 | - "0c88901b6cfb53d0"
65 | - "93c583b46398560e"
66 | - "7835a781f6bd5688"
67 | - "4aeaa8bebf7a57f1"
68 | - "33624e1585945551"
69 | - e5d69d5a0b135831
70 | - "711c54430cae590f"
71 | - "29b77c53a46956af"
72 | - "7d094d765a8f5ffa"
73 | - a1d2f54d8ec1564c
74 | - af3681f002005504
75 | - a26d34c327365e71
76 | - "15a3339778d058db"
77 | - "81ecd0ddcc4e5c1f"
78 | - "770a7cb9f1585e38"
79 | - "7f50df6eea5c54e9"
80 | - b7485487d49853ca
81 | - cb153e6cd19851ba
82 | - "4e0c8ec081805052"
83 | - b1f9cfc8ac885d91
84 | - "4eaa9a0e22c95491"
85 | - "812679cd23f45d5d"
86 | - "7be9d726f1be5003"
87 | - dc98be51b7155407
88 | - aff9a658a9e450cb
89 | - "845645e7c5665590"
90 | - d3dfcbb4c7be5e8b
91 | - "33a3eb36641857e6"
92 | - a489a4faafbd5c67
93 | - "1ce0f41f0d9754f5"
94 | - "2df19b9bf2fe5c88"
95 | - "0908906ecbed58df"
96 | - d7a3ae0854575652
97 | - "29a854ae1aee58ff"
98 | - "64986687478f5338"
99 | - cada319ebf2b505a
100 | - a874cefaca8e5b69
101 | - "5e33493023b75f33"
102 | - "356d80474d1252c0"
103 | - "0eb1eb9046925cf8"
104 | - bc2485198080598a
105 | - b342971fb07451f9
106 | - "098144ba356652a5"
107 | - "33b5947ae3d75a80"
108 | - b086f30875e65add
109 | - b0f8ae874f2b5e78
110 | - "38e301b3cbad5707"
111 | - e5ef6199dc3b5909
112 | - "4236a1f3c9e35f9d"
113 | - ccdca0ce28565318
114 | - "2664bb314a155875"
115 | - d8633b984d75530e
116 | - d21b22d6c0405ad7
117 | - f3895453b6c35e51
118 | - dca16f92a1015a7f
119 | - "9fd5ec2b453d556e"
120 | - "619533b026ce59f1"
121 | - d5a09eb525e8592a
122 | - "48b0081cd2385220"
123 | - "35dd043cd64e5537"
124 | - cc4bc91994df5424
125 | - "522cda9d26b1526a"
126 | - "8160c5b3a5c9555e"
127 | - "704c05ebeb5959e6"
128 | - "78c5b673402b5861"
129 | - "34a52b2f68f357df"
130 | - a614ad720a76576c
131 | - b6f238f681bf5a09
132 | - b3b422a58ff35545
133 | - f0be7cf0e03e58db
134 | - c1027e45d3b956e4
135 | - "99c06a975b465903"
136 | - e51649556f215dd9
137 | - b526edea27bc5f93
138 | - e9d360ea046554d6
139 | - e7fc2f835ea95ece
140 | - "2fd784c6bf4f572c"
141 | - "28d5868f9d6e5035"
142 | - "6c37bc8ac424562b"
143 | - "2f8c75dfca3f50b9"
144 | - ead9590e094d5a8f
145 | - "5af50221d55658e4"
146 | - "65af9ce73917587a"
147 | - "0cc3d9bb137b5a2b"
148 | - "7e4a2d822a2e55e0"
149 | - bd5783916d995801
150 | - fa10175cddfa52fa
151 | - e095899901a75a56
152 | - "936b015ebca15109"
153 | - "2f149371d38e5806"
154 | - "40bcab2def635436"
155 | - a438e0b5c26351dd
156 | - ee479ca0620c5077
157 | - c3bdac749eeb5206
158 | - "3c0e98e1b77b5d44"
159 | - c28d662b14b05c83
160 | - "2be476b6b8db5693"
161 | - "2574c2dbef5850f1"
162 | - "144473d8f03a57f4"
163 | - bff3663d750c5ee1
164 | - "541899555b645329"
165 | - "0d05468e475657b8"
166 | - "9eb227aec6dc5e23"
167 | - "8fd7ccd629615f24"
168 | - ef202cd018ce55e1
169 | - "7dea428154ad5d8a"
170 | - "6582d5ea9e2e5fa8"
171 | - "1f5dc80046035f11"
172 | - "1a278d6ef4435e69"
173 | - "52e1da615adf57c4"
174 | - "0612d791f0825acd"
175 | - "88c66c9df39e5193"
176 | - "9bb232eead8e525d"
177 | - db0204e9603f59e9
178 | - da64279ff31552d1
179 | - "462120b01b425182"
180 | - e257ace5273058ae
181 | - "660d375c109f5eed"
182 | - "9bdc15b26c455633"
183 | - c7be4ae92b455fa3
184 | - "84008abb23955152"
185 | - "957e6a7dec135b62"
186 | - "29686b1e8e6859ca"
187 | - "56726e1f05ae5164"
188 | - b85af6bc62cf5e03
189 | - "341ca59d11c15742"
190 | - eaef31406a205542
191 | - "63e48fa44d7c5aa4"
192 | - "713972db4cf35d09"
193 | - "3452d2d6fb655c2b"
194 | - "8efa8475e103598e"
195 | - e5fd3465d00b57ae
196 | - "8de10fd86b825304"
197 | - "477820688c4f5683"
198 | - "8ec9c713f4fc5d52"
199 | - "9fb968b3cea6507a"
200 | - a762a67013d4550e
201 | - "714dffb0115b5071"
202 | - "17e261bca6435550"
203 | - "561e8b2703f651f5"
204 | - "96074214b9645952"
205 | - "094db8d50cf95250"
206 | - "61d9e27ce6ee5a38"
207 | - daaaa1360bcf5451
208 | - c167ed0e96e455c4
209 | - "724cc15d7fc856df"
210 | - c700656703ec57bc
211 | - d6dde3e83ebe5b53
212 | - fa75094937835cf0
213 | - adc1c85043525786
214 | - e4bac3a9053955e2
215 | - "2fa51997ae9554cb"
216 | - "3ee6e16b65f956a8"
217 | - "1039c939079950f4"
218 | - "2e82153ebbed5d67"
219 | - bb0e6c0c719a5e7f
220 | - "7dc66aff7ac85f5f"
221 | - fd31b918c10b56e2
222 | - "0fa049c07b5b57c0"
223 | - bac6f1145eef5bf2
224 | - "6a1ffac82183523e"
225 | - "9f006885686852e7"
226 | - "3acc225a64965bd8"
227 | - a2ec5056da3c5c67
228 | - f69d0b59d08d512b
229 | - "3474e3d6c7485935"
230 | - c139820a97455f83
231 | - "93b3c95f3fea5812"
232 | - a09bce6cb8eb591b
233 | - c4f2b1f170eb5649
234 | - "9ba0cad0c3e7580f"
235 | - "1772467d58bb5dec"
236 | - "07db0457ad745d8e"
237 | - "94ce7eff6722572c"
238 | - b83ee0e949e45520
239 | - b497992774675304
240 | - "255e819738da54df"
241 | - "1b1aafb5916e5534"
242 | - "93e79e172de3521b"
243 | - c757a4faf12f52be
244 | - "155505762e1d5cc0"
245 | - "51933e2f44775d6a"
246 | - "071b206ea836546e"
247 | - f5d7ec43b9415862
248 | - "0ef918a26914564e"
249 | - "1c8be883a97a575e"
250 | - bd345bdaa3715b71
251 | - c998657ba1535677
252 | - c1c5dec6bab3598a
253 | - "7721f05072a85928"
254 | - "95495f60a1b659f3"
255 | - "1173af87b1e551ce"
256 | - "399986448051533a"
257 | - c299dae1f0745da8
258 | - beb18cddda575028
259 | - fcedb0da569c518b
260 | - a88f9e60e43c5ebc
261 | - f5a6ac1890c75f87
262 | - e7aa83d247dd5cdc
263 | - cfbbb77238bd541f
264 | - "62db31b428315ae8"
265 | - "0409c3925f245965"
266 | - "70086024fed658cf"
267 | - ceec28e1943f5d76
268 | - "3094b2c29265536a"
269 | log_names: null
270 | map_names: null
271 | num_scenarios_per_type: null
272 | limit_total_scenarios: null
273 | timestamp_threshold_s: 15
274 | ego_displacement_minimum_m: null
275 | ego_start_speed_threshold: null
276 | ego_stop_speed_threshold: null
277 | speed_noise_tolerance: null
278 | expand_scenarios: null
279 | remove_invalid_goals: true
280 | shuffle: false
281 |
--------------------------------------------------------------------------------
/config/scenario_filter/training_scenarios_1M.yaml:
--------------------------------------------------------------------------------
1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter
2 | _convert_: 'all'
3 |
4 | scenario_types: null # List of scenario types to include
5 | scenario_tokens: null # List of scenario tokens to include
6 |
7 | log_names: null # Filter scenarios by log names
8 | map_names: null # Filter scenarios by map names
9 |
10 | num_scenarios_per_type: null # Number of scenarios per type
11 | limit_total_scenarios: 1000000 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type
12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps
13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount
14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below
15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above
16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise.
17 |
18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios
19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid
20 | shuffle: false # Whether to shuffle the scenarios
21 |
--------------------------------------------------------------------------------
/config/scenario_filter/training_scenarios_tiny.yaml:
--------------------------------------------------------------------------------
1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter
2 | _convert_: 'all'
3 |
4 | scenario_types: null # List of scenario types to include
5 | scenario_tokens: null # List of scenario tokens to include
6 |
7 | log_names: null # Filter scenarios by log names
8 | map_names: null # Filter scenarios by map names
9 |
10 | num_scenarios_per_type: null # Number of scenarios per type
11 | limit_total_scenarios: 50 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type
12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps
13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount
14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below
15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above
16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise.
17 |
18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios
19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid
20 | shuffle: true # Whether to shuffle the scenarios
--------------------------------------------------------------------------------
/config/scenario_filter/val_demo_scenario.yaml:
--------------------------------------------------------------------------------
1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter
2 | _convert_: "all"
3 |
4 | scenario_types: null # List of scenario types to include
5 | scenario_tokens:
6 | - c556f4bc6f165a76
7 |
8 | log_names: # Filter scenarios by log names
9 | - 2021.07.24.00.36.59_veh-47_06810_07310
10 |
11 | map_names: null # Filter scenarios by map names
12 |
13 | num_scenarios_per_type: null # Number of scenarios per type
14 | limit_total_scenarios: null # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type
15 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps
16 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount
17 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below
18 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above
19 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise.
20 |
21 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios
22 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid
23 | shuffle: false # Whether to shuffle the scenarios
24 |
--------------------------------------------------------------------------------
/config/training/train_pluto.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | job_name: pluto
3 | py_func: train
4 | objective_aggregate_mode: mean
5 |
6 | defaults:
7 | - override /data_augmentation:
8 | - contrastive_scenario_generator
9 | - override /splitter: nuplan
10 | - override /model: pluto_model
11 | - override /scenario_filter: training_scenarios_tiny
12 | - override /custom_trainer: pluto_trainer
13 | - override /lightning: custom_lightning
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm
2 | pytorch-lightning==2.0.1
3 | torchmetrics==0.10.2
4 | tensorboard
5 | wandb==0.14.2
6 | numba
7 | rich==13.3.4
--------------------------------------------------------------------------------
/run_simulation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import pprint
4 | from pathlib import Path
5 | from shutil import rmtree
6 | from typing import List, Optional, Union
7 |
8 | import hydra
9 | import pandas as pd
10 | import pytorch_lightning as pl
11 | from nuplan.common.utils.s3_utils import is_s3_path
12 | from nuplan.planning.script.builders.simulation_builder import build_simulations
13 | from nuplan.planning.script.builders.simulation_callback_builder import (
14 | build_callbacks_worker,
15 | build_simulation_callbacks,
16 | )
17 | from nuplan.planning.script.utils import (
18 | run_runners,
19 | set_default_path,
20 | set_up_common_builder,
21 | )
22 | from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner
23 | from omegaconf import DictConfig, OmegaConf
24 |
25 | logging.basicConfig(level=logging.INFO)
26 | logger = logging.getLogger(__name__)
27 |
28 | # If set, use the env. variable to overwrite the default dataset and experiment paths
29 | set_default_path()
30 |
31 | # If set, use the env. variable to overwrite the Hydra config
32 | CONFIG_PATH = os.getenv("NUPLAN_HYDRA_CONFIG_PATH", "config/simulation")
33 |
34 |
35 | def print_simulation_results(file=None):
36 | if file is not None:
37 | df = pd.read_parquet(file)
38 | else:
39 | root = Path(os.getcwd()) / "aggregator_metric"
40 | result = list(root.glob("*.parquet"))
41 | result = max(result, key=lambda item: item.stat().st_ctime)
42 | df = pd.read_parquet(result)
43 | final_score = df[df["scenario"] == "final_score"]
44 | final_score = final_score.to_dict(orient="records")[0]
45 | pprint.PrettyPrinter(indent=4).pprint(final_score)
46 |
47 |
48 | def run_simulation(
49 | cfg: DictConfig,
50 | planners: Optional[Union[AbstractPlanner, List[AbstractPlanner]]] = None,
51 | ) -> None:
52 | """
53 | Execute all available challenges simultaneously on the same scenario. Helper function for main to allow planner to
54 | be specified via config or directly passed as argument.
55 | :param cfg: Configuration that is used to run the experiment.
56 | Already contains the changes merged from the experiment's config to default config.
57 | :param planners: Pre-built planner(s) to run in simulation. Can either be a single planner or list of planners.
58 | """
59 | # Fix random seed
60 | pl.seed_everything(cfg.seed, workers=True)
61 |
62 | profiler_name = "building_simulation"
63 | common_builder = set_up_common_builder(cfg=cfg, profiler_name=profiler_name)
64 |
65 | # Build simulation callbacks
66 | callbacks_worker_pool = build_callbacks_worker(cfg)
67 | callbacks = build_simulation_callbacks(
68 | cfg=cfg, output_dir=common_builder.output_dir, worker=callbacks_worker_pool
69 | )
70 |
71 | # Remove planner from config to make sure run_simulation does not receive multiple planner specifications.
72 | if planners and "planner" in cfg.keys():
73 | logger.info("Using pre-instantiated planner. Ignoring planner in config")
74 | OmegaConf.set_struct(cfg, False)
75 | cfg.pop("planner")
76 | OmegaConf.set_struct(cfg, True)
77 |
78 | # Construct simulations
79 | if isinstance(planners, AbstractPlanner):
80 | planners = [planners]
81 |
82 | runners = build_simulations(
83 | cfg=cfg,
84 | callbacks=callbacks,
85 | worker=common_builder.worker,
86 | pre_built_planners=planners,
87 | callbacks_worker=callbacks_worker_pool,
88 | )
89 |
90 | if common_builder.profiler:
91 | # Stop simulation construction profiling
92 | common_builder.profiler.save_profiler(profiler_name)
93 |
94 | logger.info("Running simulation...")
95 | run_runners(
96 | runners=runners,
97 | common_builder=common_builder,
98 | cfg=cfg,
99 | profiler_name="running_simulation",
100 | )
101 | logger.info("Finished running simulation!")
102 |
103 |
104 | def clean_up_s3_artifacts() -> None:
105 | """
106 | Cleanup lingering s3 artifacts that are written locally.
107 | This happens because some minor write-to-s3 functionality isn't yet implemented.
108 | """
109 | # Lingering artifacts get written locally to a 's3:' directory. Hydra changes
110 | # the working directory to a subdirectory of this, so we serach the working
111 | # path for it.
112 | working_path = os.getcwd()
113 | s3_dirname = "s3:"
114 | s3_ind = working_path.find(s3_dirname)
115 | if s3_ind != -1:
116 | local_s3_path = working_path[: working_path.find(s3_dirname) + len(s3_dirname)]
117 | rmtree(local_s3_path)
118 |
119 |
120 | @hydra.main(config_path="./config", config_name="default_simulation")
121 | def main(cfg: DictConfig) -> None:
122 | """
123 | Execute all available challenges simultaneously on the same scenario. Calls run_simulation to allow planner to
124 | be specified via config or directly passed as argument.
125 | :param cfg: Configuration that is used to run the experiment.
126 | Already contains the changes merged from the experiment's config to default config.
127 | """
128 | assert (
129 | cfg.simulation_log_main_path is None
130 | ), "Simulation_log_main_path must not be set when running simulation."
131 |
132 | run_simulation(cfg=cfg)
133 |
134 | if is_s3_path(Path(cfg.output_dir)):
135 | clean_up_s3_artifacts()
136 |
137 | print_simulation_results()
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/run_training.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Optional
3 |
4 | import hydra
5 | import numpy
6 | import pytorch_lightning as pl
7 | from nuplan.planning.script.builders.folder_builder import (
8 | build_training_experiment_folder,
9 | )
10 | from nuplan.planning.script.builders.logging_builder import build_logger
11 | from nuplan.planning.script.builders.worker_pool_builder import build_worker
12 | from nuplan.planning.script.profiler_context_manager import ProfilerContextManager
13 | from nuplan.planning.script.utils import set_default_path
14 | from nuplan.planning.training.experiments.caching import cache_data
15 | from omegaconf import DictConfig
16 |
17 | from src.custom_training import (
18 | TrainingEngine,
19 | build_training_engine,
20 | update_config_for_training,
21 | )
22 |
23 | logging.getLogger("numba").setLevel(logging.WARNING)
24 | logger = logging.getLogger(__name__)
25 |
26 | # If set, use the env. variable to overwrite the default dataset and experiment paths
27 | set_default_path()
28 |
29 | # If set, use the env. variable to overwrite the Hydra config
30 | CONFIG_PATH = "./config"
31 | CONFIG_NAME = "default_training"
32 |
33 |
34 | @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
35 | def main(cfg: DictConfig) -> Optional[TrainingEngine]:
36 | """
37 | Main entrypoint for training/validation experiments.
38 | :param cfg: omegaconf dictionary
39 | """
40 | pl.seed_everything(cfg.seed, workers=True)
41 |
42 | # Configure logger
43 | build_logger(cfg)
44 |
45 | # Override configs based on setup, and print config
46 | update_config_for_training(cfg)
47 |
48 | # Create output storage folder
49 | build_training_experiment_folder(cfg=cfg)
50 |
51 | # Build worker
52 | worker = build_worker(cfg)
53 |
54 | if cfg.py_func == "train":
55 | # Build training engine
56 | with ProfilerContextManager(
57 | cfg.output_dir, cfg.enable_profiling, "build_training_engine"
58 | ):
59 | engine = build_training_engine(cfg, worker)
60 |
61 | # Run training
62 | logger.info("Starting training...")
63 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "training"):
64 | engine.trainer.fit(
65 | model=engine.model,
66 | datamodule=engine.datamodule,
67 | ckpt_path=cfg.checkpoint,
68 | )
69 | return engine
70 | if cfg.py_func == "validate":
71 | # Build training engine
72 | with ProfilerContextManager(
73 | cfg.output_dir, cfg.enable_profiling, "build_training_engine"
74 | ):
75 | engine = build_training_engine(cfg, worker)
76 |
77 | # Run training
78 | logger.info("Starting training...")
79 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "validate"):
80 | engine.trainer.validate(
81 | model=engine.model,
82 | datamodule=engine.datamodule,
83 | ckpt_path=cfg.checkpoint,
84 | )
85 | return engine
86 | elif cfg.py_func == "test":
87 | # Build training engine
88 | with ProfilerContextManager(
89 | cfg.output_dir, cfg.enable_profiling, "build_training_engine"
90 | ):
91 | engine = build_training_engine(cfg, worker)
92 |
93 | # Test model
94 | logger.info("Starting testing...")
95 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "testing"):
96 | engine.trainer.test(model=engine.model, datamodule=engine.datamodule)
97 | return engine
98 | elif cfg.py_func == "cache":
99 | # Precompute and cache all features
100 | logger.info("Starting caching...")
101 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "caching"):
102 | cache_data(cfg=cfg, worker=worker)
103 | return None
104 | else:
105 | raise NameError(f"Function {cfg.py_func} does not exist")
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
--------------------------------------------------------------------------------
/script/run_pluto_planner.sh:
--------------------------------------------------------------------------------
1 | cwd=$(pwd)
2 | CKPT_ROOT="$cwd/checkpoints"
3 |
4 | PLANNER=$1
5 | BUILDER=$2
6 | FILTER=$3
7 | CKPT=$4
8 | VIDEO_SAVE_DIR=$5
9 |
10 | CHALLENGE="closed_loop_nonreactive_agents"
11 | # CHALLENGE="closed_loop_reactive_agents"
12 | # CHALLENGE="open_loop_boxes"
13 |
14 | python run_simulation.py \
15 | +simulation=$CHALLENGE \
16 | planner=$PLANNER \
17 | scenario_builder=$BUILDER \
18 | scenario_filter=$FILTER \
19 | worker=sequential \
20 | verbose=true \
21 | experiment_uid="pluto_planner/$FILTER" \
22 | planner.pluto_planner.render=true \
23 | planner.pluto_planner.planner_ckpt="$CKPT_ROOT/$CKPT" \
24 | +planner.pluto_planner.save_dir=$VIDEO_SAVE_DIR
25 |
26 |
--------------------------------------------------------------------------------
/script/setup_env.sh:
--------------------------------------------------------------------------------
1 | pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118
2 | pip3 install natten==0.14.6 -f https://shi-labs.com/natten/wheels/cu118/torch2.0.0/index.html
3 | pip install -r ./requirements.txt
4 |
--------------------------------------------------------------------------------
/src/custom_training/custom_training_builder.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from shutil import rmtree
6 | from typing import cast
7 |
8 | import pytorch_lightning as pl
9 | from hydra.utils import instantiate
10 | from nuplan.planning.script.builders.data_augmentation_builder import (
11 | build_agent_augmentor,
12 | )
13 | from nuplan.planning.script.builders.model_builder import build_torch_module_wrapper
14 | from nuplan.planning.script.builders.objectives_builder import build_objectives
15 | from nuplan.planning.script.builders.scenario_builder import build_scenarios
16 | from nuplan.planning.script.builders.splitter_builder import build_splitter
17 | from nuplan.planning.script.builders.training_metrics_builder import (
18 | build_training_metrics,
19 | )
20 | from nuplan.planning.training.modeling.lightning_module_wrapper import (
21 | LightningModuleWrapper,
22 | )
23 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper
24 | from nuplan.planning.training.preprocessing.feature_preprocessor import (
25 | FeaturePreprocessor,
26 | )
27 | from nuplan.planning.utils.multithreading.worker_pool import WorkerPool
28 | from omegaconf import DictConfig, OmegaConf
29 | from pytorch_lightning.callbacks import (
30 | LearningRateMonitor,
31 | ModelCheckpoint,
32 | RichModelSummary,
33 | RichProgressBar,
34 | )
35 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
36 | from pytorch_lightning.loggers.wandb import WandbLogger
37 |
38 | from .custom_datamodule import CustomDataModule
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 |
43 | def update_config_for_training(cfg: DictConfig) -> None:
44 | """
45 | Updates the config based on some conditions.
46 | :param cfg: omegaconf dictionary that is used to run the experiment.
47 | """
48 | # Make the configuration editable.
49 | OmegaConf.set_struct(cfg, False)
50 |
51 | if cfg.cache.cache_path is None:
52 | logger.warning("Parameter cache_path is not set, caching is disabled")
53 | else:
54 | if not str(cfg.cache.cache_path).startswith("s3://"):
55 | if cfg.cache.cleanup_cache and Path(cfg.cache.cache_path).exists():
56 | rmtree(cfg.cache.cache_path)
57 |
58 | Path(cfg.cache.cache_path).mkdir(parents=True, exist_ok=True)
59 |
60 | if cfg.lightning.trainer.overfitting.enable:
61 | cfg.data_loader.params.num_workers = 0
62 |
63 | OmegaConf.resolve(cfg)
64 |
65 | # Finalize the configuration and make it non-editable.
66 | OmegaConf.set_struct(cfg, True)
67 |
68 | # Log the final configuration after all overrides, interpolations and updates.
69 | if cfg.log_config:
70 | logger.info(
71 | f"Creating experiment name [{cfg.experiment}] in group [{cfg.group}] with config..."
72 | )
73 | logger.info("\n" + OmegaConf.to_yaml(cfg))
74 |
75 |
76 | @dataclass(frozen=True)
77 | class TrainingEngine:
78 | """Lightning training engine dataclass wrapping the lightning trainer, model and datamodule."""
79 |
80 | trainer: pl.Trainer # Trainer for models
81 | model: pl.LightningModule # Module describing NN model, loss, metrics, visualization
82 | datamodule: pl.LightningDataModule # Loading data
83 |
84 | def __repr__(self) -> str:
85 | """
86 | :return: String representation of class without expanding the fields.
87 | """
88 | return f"<{type(self).__module__}.{type(self).__qualname__} object at {hex(id(self))}>"
89 |
90 |
91 | def build_lightning_datamodule(
92 | cfg: DictConfig, worker: WorkerPool, model: TorchModuleWrapper
93 | ) -> pl.LightningDataModule:
94 | """
95 | Build the lightning datamodule from the config.
96 | :param cfg: Omegaconf dictionary.
97 | :param model: NN model used for training.
98 | :param worker: Worker to submit tasks which can be executed in parallel.
99 | :return: Instantiated datamodule object.
100 | """
101 | # Build features and targets
102 | feature_builders = model.get_list_of_required_feature()
103 | target_builders = model.get_list_of_computed_target()
104 |
105 | # Build splitter
106 | splitter = build_splitter(cfg.splitter)
107 |
108 | # Create feature preprocessor
109 | feature_preprocessor = FeaturePreprocessor(
110 | cache_path=cfg.cache.cache_path,
111 | force_feature_computation=cfg.cache.force_feature_computation,
112 | feature_builders=feature_builders,
113 | target_builders=target_builders,
114 | )
115 |
116 | # Create data augmentation
117 | augmentors = (
118 | build_agent_augmentor(cfg.data_augmentation)
119 | if "data_augmentation" in cfg
120 | else None
121 | )
122 |
123 | # Build dataset scenarios
124 | scenarios = build_scenarios(cfg, worker, model)
125 |
126 | # Create datamodule
127 | datamodule: pl.LightningDataModule = CustomDataModule(
128 | feature_preprocessor=feature_preprocessor,
129 | splitter=splitter,
130 | all_scenarios=scenarios,
131 | dataloader_params=cfg.data_loader.params,
132 | augmentors=augmentors,
133 | worker=worker,
134 | scenario_type_sampling_weights=cfg.scenario_type_weights.scenario_type_sampling_weights,
135 | **cfg.data_loader.datamodule,
136 | )
137 |
138 | return datamodule
139 |
140 |
141 | def build_lightning_module(
142 | cfg: DictConfig, torch_module_wrapper: TorchModuleWrapper
143 | ) -> pl.LightningModule:
144 | """
145 | Builds the lightning module from the config.
146 | :param cfg: omegaconf dictionary
147 | :param torch_module_wrapper: NN model used for training
148 | :return: built object.
149 | """
150 | # Create the complete Module
151 | if "custom_trainer" in cfg:
152 | model = instantiate(
153 | cfg.custom_trainer,
154 | model=torch_module_wrapper,
155 | lr=cfg.lr,
156 | weight_decay=cfg.weight_decay,
157 | epochs=cfg.epochs,
158 | warmup_epochs=cfg.warmup_epochs,
159 | )
160 | else:
161 | objectives = build_objectives(cfg)
162 | metrics = build_training_metrics(cfg)
163 | model = LightningModuleWrapper(
164 | model=torch_module_wrapper,
165 | objectives=objectives,
166 | metrics=metrics,
167 | batch_size=cfg.data_loader.params.batch_size,
168 | optimizer=cfg.optimizer,
169 | lr_scheduler=cfg.lr_scheduler if "lr_scheduler" in cfg else None,
170 | warm_up_lr_scheduler=cfg.warm_up_lr_scheduler
171 | if "warm_up_lr_scheduler" in cfg
172 | else None,
173 | objective_aggregate_mode=cfg.objective_aggregate_mode,
174 | )
175 |
176 | return cast(pl.LightningModule, model)
177 |
178 |
179 | def build_custom_trainer(cfg: DictConfig) -> pl.Trainer:
180 | """
181 | Builds the lightning trainer from the config.
182 | :param cfg: omegaconf dictionary
183 | :return: built object.
184 | """
185 | params = cfg.lightning.trainer.params
186 |
187 | # callbacks = build_callbacks(cfg)
188 | callbacks = [
189 | ModelCheckpoint(
190 | dirpath=os.path.join(os.getcwd(), "checkpoints"),
191 | filename="{epoch}-{val_minFDE:.3f}",
192 | monitor=cfg.lightning.trainer.checkpoint.monitor,
193 | mode=cfg.lightning.trainer.checkpoint.mode,
194 | save_top_k=cfg.lightning.trainer.checkpoint.save_top_k,
195 | save_last=True,
196 | ),
197 | RichModelSummary(max_depth=1),
198 | RichProgressBar(),
199 | LearningRateMonitor(logging_interval="epoch"),
200 | ]
201 |
202 | if cfg.wandb.mode == "disable":
203 | training_logger = TensorBoardLogger(
204 | save_dir=cfg.group,
205 | name=cfg.experiment,
206 | log_graph=False,
207 | version="",
208 | prefix="",
209 | )
210 | else:
211 | if cfg.wandb.artifact is not None:
212 | os.system(f"wandb artifact get {cfg.wandb.artifact}")
213 | _, _, artifact = cfg.wandb.artifact.split("/")
214 | checkpoint = os.path.join(os.getcwd(), f"artifacts/{artifact}/model.ckpt")
215 | run_id = artifact.split(":")[0][-8:]
216 | cfg.checkpoint = checkpoint
217 | cfg.wandb.run_id = run_id
218 |
219 | training_logger = WandbLogger(
220 | save_dir=cfg.group,
221 | project=cfg.wandb.project,
222 | name=cfg.wandb.name,
223 | mode=cfg.wandb.mode,
224 | log_model=cfg.wandb.log_model,
225 | resume=cfg.checkpoint is not None,
226 | id=cfg.wandb.run_id,
227 | )
228 |
229 | trainer = pl.Trainer(
230 | callbacks=callbacks,
231 | logger=training_logger,
232 | **params,
233 | )
234 |
235 | return trainer
236 |
237 |
238 | def build_training_engine(cfg: DictConfig, worker: WorkerPool) -> TrainingEngine:
239 | """
240 | Build the three core lightning modules: LightningDataModule, LightningModule and Trainer
241 | :param cfg: omegaconf dictionary
242 | :param worker: Worker to submit tasks which can be executed in parallel
243 | :return: TrainingEngine
244 | """
245 | logger.info("Building training engine...")
246 |
247 | trainer = build_custom_trainer(cfg)
248 |
249 | # Create model
250 | torch_module_wrapper = build_torch_module_wrapper(cfg.model)
251 |
252 | # Build the datamodule
253 | datamodule = build_lightning_datamodule(cfg, worker, torch_module_wrapper)
254 |
255 | # Build lightning module
256 | model = build_lightning_module(cfg, torch_module_wrapper)
257 |
258 | engine = TrainingEngine(trainer=trainer, datamodule=datamodule, model=model)
259 |
260 | return engine
--------------------------------------------------------------------------------
/src/feature_builders/common.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from enum import IntEnum
4 | from typing import Any, Dict, List, Set, Tuple, Type, Union
5 |
6 | import cv2
7 | import numba
8 | import numpy as np
9 | from nuplan.common.actor_state.ego_state import EgoState
10 | from nuplan.common.actor_state.state_representation import Point2D
11 | from nuplan.common.actor_state.tracked_objects import TrackedObjects
12 | from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType
13 | from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters
14 | from nuplan.common.maps.abstract_map import AbstractMap, MapObject
15 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer, TrafficLightStatusType
16 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
17 |
18 | EGO_PARAMS = get_pacifica_parameters()
19 | LANE_LAYERS = [SemanticMapLayer.LANE, SemanticMapLayer.LANE_CONNECTOR]
20 |
21 | HALF_WIDTH = EGO_PARAMS.half_width
22 | FRONT_LENGTH = EGO_PARAMS.front_length
23 | REAR_LENGTH = EGO_PARAMS.rear_length
24 |
25 |
26 | class PolylineElements(IntEnum):
27 | """
28 | Enum for PolylineElements.
29 | """
30 |
31 | LANE = 0
32 | BOUNDARY = 1
33 | STOP_LINE = 2
34 | CROSSWALK = 3
35 |
36 | @classmethod
37 | def deserialize(cls, layer: str) -> PolylineElements:
38 | """Deserialize the type when loading from a string."""
39 | return PolylineElements.__members__[layer]
40 |
41 |
42 | GlobalTypeMapping = {
43 | "AV": 0,
44 | TrackedObjectType.VEHICLE: 1,
45 | TrackedObjectType.PEDESTRIAN: 2,
46 | TrackedObjectType.BICYCLE: 3,
47 | PolylineElements.LANE: 4,
48 | PolylineElements.BOUNDARY: 5,
49 | }
50 |
51 |
52 | def interpolate_polyline(points: np.ndarray, t: int) -> np.ndarray:
53 | """copy from av2-api"""
54 |
55 | if points.ndim != 2:
56 | raise ValueError("Input array must be (N,2) or (N,3) in shape.")
57 |
58 | # the number of points on the curve itself
59 | n, _ = points.shape
60 |
61 | # equally spaced in arclength -- the number of points that will be uniformly interpolated
62 | eq_spaced_points = np.linspace(0, 1, t)
63 |
64 | # Compute the chordal arclength of each segment.
65 | # Compute differences between each x coord, to get the dx's
66 | # Do the same to get dy's. Then the hypotenuse length is computed as a norm.
67 | chordlen: np.ndarray = np.linalg.norm(np.diff(points, axis=0), axis=1) # type: ignore
68 | # Normalize the arclengths to a unit total
69 | chordlen = chordlen / np.sum(chordlen)
70 | # cumulative arclength
71 |
72 | cumarc: np.ndarray = np.zeros(len(chordlen) + 1)
73 | cumarc[1:] = np.cumsum(chordlen)
74 |
75 | # which interval did each point fall in, in terms of eq_spaced_points? (bin index)
76 | tbins: np.ndarray = np.digitize(eq_spaced_points, bins=cumarc).astype(int) # type: ignore
77 |
78 | # #catch any problems at the ends
79 | tbins[np.where((tbins <= 0) | (eq_spaced_points <= 0))] = 1 # type: ignore
80 | tbins[np.where((tbins >= n) | (eq_spaced_points >= 1))] = n - 1
81 |
82 | chordlen[tbins - 1] = np.where(
83 | chordlen[tbins - 1] == 0, chordlen[tbins - 1] + 1e-6, chordlen[tbins - 1]
84 | )
85 |
86 | s = np.divide((eq_spaced_points - cumarc[tbins - 1]), chordlen[tbins - 1])
87 | anchors = points[tbins - 1, :]
88 | # broadcast to scale each row of `points` by a different row of s
89 | offsets = (points[tbins, :] - points[tbins - 1, :]) * s.reshape(-1, 1)
90 | points_interp: np.ndarray = anchors + offsets
91 |
92 | return points_interp
93 |
94 |
95 | def get_ego_corners(rear_axle_xy: np.ndarray, heading: np.ndarray):
96 | """
97 | rear_axle_xy: [T, x, y]
98 | """
99 | ego_corners_offset = np.array(
100 | [
101 | [-REAR_LENGTH, -HALF_WIDTH],
102 | [-REAR_LENGTH, HALF_WIDTH],
103 | [FRONT_LENGTH, HALF_WIDTH],
104 | [FRONT_LENGTH, -HALF_WIDTH],
105 | ],
106 | dtype=np.float64,
107 | )
108 | ego_corners = rear_axle_xy[..., None, :] + ego_corners_offset[None, ...]
109 | rotate_mat = np.zeros((len(heading), 2, 2), dtype=np.float64)
110 | rotate_mat[:, 0, 0] = np.cos(heading)
111 | rotate_mat[:, 0, 1] = np.sin(heading)
112 | rotate_mat[:, 1, 0] = -np.sin(heading)
113 | rotate_mat[:, 1, 1] = np.cos(heading)
114 |
115 | ego_corners = ego_corners @ rotate_mat
116 | return ego_corners
117 |
118 |
119 | @numba.njit
120 | def rotate_round_z_axis(points: np.ndarray, angle: float):
121 | rotate_mat = np.array(
122 | [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
123 | # dtype=np.float64,
124 | )
125 | # return np.matmul(points, rotate_mat)
126 | return points @ rotate_mat
127 |
--------------------------------------------------------------------------------
/src/features/pluto_feature.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import Any, Dict, List
5 |
6 | import numpy as np
7 | import torch
8 | from nuplan.planning.training.preprocessing.features.abstract_model_feature import (
9 | AbstractModelFeature,
10 | )
11 | from torch.nn.utils.rnn import pad_sequence
12 |
13 | from src.utils.utils import to_device, to_numpy, to_tensor
14 |
15 |
16 | @dataclass
17 | class PlutoFeature(AbstractModelFeature):
18 | data: Dict[str, Any] # anchor sample
19 | data_p: Dict[str, Any] = None # positive sample
20 | data_n: Dict[str, Any] = None # negative sample
21 | data_n_info: Dict[str, Any] = None # negative sample info
22 |
23 | @classmethod
24 | def collate(cls, feature_list: List[PlutoFeature]) -> PlutoFeature:
25 | batch_data = {}
26 |
27 | pad_keys = ["agent", "map"]
28 | stack_keys = ["current_state", "origin", "angle"]
29 |
30 | if "reference_line" in feature_list[0].data:
31 | pad_keys.append("reference_line")
32 | if "static_objects" in feature_list[0].data:
33 | pad_keys.append("static_objects")
34 | if "cost_maps" in feature_list[0].data:
35 | stack_keys.append("cost_maps")
36 |
37 | if feature_list[0].data_n is not None:
38 | for key in pad_keys:
39 | batch_data[key] = {
40 | k: pad_sequence(
41 | [f.data[key][k] for f in feature_list]
42 | + [f.data_p[key][k] for f in feature_list]
43 | + [f.data_n[key][k] for f in feature_list],
44 | batch_first=True,
45 | )
46 | for k in feature_list[0].data[key].keys()
47 | }
48 |
49 | batch_data["data_n_valid_mask"] = torch.Tensor(
50 | [f.data_n_info["valid_mask"] for f in feature_list]
51 | ).bool()
52 | batch_data["data_n_type"] = torch.Tensor(
53 | [f.data_n_info["type"] for f in feature_list]
54 | ).long()
55 |
56 | for key in stack_keys:
57 | batch_data[key] = torch.stack(
58 | [f.data[key] for f in feature_list]
59 | + [f.data_p[key] for f in feature_list]
60 | + [f.data_n[key] for f in feature_list],
61 | dim=0,
62 | )
63 | elif feature_list[0].data_p is not None:
64 | for key in pad_keys:
65 | batch_data[key] = {
66 | k: pad_sequence(
67 | [f.data[key][k] for f in feature_list]
68 | + [f.data_p[key][k] for f in feature_list],
69 | batch_first=True,
70 | )
71 | for k in feature_list[0].data[key].keys()
72 | }
73 |
74 | for key in stack_keys:
75 | batch_data[key] = torch.stack(
76 | [f.data[key] for f in feature_list]
77 | + [f.data_p[key] for f in feature_list],
78 | dim=0,
79 | )
80 | else:
81 | for key in pad_keys:
82 | batch_data[key] = {
83 | k: pad_sequence(
84 | [f.data[key][k] for f in feature_list], batch_first=True
85 | )
86 | for k in feature_list[0].data[key].keys()
87 | }
88 |
89 | for key in stack_keys:
90 | batch_data[key] = torch.stack(
91 | [f.data[key] for f in feature_list], dim=0
92 | )
93 |
94 | return PlutoFeature(data=batch_data)
95 |
96 | def to_feature_tensor(self) -> PlutoFeature:
97 | new_data = {}
98 | for k, v in self.data.items():
99 | new_data[k] = to_tensor(v)
100 |
101 | if self.data_p is not None:
102 | new_data_p = {}
103 | for k, v in self.data_p.items():
104 | new_data_p[k] = to_tensor(v)
105 | else:
106 | new_data_p = None
107 |
108 | if self.data_n is not None:
109 | new_data_n = {}
110 | new_data_n_info = {}
111 | for k, v in self.data_n.items():
112 | new_data_n[k] = to_tensor(v)
113 | for k, v in self.data_n_info.items():
114 | new_data_n_info[k] = to_tensor(v)
115 | else:
116 | new_data_n = None
117 | new_data_n_info = None
118 |
119 | return PlutoFeature(
120 | data=new_data,
121 | data_p=new_data_p,
122 | data_n=new_data_n,
123 | data_n_info=new_data_n_info,
124 | )
125 |
126 | def to_numpy(self) -> PlutoFeature:
127 | new_data = {}
128 | for k, v in self.data.items():
129 | new_data[k] = to_numpy(v)
130 | if self.data_p is not None:
131 | new_data_p = {}
132 | for k, v in self.data_p.items():
133 | new_data_p[k] = to_numpy(v)
134 | else:
135 | new_data_p = None
136 | if self.data_n is not None:
137 | new_data_n = {}
138 | for k, v in self.data_n.items():
139 | new_data_n[k] = to_numpy(v)
140 | else:
141 | new_data_n = None
142 | return PlutoFeature(data=new_data, data_p=new_data_p, data_n=new_data_n)
143 |
144 | def to_device(self, device: torch.device) -> PlutoFeature:
145 | new_data = {}
146 | for k, v in self.data.items():
147 | new_data[k] = to_device(v, device)
148 | return PlutoFeature(data=new_data)
149 |
150 | def serialize(self) -> Dict[str, Any]:
151 | return {"data": self.data}
152 |
153 | @classmethod
154 | def deserialize(cls, data: Dict[str, Any]) -> PlutoFeature:
155 | return PlutoFeature(data=data["data"])
156 |
157 | def unpack(self) -> List[AbstractModelFeature]:
158 | raise NotImplementedError
159 |
160 | @property
161 | def is_valid(self) -> bool:
162 | if "reference_line" in self.data:
163 | return self.data["reference_line"]["valid_mask"].any()
164 | else:
165 | return self.data["map"]["point_position"].shape[0] > 0
166 |
167 | @classmethod
168 | def normalize(
169 | self, data, first_time=False, radius=None, hist_steps=21
170 | ) -> PlutoFeature:
171 | cur_state = data["current_state"]
172 | center_xy, center_angle = cur_state[:2].copy(), cur_state[2].copy()
173 |
174 | rotate_mat = np.array(
175 | [
176 | [np.cos(center_angle), -np.sin(center_angle)],
177 | [np.sin(center_angle), np.cos(center_angle)],
178 | ],
179 | dtype=np.float64,
180 | )
181 |
182 | data["current_state"][:3] = 0
183 | data["agent"]["position"] = np.matmul(
184 | data["agent"]["position"] - center_xy, rotate_mat
185 | )
186 | data["agent"]["velocity"] = np.matmul(data["agent"]["velocity"], rotate_mat)
187 | data["agent"]["heading"] -= center_angle
188 |
189 | data["map"]["point_position"] = np.matmul(
190 | data["map"]["point_position"] - center_xy, rotate_mat
191 | )
192 | data["map"]["point_vector"] = np.matmul(data["map"]["point_vector"], rotate_mat)
193 | data["map"]["point_orientation"] -= center_angle
194 |
195 | data["map"]["polygon_center"][..., :2] = np.matmul(
196 | data["map"]["polygon_center"][..., :2] - center_xy, rotate_mat
197 | )
198 | data["map"]["polygon_center"][..., 2] -= center_angle
199 | data["map"]["polygon_position"] = np.matmul(
200 | data["map"]["polygon_position"] - center_xy, rotate_mat
201 | )
202 | data["map"]["polygon_orientation"] -= center_angle
203 |
204 | if "causal" in data:
205 | if len(data["causal"]["free_path_points"]) > 0:
206 | data["causal"]["free_path_points"][..., :2] = np.matmul(
207 | data["causal"]["free_path_points"][..., :2] - center_xy, rotate_mat
208 | )
209 | data["causal"]["free_path_points"][..., 2] -= center_angle
210 | if "static_objects" in data:
211 | data["static_objects"]["position"] = np.matmul(
212 | data["static_objects"]["position"] - center_xy, rotate_mat
213 | )
214 | data["static_objects"]["heading"] -= center_angle
215 | if "route" in data:
216 | data["route"]["position"] = np.matmul(
217 | data["route"]["position"] - center_xy, rotate_mat
218 | )
219 | if "reference_line" in data:
220 | data["reference_line"]["position"] = np.matmul(
221 | data["reference_line"]["position"] - center_xy, rotate_mat
222 | )
223 | data["reference_line"]["vector"] = np.matmul(
224 | data["reference_line"]["vector"], rotate_mat
225 | )
226 | data["reference_line"]["orientation"] -= center_angle
227 |
228 | target_position = (
229 | data["agent"]["position"][:, hist_steps:]
230 | - data["agent"]["position"][:, hist_steps - 1][:, None]
231 | )
232 | target_heading = (
233 | data["agent"]["heading"][:, hist_steps:]
234 | - data["agent"]["heading"][:, hist_steps - 1][:, None]
235 | )
236 | target = np.concatenate([target_position, target_heading[..., None]], -1)
237 | target[~data["agent"]["valid_mask"][:, hist_steps:]] = 0
238 | data["agent"]["target"] = target
239 |
240 | if first_time:
241 | point_position = data["map"]["point_position"]
242 | x_max, x_min = radius, -radius
243 | y_max, y_min = radius, -radius
244 | valid_mask = (
245 | (point_position[:, 0, :, 0] < x_max)
246 | & (point_position[:, 0, :, 0] > x_min)
247 | & (point_position[:, 0, :, 1] < y_max)
248 | & (point_position[:, 0, :, 1] > y_min)
249 | )
250 | valid_polygon = valid_mask.any(-1)
251 | data["map"]["valid_mask"] = valid_mask
252 |
253 | for k, v in data["map"].items():
254 | data["map"][k] = v[valid_polygon]
255 |
256 | if "causal" in data:
257 | data["causal"]["ego_care_red_light_mask"] = data["causal"][
258 | "ego_care_red_light_mask"
259 | ][valid_polygon]
260 |
261 | data["origin"] = center_xy
262 | data["angle"] = center_angle
263 |
264 | return PlutoFeature(data=data)
265 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .min_ade import minADE
2 | from .min_fde import minFDE
3 | from .mr import MR
--------------------------------------------------------------------------------
/src/metrics/min_ade.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 | from .utils import sort_predictions
7 |
8 |
9 | class minADE(Metric):
10 | """Minimum Average Displacement Error
11 | minADE: The average L2 distance between the best forecasted trajectory and the ground truth.
12 | The best here refers to the trajectory that has the minimum endpoint error.
13 | """
14 |
15 | full_state_update: Optional[bool] = False
16 | higher_is_better: Optional[bool] = False
17 |
18 | def __init__(
19 | self,
20 | k=6,
21 | compute_on_step: bool = True,
22 | dist_sync_on_step: bool = False,
23 | process_group: Optional[Any] = None,
24 | dist_sync_fn: Callable = None,
25 | ) -> None:
26 | super(minADE, self).__init__(
27 | compute_on_step=compute_on_step,
28 | dist_sync_on_step=dist_sync_on_step,
29 | process_group=process_group,
30 | dist_sync_fn=dist_sync_fn,
31 | )
32 | self.k = k
33 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
34 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
35 |
36 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None:
37 | with torch.no_grad():
38 | pred, _ = sort_predictions(
39 | outputs["trajectory"], outputs["probability"], k=self.k
40 | )
41 | ade = torch.norm(
42 | pred[..., :2] - target.unsqueeze(1)[..., :2], p=2, dim=-1
43 | ).mean(-1)
44 | min_ade = ade.min(-1)[0]
45 | self.sum += min_ade.sum()
46 | self.count += pred.size(0)
47 |
48 | def compute(self) -> torch.Tensor:
49 | return self.sum / self.count
50 |
--------------------------------------------------------------------------------
/src/metrics/min_fde.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 | from .utils import sort_predictions
7 |
8 |
9 | class minFDE(Metric):
10 | full_state_update: Optional[bool] = False
11 | higher_is_better: Optional[bool] = False
12 |
13 | def __init__(
14 | self,
15 | k=6,
16 | compute_on_step: bool = True,
17 | dist_sync_on_step: bool = False,
18 | process_group: Optional[Any] = None,
19 | dist_sync_fn: Callable = None,
20 | ) -> None:
21 | super(minFDE, self).__init__(
22 | compute_on_step=compute_on_step,
23 | dist_sync_on_step=dist_sync_on_step,
24 | process_group=process_group,
25 | dist_sync_fn=dist_sync_fn,
26 | )
27 | self.k = k
28 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
29 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
30 |
31 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None:
32 | with torch.no_grad():
33 | pred, _ = sort_predictions(
34 | outputs["trajectory"], outputs["probability"], k=self.k
35 | )
36 | fde = torch.norm(
37 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1
38 | )
39 | min_fde = fde.min(-1)[0]
40 | self.sum += min_fde.sum()
41 | self.count += pred.shape[0]
42 |
43 | def compute(self) -> torch.Tensor:
44 | return self.sum / self.count
45 |
--------------------------------------------------------------------------------
/src/metrics/mr.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class MR(Metric):
8 | full_state_update: Optional[bool] = False
9 | higher_is_better: Optional[bool] = False
10 |
11 | def __init__(
12 | self,
13 | miss_threshold: float = 2.0,
14 | compute_on_step: bool = True,
15 | dist_sync_on_step: bool = False,
16 | process_group: Optional[Any] = None,
17 | dist_sync_fn: Callable = None,
18 | ) -> None:
19 | super(MR, self).__init__(
20 | compute_on_step=compute_on_step,
21 | dist_sync_on_step=dist_sync_on_step,
22 | process_group=process_group,
23 | dist_sync_fn=dist_sync_fn,
24 | )
25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
27 | self.miss_threshold = miss_threshold
28 |
29 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None:
30 | with torch.no_grad():
31 | pred = outputs["trajectory"]
32 | missed_pred = (
33 | torch.norm(
34 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1
35 | )
36 | > self.miss_threshold
37 | )
38 | self.sum += missed_pred.all(-1).sum()
39 | self.count += pred.shape[0]
40 |
41 | def compute(self) -> torch.Tensor:
42 | return self.sum / self.count
43 |
--------------------------------------------------------------------------------
/src/metrics/prediction_avg_ade.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Dict
2 |
3 | import torch
4 | from torchmetrics import Metric
5 | from torchmetrics.classification.accuracy import Accuracy
6 |
7 |
8 | class PredAvgADE(Metric):
9 | full_state_update: Optional[bool] = False
10 | higher_is_better: Optional[bool] = False
11 |
12 | def __init__(
13 | self,
14 | compute_on_step: bool = True,
15 | dist_sync_on_step: bool = False,
16 | process_group: Optional[Any] = None,
17 | dist_sync_fn: Callable = None,
18 | ) -> None:
19 | super(PredAvgADE, self).__init__(
20 | compute_on_step=compute_on_step,
21 | dist_sync_on_step=dist_sync_on_step,
22 | process_group=process_group,
23 | dist_sync_fn=dist_sync_fn,
24 | )
25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
27 |
28 | def update(
29 | self,
30 | outputs: Dict[str, torch.Tensor],
31 | target: torch.Tensor,
32 | ) -> None:
33 | """
34 | outputs: [A, T, 2]
35 | target: [A, T, 2]
36 | """
37 | with torch.no_grad():
38 | prediction, valid_mask = outputs["prediction"], outputs["valid_mask"]
39 | target = outputs["prediction_target"]
40 | ade = (
41 | torch.norm(prediction - target[..., :2], p=2, dim=-1) * valid_mask
42 | ).sum(-1) / (valid_mask.sum(-1) + 1e-6)
43 |
44 | self.sum += ade.sum()
45 | self.count += valid_mask.any(-1).sum().item()
46 |
47 | def compute(self) -> torch.Tensor:
48 | return self.sum / self.count
49 |
--------------------------------------------------------------------------------
/src/metrics/prediction_avg_fde.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Dict
2 |
3 | import torch
4 | from torchmetrics import Metric
5 | from torchmetrics.classification.accuracy import Accuracy
6 |
7 |
8 | class PredAvgFDE(Metric):
9 | full_state_update: Optional[bool] = False
10 | higher_is_better: Optional[bool] = False
11 |
12 | def __init__(
13 | self,
14 | compute_on_step: bool = True,
15 | dist_sync_on_step: bool = False,
16 | process_group: Optional[Any] = None,
17 | dist_sync_fn: Callable = None,
18 | ) -> None:
19 | super(PredAvgFDE, self).__init__(
20 | compute_on_step=compute_on_step,
21 | dist_sync_on_step=dist_sync_on_step,
22 | process_group=process_group,
23 | dist_sync_fn=dist_sync_fn,
24 | )
25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
27 |
28 | def update(
29 | self,
30 | outputs: Dict[str, torch.Tensor],
31 | target: torch.Tensor,
32 | ) -> None:
33 | """
34 | outputs: [A, T, 2]
35 | target: [A, T, 2]
36 | """
37 | with torch.no_grad():
38 | prediction, valid_mask = outputs["prediction"], outputs["valid_mask"]
39 | target = outputs["prediction_target"]
40 | endpoint_mask = valid_mask[..., -1].float()
41 | fde = (
42 | torch.norm(prediction[..., -1, :2] - target[..., -1, :2], p=2, dim=-1)
43 | * endpoint_mask
44 | ).sum(-1)
45 |
46 | self.sum += fde.sum()
47 | self.count += endpoint_mask.sum().long()
48 |
49 | def compute(self) -> torch.Tensor:
50 | return self.sum / self.count
51 |
--------------------------------------------------------------------------------
/src/metrics/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def sort_predictions(predictions, probability, k=6):
5 | """Sort the predictions based on the probability of each mode.
6 | Args:
7 | predictions (torch.Tensor): The predicted trajectories [b, k, t, 2].
8 | probability (torch.Tensor): The probability of each mode [b, k].
9 | Returns:
10 | torch.Tensor: The sorted predictions [b, k', t, 2].
11 | """
12 | indices = torch.argsort(probability, dim=-1, descending=True)
13 | sorted_prob = probability[torch.arange(probability.size(0))[:, None], indices]
14 | sorted_predictions = predictions[
15 | torch.arange(predictions.size(0))[:, None], indices
16 | ]
17 | return sorted_predictions[:, :k], sorted_prob[:, :k]
18 |
--------------------------------------------------------------------------------
/src/models/pluto/layers/common_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def build_mlp(c_in, channels, norm=None, activation="relu"):
5 | layers = []
6 | num_layers = len(channels)
7 |
8 | if norm is not None:
9 | norm = get_norm(norm)
10 |
11 | activation = get_activation(activation)
12 |
13 | for k in range(num_layers):
14 | if k == num_layers - 1:
15 | layers.append(nn.Linear(c_in, channels[k], bias=True))
16 | else:
17 | if norm is None:
18 | layers.extend([nn.Linear(c_in, channels[k], bias=True), activation()])
19 | else:
20 | layers.extend(
21 | [
22 | nn.Linear(c_in, channels[k], bias=False),
23 | norm(channels[k]),
24 | activation(),
25 | ]
26 | )
27 | c_in = channels[k]
28 |
29 | return nn.Sequential(*layers)
30 |
31 |
32 | def get_norm(norm: str):
33 | if norm == "bn":
34 | return nn.BatchNorm1d
35 | elif norm == "ln":
36 | return nn.LayerNorm
37 | else:
38 | raise NotImplementedError
39 |
40 |
41 | def get_activation(activation: str):
42 | if activation == "relu":
43 | return nn.ReLU
44 | elif activation == "gelu":
45 | return nn.GELU
46 | else:
47 | raise NotImplementedError
48 |
--------------------------------------------------------------------------------
/src/models/pluto/layers/embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from natten import NeighborhoodAttention1D
5 | from timm.models.layers import DropPath
6 |
7 |
8 | class NATSequenceEncoder(nn.Module):
9 | def __init__(
10 | self,
11 | in_chans=3,
12 | embed_dim=32,
13 | mlp_ratio=3,
14 | kernel_size=[3, 3, 5],
15 | depths=[2, 2, 2],
16 | num_heads=[2, 4, 8],
17 | out_indices=[0, 1, 2],
18 | drop_rate=0.0,
19 | attn_drop_rate=0.0,
20 | drop_path_rate=0.2,
21 | norm_layer=nn.LayerNorm,
22 | ) -> None:
23 | super().__init__()
24 |
25 | self.embed = ConvTokenizer(in_chans, embed_dim)
26 | self.num_levels = len(depths)
27 | self.num_features = [int(embed_dim * 2**i) for i in range(self.num_levels)]
28 | self.out_indices = out_indices
29 |
30 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
31 | self.levels = nn.ModuleList()
32 | for i in range(self.num_levels):
33 | level = NATBlock(
34 | dim=int(embed_dim * 2**i),
35 | depth=depths[i],
36 | num_heads=num_heads[i],
37 | kernel_size=kernel_size[i],
38 | dilations=None,
39 | mlp_ratio=mlp_ratio,
40 | drop=drop_rate,
41 | attn_drop=attn_drop_rate,
42 | drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
43 | norm_layer=norm_layer,
44 | downsample=(i < self.num_levels - 1),
45 | )
46 | self.levels.append(level)
47 |
48 | for i_layer in self.out_indices:
49 | layer = norm_layer(self.num_features[i_layer])
50 | layer_name = f"norm{i_layer}"
51 | self.add_module(layer_name, layer)
52 |
53 | n = self.num_features[-1]
54 | self.lateral_convs = nn.ModuleList()
55 | for i_layer in self.out_indices:
56 | self.lateral_convs.append(
57 | nn.Conv1d(self.num_features[i_layer], n, 3, padding=1)
58 | )
59 |
60 | self.fpn_conv = nn.Conv1d(n, n, 3, padding=1)
61 |
62 | def forward(self, x):
63 | """x: [B, C, T]"""
64 | x = self.embed(x)
65 |
66 | out = []
67 | for idx, level in enumerate(self.levels):
68 | x, xo = level(x)
69 | if idx in self.out_indices:
70 | norm_layer = getattr(self, f"norm{idx}")
71 | x_out = norm_layer(xo)
72 | out.append(x_out.permute(0, 2, 1).contiguous())
73 |
74 | laterals = [
75 | lateral_conv(out[i]) for i, lateral_conv in enumerate(self.lateral_convs)
76 | ]
77 | for i in range(len(out) - 1, 0, -1):
78 | laterals[i - 1] = laterals[i - 1] + F.interpolate(
79 | laterals[i],
80 | scale_factor=(laterals[i - 1].shape[-1] / laterals[i].shape[-1]),
81 | mode="linear",
82 | align_corners=False,
83 | )
84 |
85 | out = self.fpn_conv(laterals[0])
86 |
87 | return out[:, :, -1]
88 |
89 |
90 | class ConvTokenizer(nn.Module):
91 | def __init__(self, in_chans=3, embed_dim=32, norm_layer=None):
92 | super().__init__()
93 | self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=3, stride=1, padding=1)
94 |
95 | if norm_layer is not None:
96 | self.norm = norm_layer(embed_dim)
97 | else:
98 | self.norm = None
99 |
100 | def forward(self, x):
101 | x = self.proj(x).permute(0, 2, 1) # B, C, L -> B, L, C
102 | if self.norm is not None:
103 | x = self.norm(x)
104 | return x
105 |
106 |
107 | class ConvDownsampler(nn.Module):
108 | def __init__(self, dim, norm_layer=nn.LayerNorm):
109 | super().__init__()
110 | self.reduction = nn.Conv1d(
111 | dim, 2 * dim, kernel_size=3, stride=2, padding=1, bias=False
112 | )
113 | self.norm = norm_layer(2 * dim)
114 |
115 | def forward(self, x):
116 | x = self.reduction(x.permute(0, 2, 1)).permute(0, 2, 1)
117 | x = self.norm(x)
118 | return x
119 |
120 |
121 | class Mlp(nn.Module):
122 | def __init__(
123 | self,
124 | in_features,
125 | hidden_features=None,
126 | out_features=None,
127 | act_layer=nn.GELU,
128 | drop=0.0,
129 | ):
130 | super().__init__()
131 | out_features = out_features or in_features
132 | hidden_features = hidden_features or in_features
133 | self.fc1 = nn.Linear(in_features, hidden_features)
134 | self.act = act_layer()
135 | self.fc2 = nn.Linear(hidden_features, out_features)
136 | self.drop = nn.Dropout(drop)
137 |
138 | def forward(self, x):
139 | x = self.fc1(x)
140 | x = self.act(x)
141 | x = self.drop(x)
142 | x = self.fc2(x)
143 | x = self.drop(x)
144 | return x
145 |
146 |
147 | class NATLayer(nn.Module):
148 | def __init__(
149 | self,
150 | dim,
151 | num_heads,
152 | kernel_size=7,
153 | dilation=None,
154 | mlp_ratio=4.0,
155 | qkv_bias=True,
156 | qk_scale=None,
157 | drop=0.0,
158 | attn_drop=0.0,
159 | drop_path=0.0,
160 | act_layer=nn.GELU,
161 | norm_layer=nn.LayerNorm,
162 | ):
163 | super().__init__()
164 | self.dim = dim
165 | self.num_heads = num_heads
166 | self.mlp_ratio = mlp_ratio
167 |
168 | self.norm1 = norm_layer(dim)
169 | self.attn = NeighborhoodAttention1D(
170 | dim,
171 | kernel_size=kernel_size,
172 | dilation=dilation,
173 | num_heads=num_heads,
174 | qkv_bias=qkv_bias,
175 | qk_scale=qk_scale,
176 | attn_drop=attn_drop,
177 | proj_drop=drop,
178 | )
179 |
180 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
181 | self.norm2 = norm_layer(dim)
182 | self.mlp = Mlp(
183 | in_features=dim,
184 | hidden_features=int(dim * mlp_ratio),
185 | act_layer=act_layer,
186 | drop=drop,
187 | )
188 |
189 | def forward(self, x):
190 | shortcut = x
191 | x = self.norm1(x)
192 | x = self.attn(x)
193 | x = shortcut + self.drop_path(x)
194 | x = x + self.drop_path(self.mlp(self.norm2(x)))
195 | return x
196 |
197 |
198 | class NATBlock(nn.Module):
199 | def __init__(
200 | self,
201 | dim,
202 | depth,
203 | num_heads,
204 | kernel_size,
205 | dilations=None,
206 | downsample=True,
207 | mlp_ratio=4.0,
208 | qkv_bias=True,
209 | qk_scale=None,
210 | drop=0.0,
211 | attn_drop=0.0,
212 | drop_path=0.0,
213 | norm_layer=nn.LayerNorm,
214 | act_layer=nn.GELU,
215 | ):
216 | super().__init__()
217 | self.dim = dim
218 | self.depth = depth
219 |
220 | self.blocks = nn.ModuleList(
221 | [
222 | NATLayer(
223 | dim=dim,
224 | num_heads=num_heads,
225 | kernel_size=kernel_size,
226 | dilation=None if dilations is None else dilations[i],
227 | mlp_ratio=mlp_ratio,
228 | qkv_bias=qkv_bias,
229 | qk_scale=qk_scale,
230 | drop=drop,
231 | attn_drop=attn_drop,
232 | drop_path=drop_path[i]
233 | if isinstance(drop_path, list)
234 | else drop_path,
235 | norm_layer=norm_layer,
236 | act_layer=act_layer,
237 | )
238 | for i in range(depth)
239 | ]
240 | )
241 |
242 | self.downsample = (
243 | None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer)
244 | )
245 |
246 | def forward(self, x):
247 | for blk in self.blocks:
248 | x = blk(x)
249 | if self.downsample is None:
250 | return x, x
251 | return self.downsample(x), x
252 |
253 |
254 | class PointsEncoder(nn.Module):
255 | def __init__(self, feat_channel, encoder_channel):
256 | super().__init__()
257 | self.encoder_channel = encoder_channel
258 | self.first_mlp = nn.Sequential(
259 | nn.Linear(feat_channel, 128),
260 | nn.BatchNorm1d(128),
261 | nn.ReLU(inplace=True),
262 | nn.Linear(128, 256),
263 | )
264 | self.second_mlp = nn.Sequential(
265 | nn.Linear(512, 256),
266 | nn.BatchNorm1d(256),
267 | nn.ReLU(inplace=True),
268 | nn.Linear(256, self.encoder_channel),
269 | )
270 |
271 | def forward(self, x, mask=None):
272 | """
273 | x : B M 3
274 | mask: B M
275 | -----------------
276 | feature_global : B C
277 | """
278 |
279 | bs, n, _ = x.shape
280 | device = x.device
281 |
282 | x_valid = self.first_mlp(x[mask]) # B n 256
283 | x_features = torch.zeros(bs, n, 256, device=device)
284 | x_features[mask] = x_valid
285 |
286 | pooled_feature = x_features.max(dim=1)[0]
287 | x_features = torch.cat(
288 | [x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1
289 | )
290 |
291 | x_features_valid = self.second_mlp(x_features[mask])
292 | res = torch.zeros(bs, n, self.encoder_channel, device=device)
293 | res[mask] = x_features_valid
294 |
295 | res = res.max(dim=1)[0]
296 | return res
297 |
--------------------------------------------------------------------------------
/src/models/pluto/layers/fourier_embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, Zikang Zhou. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 | from typing import List, Optional
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class FourierEmbedding(nn.Module):
22 | def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int) -> None:
23 | super(FourierEmbedding, self).__init__()
24 | self.input_dim = input_dim
25 | self.hidden_dim = hidden_dim
26 |
27 | self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
28 | self.mlps = nn.ModuleList(
29 | [
30 | nn.Sequential(
31 | nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
32 | nn.LayerNorm(hidden_dim),
33 | nn.ReLU(inplace=True),
34 | nn.Linear(hidden_dim, hidden_dim),
35 | )
36 | for _ in range(input_dim)
37 | ]
38 | )
39 | self.to_out = nn.Sequential(
40 | nn.LayerNorm(hidden_dim),
41 | nn.ReLU(inplace=True),
42 | nn.Linear(hidden_dim, hidden_dim),
43 | )
44 |
45 | def forward(
46 | self,
47 | continuous_inputs: Optional[torch.Tensor],
48 | ) -> torch.Tensor:
49 | x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
50 | x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
51 | continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
52 | for i in range(self.input_dim):
53 | continuous_embs[i] = self.mlps[i](x[..., i, :])
54 | x = torch.stack(continuous_embs).sum(dim=0)
55 | return self.to_out(x)
56 |
--------------------------------------------------------------------------------
/src/models/pluto/layers/mlp_layer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class MLPLayer(nn.Module):
5 | def __init__(self, channel_in, hidden, channel_out) -> None:
6 | super().__init__()
7 |
8 | self.mlp = nn.Sequential(
9 | nn.Linear(channel_in, hidden),
10 | nn.LayerNorm(hidden),
11 | nn.ReLU(inplace=True),
12 | nn.Linear(hidden, channel_out),
13 | )
14 |
15 | def forward(self, x):
16 | return self.mlp(x)
17 |
--------------------------------------------------------------------------------
/src/models/pluto/layers/transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from timm.models.layers import DropPath
7 | from torch import Tensor
8 |
9 |
10 | class Mlp(nn.Module):
11 | """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
12 |
13 | def __init__(
14 | self,
15 | in_features,
16 | hidden_features=None,
17 | out_features=None,
18 | act_layer=nn.GELU,
19 | drop=0.0,
20 | ):
21 | super().__init__()
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 |
25 | self.fc1 = nn.Linear(in_features, hidden_features)
26 | self.act = act_layer()
27 | self.drop1 = nn.Dropout(drop)
28 | self.fc2 = nn.Linear(hidden_features, out_features)
29 | self.drop2 = nn.Dropout(drop)
30 |
31 | def forward(self, x):
32 | x = self.fc1(x)
33 | x = self.act(x)
34 | x = self.drop1(x)
35 | x = self.fc2(x)
36 | x = self.drop2(x)
37 | return x
38 |
39 |
40 | class TransformerEncoderLayer(nn.Module):
41 | def __init__(
42 | self,
43 | dim,
44 | num_heads,
45 | mlp_ratio=4.0,
46 | qkv_bias=False,
47 | drop=0.0,
48 | attn_drop=0.0,
49 | drop_path=0.0,
50 | act_layer=nn.GELU,
51 | norm_layer=nn.LayerNorm,
52 | ):
53 | super().__init__()
54 | self.norm1 = norm_layer(dim)
55 | self.attn = torch.nn.MultiheadAttention(
56 | dim,
57 | num_heads=num_heads,
58 | add_bias_kv=qkv_bias,
59 | dropout=attn_drop,
60 | batch_first=True,
61 | )
62 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
63 |
64 | self.norm2 = norm_layer(dim)
65 | self.mlp = Mlp(
66 | in_features=dim,
67 | hidden_features=int(dim * mlp_ratio),
68 | act_layer=act_layer,
69 | drop=drop,
70 | )
71 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
72 |
73 | def forward(
74 | self,
75 | src,
76 | mask: Optional[Tensor] = None,
77 | key_padding_mask: Optional[Tensor] = None,
78 | return_attn_weights=False,
79 | ):
80 | src2 = self.norm1(src)
81 | src2, attn = self.attn(
82 | query=src2,
83 | key=src2,
84 | value=src2,
85 | attn_mask=mask,
86 | key_padding_mask=key_padding_mask,
87 | )
88 | src = src + self.drop_path1(src2)
89 | src = src + self.drop_path2(self.mlp(self.norm2(src)))
90 |
91 | if return_attn_weights:
92 | return src, attn
93 |
94 | return src
95 |
96 |
97 | class CrossAttentionLayer(nn.Module):
98 | def __init__(
99 | self,
100 | dim,
101 | num_heads,
102 | mlp_ratio=4,
103 | qkv_bias=False,
104 | dropout=0.1,
105 | attn_drop=0.0,
106 | act_layer=nn.GELU,
107 | norm_layer=nn.LayerNorm,
108 | norm_first=True,
109 | ):
110 | super().__init__()
111 |
112 | self.norm_first = norm_first
113 | self.attn = torch.nn.MultiheadAttention(
114 | dim,
115 | num_heads=num_heads,
116 | add_bias_kv=qkv_bias,
117 | dropout=attn_drop,
118 | batch_first=True,
119 | )
120 |
121 | self.linear1 = nn.Linear(dim, int(mlp_ratio * dim))
122 | self.activation = act_layer()
123 | self.linear2 = nn.Linear(int(mlp_ratio * dim), dim)
124 |
125 | self.norm1 = norm_layer(dim)
126 | self.norm2 = norm_layer(dim)
127 | self.dropout1 = nn.Dropout(dropout)
128 | self.dropout2 = nn.Dropout(dropout)
129 | self.dropout3 = nn.Dropout(dropout)
130 |
131 | def forward(
132 | self,
133 | x,
134 | memory,
135 | mask: Optional[Tensor] = None,
136 | key_padding_mask: Optional[Tensor] = None,
137 | ):
138 | if self.norm_first:
139 | x = x + self._mha_block(self.norm1(x), memory, mask, key_padding_mask)
140 | x = x + self._ff_block(self.norm2(x))
141 | else:
142 | x = self.norm1(x + self._mha_block(x, memory, mask, key_padding_mask))
143 | x = self.norm2(x + self._ff_block(x))
144 |
145 | return x
146 |
147 | def _mha_block(
148 | self,
149 | x: Tensor,
150 | mem: Tensor,
151 | attn_mask: Optional[Tensor],
152 | key_padding_mask: Optional[Tensor],
153 | ) -> Tensor:
154 | x = self.attn(
155 | x,
156 | mem,
157 | mem,
158 | attn_mask=attn_mask,
159 | key_padding_mask=key_padding_mask,
160 | need_weights=False,
161 | )[0]
162 | return self.dropout1(x)
163 |
164 | def _ff_block(self, x: Tensor) -> Tensor:
165 | x = self.linear2(self.dropout2(self.activation(self.linear1(x))))
166 | return self.dropout3(x)
167 |
168 |
169 | class TransformerDecoderLayer(nn.Module):
170 | def __init__(
171 | self,
172 | d_model,
173 | nhead,
174 | dim_feedforward=2048,
175 | dropout=0.1,
176 | activation="relu",
177 | normalize_before=False,
178 | ):
179 | super().__init__()
180 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
181 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
182 | # Implementation of Feedforward model
183 | self.linear1 = nn.Linear(d_model, dim_feedforward)
184 | self.dropout = nn.Dropout(dropout)
185 | self.linear2 = nn.Linear(dim_feedforward, d_model)
186 |
187 | self.norm1 = nn.LayerNorm(d_model)
188 | self.norm2 = nn.LayerNorm(d_model)
189 | self.norm3 = nn.LayerNorm(d_model)
190 | self.dropout1 = nn.Dropout(dropout)
191 | self.dropout2 = nn.Dropout(dropout)
192 | self.dropout3 = nn.Dropout(dropout)
193 |
194 | self.activation = _get_activation_fn(activation)
195 | self.normalize_before = normalize_before
196 |
197 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
198 | return tensor if pos is None else tensor + pos
199 |
200 | def forward_post(
201 | self,
202 | tgt,
203 | memory,
204 | tgt_mask: Optional[Tensor] = None,
205 | memory_mask: Optional[Tensor] = None,
206 | tgt_key_padding_mask: Optional[Tensor] = None,
207 | memory_key_padding_mask: Optional[Tensor] = None,
208 | pos: Optional[Tensor] = None,
209 | query_pos: Optional[Tensor] = None,
210 | ):
211 | q = k = self.with_pos_embed(tgt, query_pos)
212 | tgt2 = self.self_attn(
213 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
214 | )[0]
215 | tgt = tgt + self.dropout1(tgt2)
216 | tgt = self.norm1(tgt)
217 | tgt2 = self.multihead_attn(
218 | query=self.with_pos_embed(tgt, query_pos),
219 | key=self.with_pos_embed(memory, pos),
220 | value=memory,
221 | attn_mask=memory_mask,
222 | key_padding_mask=memory_key_padding_mask,
223 | )[0]
224 | tgt = tgt + self.dropout2(tgt2)
225 | tgt = self.norm2(tgt)
226 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
227 | tgt = tgt + self.dropout3(tgt2)
228 | tgt = self.norm3(tgt)
229 | return tgt
230 |
231 | def forward_pre(
232 | self,
233 | tgt,
234 | memory,
235 | tgt_mask: Optional[Tensor] = None,
236 | memory_mask: Optional[Tensor] = None,
237 | tgt_key_padding_mask: Optional[Tensor] = None,
238 | memory_key_padding_mask: Optional[Tensor] = None,
239 | pos: Optional[Tensor] = None,
240 | query_pos: Optional[Tensor] = None,
241 | ):
242 | tgt2 = self.norm1(tgt)
243 | q = k = self.with_pos_embed(tgt2, query_pos)
244 | tgt2 = self.self_attn(
245 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
246 | )[0]
247 | tgt = tgt + self.dropout1(tgt2)
248 | tgt2 = self.norm2(tgt)
249 | tgt2 = self.multihead_attn(
250 | query=self.with_pos_embed(tgt2, query_pos),
251 | key=self.with_pos_embed(memory, pos),
252 | value=memory,
253 | attn_mask=memory_mask,
254 | key_padding_mask=memory_key_padding_mask,
255 | )[0]
256 | tgt = tgt + self.dropout2(tgt2)
257 | tgt2 = self.norm3(tgt)
258 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
259 | tgt = tgt + self.dropout3(tgt2)
260 | return tgt
261 |
262 | def forward(
263 | self,
264 | tgt,
265 | memory,
266 | tgt_mask: Optional[Tensor] = None,
267 | memory_mask: Optional[Tensor] = None,
268 | tgt_key_padding_mask: Optional[Tensor] = None,
269 | memory_key_padding_mask: Optional[Tensor] = None,
270 | pos: Optional[Tensor] = None,
271 | query_pos: Optional[Tensor] = None,
272 | ):
273 | if self.normalize_before:
274 | return self.forward_pre(
275 | tgt,
276 | memory,
277 | tgt_mask,
278 | memory_mask,
279 | tgt_key_padding_mask,
280 | memory_key_padding_mask,
281 | pos,
282 | query_pos,
283 | )
284 | return self.forward_post(
285 | tgt,
286 | memory,
287 | tgt_mask,
288 | memory_mask,
289 | tgt_key_padding_mask,
290 | memory_key_padding_mask,
291 | pos,
292 | query_pos,
293 | )
294 |
295 |
296 | def get_activation_fn(activation):
297 | """Return an activation function given a string"""
298 | if activation == "relu":
299 | return F.relu
300 | if activation == "gelu":
301 | return F.gelu
302 | if activation == "glu":
303 | return F.glu
304 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
305 |
--------------------------------------------------------------------------------
/src/models/pluto/loss/esdf_collision_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 |
9 | class ESDFCollisionLoss(nn.Module):
10 | def __init__(
11 | self,
12 | num_circles=3,
13 | ego_width=2.297,
14 | ego_front_length=4.049,
15 | ego_rear_length=1.127,
16 | resolution=0.2,
17 | ) -> None:
18 | super().__init__()
19 |
20 | ego_length = ego_front_length + ego_rear_length
21 | interval = ego_length / num_circles
22 |
23 | self.N = num_circles
24 | self.width = ego_width
25 | self.length = ego_length
26 | self.rear_length = ego_rear_length
27 | self.resolution = resolution
28 |
29 | self.radius = math.sqrt(ego_width**2 + interval**2) / 2 - resolution
30 | self.offset = torch.Tensor(
31 | [-ego_rear_length + interval / 2 * (2 * i + 1) for i in range(num_circles)]
32 | )
33 |
34 | def forward(self, trajectory: Tensor, sdf: Tensor):
35 | """
36 | trajectory: (bs, T, 4) - [x, y, cos0, sin0]
37 | sdf: (bs, H, W)
38 | """
39 | bs, H, W = sdf.shape
40 |
41 | origin_offset = torch.tensor([W // 2, H // 2], device=sdf.device)
42 | offset = self.offset.to(sdf.device).view(1, 1, self.N, 1)
43 | # (bs, T, N, 2)
44 | centers = trajectory[..., None, :2] + offset * trajectory[..., None, 2:4]
45 |
46 | pixel_coord = torch.stack(
47 | [centers[..., 0] / self.resolution, -centers[..., 1] / self.resolution],
48 | dim=-1,
49 | )
50 | grid_xy = pixel_coord / origin_offset
51 | valid_mask = (grid_xy < 0.95).all(-1) & (grid_xy > -0.95).all(-1)
52 | on_road_mask = sdf[:, H // 2, W // 2] > 0
53 |
54 | # (bs, T, N)
55 | distance = F.grid_sample(
56 | sdf.unsqueeze(1), grid_xy, mode="bilinear", padding_mode="zeros"
57 | ).squeeze(1)
58 |
59 | cost = self.radius - distance
60 | valid_mask = valid_mask & (cost > 0) & on_road_mask[:, None, None]
61 | cost.masked_fill_(~valid_mask, 0)
62 |
63 | loss = F.l1_loss(cost, torch.zeros_like(cost), reduction="none").sum(-1)
64 | loss = loss.sum() / (valid_mask.sum() + 1e-6)
65 |
66 | return loss
67 |
--------------------------------------------------------------------------------
/src/models/pluto/modules/agent_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..layers.common_layers import build_mlp
5 | from ..layers.embedding import NATSequenceEncoder
6 |
7 |
8 | class AgentEncoder(nn.Module):
9 | def __init__(
10 | self,
11 | state_channel=6,
12 | history_channel=9,
13 | dim=128,
14 | hist_steps=21,
15 | use_ego_history=False,
16 | drop_path=0.2,
17 | state_attn_encoder=True,
18 | state_dropout=0.75,
19 | ) -> None:
20 | super().__init__()
21 | self.dim = dim
22 | self.state_channel = state_channel
23 | self.use_ego_history = use_ego_history
24 | self.hist_steps = hist_steps
25 | self.state_attn_encoder = state_attn_encoder
26 |
27 | self.history_encoder = NATSequenceEncoder(
28 | in_chans=history_channel, embed_dim=dim // 4, drop_path_rate=drop_path
29 | )
30 |
31 | if not use_ego_history:
32 | if not self.state_attn_encoder:
33 | self.ego_state_emb = build_mlp(state_channel, [dim] * 2, norm="bn")
34 | else:
35 | self.ego_state_emb = StateAttentionEncoder(
36 | state_channel, dim, state_dropout
37 | )
38 |
39 | self.type_emb = nn.Embedding(4, dim)
40 |
41 | @staticmethod
42 | def to_vector(feat, valid_mask):
43 | vec_mask = valid_mask[..., :-1] & valid_mask[..., 1:]
44 |
45 | while len(vec_mask.shape) < len(feat.shape):
46 | vec_mask = vec_mask.unsqueeze(-1)
47 |
48 | return torch.where(
49 | vec_mask,
50 | feat[:, :, 1:, ...] - feat[:, :, :-1, ...],
51 | torch.zeros_like(feat[:, :, 1:, ...]),
52 | )
53 |
54 | def forward(self, data):
55 | T = self.hist_steps
56 |
57 | position = data["agent"]["position"][:, :, :T]
58 | heading = data["agent"]["heading"][:, :, :T]
59 | velocity = data["agent"]["velocity"][:, :, :T]
60 | shape = data["agent"]["shape"][:, :, :T]
61 | category = data["agent"]["category"].long()
62 | valid_mask = data["agent"]["valid_mask"][:, :, :T]
63 |
64 | heading_vec = self.to_vector(heading, valid_mask)
65 | valid_mask_vec = valid_mask[..., 1:] & valid_mask[..., :-1]
66 | agent_feature = torch.cat(
67 | [
68 | self.to_vector(position, valid_mask),
69 | self.to_vector(velocity, valid_mask),
70 | torch.stack([heading_vec.cos(), heading_vec.sin()], dim=-1),
71 | shape[:, :, 1:],
72 | valid_mask_vec.float().unsqueeze(-1),
73 | ],
74 | dim=-1,
75 | )
76 | bs, A, T, _ = agent_feature.shape
77 | agent_feature = agent_feature.view(bs * A, T, -1)
78 | valid_agent_mask = valid_mask.any(-1).flatten()
79 |
80 | x_agent_tmp = self.history_encoder(
81 | agent_feature[valid_agent_mask].permute(0, 2, 1).contiguous()
82 | )
83 | x_agent = torch.zeros(bs * A, self.dim, device=position.device)
84 | x_agent[valid_agent_mask] = x_agent_tmp
85 | x_agent = x_agent.view(bs, A, self.dim)
86 |
87 | if not self.use_ego_history:
88 | ego_feature = data["current_state"][:, : self.state_channel]
89 | x_ego = self.ego_state_emb(ego_feature)
90 | x_agent[:, 0] = x_ego
91 |
92 | x_type = self.type_emb(category)
93 |
94 | return x_agent + x_type
95 |
96 |
97 | class StateAttentionEncoder(nn.Module):
98 | def __init__(self, state_channel, dim, state_dropout=0.5) -> None:
99 | super().__init__()
100 |
101 | self.state_channel = state_channel
102 | self.state_dropout = state_dropout
103 | self.linears = nn.ModuleList([nn.Linear(1, dim) for _ in range(state_channel)])
104 | self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=4, batch_first=True)
105 | self.pos_embed = nn.Parameter(torch.Tensor(1, state_channel, dim))
106 | self.query = nn.Parameter(torch.Tensor(1, 1, dim))
107 |
108 | nn.init.normal_(self.pos_embed, std=0.02)
109 | nn.init.normal_(self.query, std=0.02)
110 |
111 | def forward(self, x):
112 | x_embed = []
113 | for i, linear in enumerate(self.linears):
114 | x_embed.append(linear(x[:, i, None]))
115 | x_embed = torch.stack(x_embed, dim=1)
116 | pos_embed = self.pos_embed.repeat(x_embed.shape[0], 1, 1)
117 | x_embed += pos_embed
118 |
119 | if self.training and self.state_dropout > 0:
120 | visible_tokens = torch.zeros(
121 | (x_embed.shape[0], 3), device=x.device, dtype=torch.bool
122 | )
123 | dropout_tokens = (
124 | torch.rand((x_embed.shape[0], self.state_channel - 3), device=x.device)
125 | < self.state_dropout
126 | )
127 | key_padding_mask = torch.concat([visible_tokens, dropout_tokens], dim=1)
128 | else:
129 | key_padding_mask = None
130 |
131 | query = self.query.repeat(x_embed.shape[0], 1, 1)
132 |
133 | x_state = self.attn(
134 | query=query,
135 | key=x_embed,
136 | value=x_embed,
137 | key_padding_mask=key_padding_mask,
138 | )[0]
139 |
140 | return x_state[:, 0]
141 |
--------------------------------------------------------------------------------
/src/models/pluto/modules/agent_predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..layers.mlp_layer import MLPLayer
5 |
6 |
7 | class AgentPredictor(nn.Module):
8 | def __init__(self, dim, future_steps) -> None:
9 | super().__init__()
10 |
11 | self.future_steps = future_steps
12 |
13 | self.loc_predictor = MLPLayer(dim, 2 * dim, future_steps * 2)
14 | self.yaw_predictor = MLPLayer(dim, 2 * dim, future_steps * 2)
15 | self.vel_predictor = MLPLayer(dim, 2 * dim, future_steps * 2)
16 |
17 | def forward(self, x):
18 | """
19 | x: (bs, N, dim)
20 | """
21 |
22 | bs, N, _ = x.shape
23 |
24 | loc = self.loc_predictor(x).view(bs, N, self.future_steps, 2)
25 | yaw = self.yaw_predictor(x).view(bs, N, self.future_steps, 2)
26 | vel = self.vel_predictor(x).view(bs, N, self.future_steps, 2)
27 |
28 | prediction = torch.cat([loc, yaw, vel], dim=-1)
29 | return prediction
30 |
--------------------------------------------------------------------------------
/src/models/pluto/modules/map_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..layers.embedding import PointsEncoder
5 | from ..layers.fourier_embedding import FourierEmbedding
6 |
7 |
8 | class MapEncoder(nn.Module):
9 | def __init__(
10 | self,
11 | polygon_channel=6,
12 | dim=128,
13 | use_lane_boundary=False,
14 | ) -> None:
15 | super().__init__()
16 |
17 | self.dim = dim
18 | self.use_lane_boundary = use_lane_boundary
19 | self.polygon_channel = (
20 | polygon_channel + 4 if use_lane_boundary else polygon_channel
21 | )
22 |
23 | self.polygon_encoder = PointsEncoder(self.polygon_channel, dim)
24 | self.speed_limit_emb = FourierEmbedding(1, dim, 64)
25 |
26 | self.type_emb = nn.Embedding(3, dim)
27 | self.on_route_emb = nn.Embedding(2, dim)
28 | self.traffic_light_emb = nn.Embedding(4, dim)
29 | self.unknown_speed_emb = nn.Embedding(1, dim)
30 |
31 | def forward(self, data) -> torch.Tensor:
32 | polygon_center = data["map"]["polygon_center"]
33 | polygon_type = data["map"]["polygon_type"].long()
34 | polygon_on_route = data["map"]["polygon_on_route"].long()
35 | polygon_tl_status = data["map"]["polygon_tl_status"].long()
36 | polygon_has_speed_limit = data["map"]["polygon_has_speed_limit"]
37 | polygon_speed_limit = data["map"]["polygon_speed_limit"]
38 | point_position = data["map"]["point_position"]
39 | point_vector = data["map"]["point_vector"]
40 | point_orientation = data["map"]["point_orientation"]
41 | valid_mask = data["map"]["valid_mask"]
42 |
43 | if self.use_lane_boundary:
44 | polygon_feature = torch.cat(
45 | [
46 | point_position[:, :, 0] - polygon_center[..., None, :2],
47 | point_vector[:, :, 0],
48 | torch.stack(
49 | [
50 | point_orientation[:, :, 0].cos(),
51 | point_orientation[:, :, 0].sin(),
52 | ],
53 | dim=-1,
54 | ),
55 | point_position[:, :, 1] - point_position[:, :, 0],
56 | point_position[:, :, 2] - point_position[:, :, 0],
57 | ],
58 | dim=-1,
59 | )
60 | else:
61 | polygon_feature = torch.cat(
62 | [
63 | point_position[:, :, 0] - polygon_center[..., None, :2],
64 | point_vector[:, :, 0],
65 | torch.stack(
66 | [
67 | point_orientation[:, :, 0].cos(),
68 | point_orientation[:, :, 0].sin(),
69 | ],
70 | dim=-1,
71 | ),
72 | ],
73 | dim=-1,
74 | )
75 |
76 | bs, M, P, C = polygon_feature.shape
77 | valid_mask = valid_mask.view(bs * M, P)
78 | polygon_feature = polygon_feature.reshape(bs * M, P, C)
79 |
80 | x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1)
81 |
82 | x_type = self.type_emb(polygon_type)
83 | x_on_route = self.on_route_emb(polygon_on_route)
84 | x_tl_status = self.traffic_light_emb(polygon_tl_status)
85 | x_speed_limit = torch.zeros(bs, M, self.dim, device=x_polygon.device)
86 | x_speed_limit[polygon_has_speed_limit] = self.speed_limit_emb(
87 | polygon_speed_limit[polygon_has_speed_limit].unsqueeze(-1)
88 | )
89 | x_speed_limit[~polygon_has_speed_limit] = self.unknown_speed_emb.weight
90 |
91 | x_polygon += x_type + x_on_route + x_tl_status + x_speed_limit
92 |
93 | return x_polygon
94 |
--------------------------------------------------------------------------------
/src/models/pluto/modules/planning_decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch import Tensor
6 |
7 | from ..layers.embedding import PointsEncoder
8 | from ..layers.fourier_embedding import FourierEmbedding
9 | from ..layers.mlp_layer import MLPLayer
10 |
11 |
12 | class DecoderLayer(nn.Module):
13 | def __init__(self, dim, num_heads, mlp_ratio, dropout) -> None:
14 | super().__init__()
15 | self.dim = dim
16 |
17 | self.r2r_attn = nn.MultiheadAttention(
18 | dim, num_heads, dropout=dropout, batch_first=True
19 | )
20 | self.m2m_attn = nn.MultiheadAttention(
21 | dim, num_heads, dropout=dropout, batch_first=True
22 | )
23 | self.cross_attn = nn.MultiheadAttention(
24 | dim, num_heads, dropout=dropout, batch_first=True
25 | )
26 |
27 | self.ffn = nn.Sequential(
28 | nn.Linear(dim, dim * mlp_ratio),
29 | nn.ReLU(inplace=True),
30 | nn.Dropout(dropout),
31 | nn.Linear(dim * mlp_ratio, dim),
32 | )
33 |
34 | self.norm1 = nn.LayerNorm(dim)
35 | self.norm2 = nn.LayerNorm(dim)
36 | self.norm3 = nn.LayerNorm(dim)
37 | self.norm4 = nn.LayerNorm(dim)
38 | self.dropout1 = nn.Dropout(dropout)
39 | self.dropout2 = nn.Dropout(dropout)
40 | self.dropout3 = nn.Dropout(dropout)
41 |
42 | def forward(
43 | self,
44 | tgt,
45 | memory,
46 | tgt_key_padding_mask: Optional[Tensor] = None,
47 | memory_key_padding_mask: Optional[Tensor] = None,
48 | m_pos: Optional[Tensor] = None,
49 | ):
50 | """
51 | tgt: (bs, R, M, dim)
52 | tgt_key_padding_mask: (bs, R)
53 | """
54 | bs, R, M, D = tgt.shape
55 |
56 | tgt = tgt.transpose(1, 2).reshape(bs * M, R, D)
57 | tgt2 = self.norm1(tgt)
58 | tgt2 = self.r2r_attn(
59 | tgt2, tgt2, tgt2, key_padding_mask=tgt_key_padding_mask.repeat(M, 1)
60 | )[0]
61 | tgt = tgt + self.dropout1(tgt2)
62 |
63 | tgt_tmp = tgt.reshape(bs, M, R, D).transpose(1, 2).reshape(bs * R, M, D)
64 | tgt_valid_mask = ~tgt_key_padding_mask.reshape(-1)
65 | tgt_valid = tgt_tmp[tgt_valid_mask]
66 | tgt2_valid = self.norm2(tgt_valid)
67 | tgt2_valid, _ = self.m2m_attn(
68 | tgt2_valid + m_pos, tgt2_valid + m_pos, tgt2_valid
69 | )
70 | tgt_valid = tgt_valid + self.dropout2(tgt2_valid)
71 | tgt = torch.zeros_like(tgt_tmp)
72 | tgt[tgt_valid_mask] = tgt_valid
73 |
74 | tgt = tgt.reshape(bs, R, M, D).view(bs, R * M, D)
75 | tgt2 = self.norm3(tgt)
76 | tgt2 = self.cross_attn(
77 | tgt2, memory, memory, key_padding_mask=memory_key_padding_mask
78 | )[0]
79 |
80 | tgt = tgt + self.dropout2(tgt2)
81 | tgt2 = self.norm4(tgt)
82 | tgt2 = self.ffn(tgt2)
83 | tgt = tgt + self.dropout3(tgt2)
84 | tgt = tgt.reshape(bs, R, M, D)
85 |
86 | return tgt
87 |
88 |
89 | class PlanningDecoder(nn.Module):
90 | def __init__(
91 | self,
92 | num_mode,
93 | decoder_depth,
94 | dim,
95 | num_heads,
96 | mlp_ratio,
97 | dropout,
98 | future_steps,
99 | yaw_constraint=False,
100 | cat_x=False,
101 | ) -> None:
102 | super().__init__()
103 |
104 | self.num_mode = num_mode
105 | self.future_steps = future_steps
106 | self.yaw_constraint = yaw_constraint
107 | self.cat_x = cat_x
108 |
109 | self.decoder_blocks = nn.ModuleList(
110 | [
111 | DecoderLayer(dim, num_heads, mlp_ratio, dropout)
112 | for _ in range(decoder_depth)
113 | ]
114 | )
115 |
116 | self.r_pos_emb = FourierEmbedding(3, dim, 64)
117 | self.r_encoder = PointsEncoder(6, dim)
118 |
119 | self.q_proj = nn.Linear(2 * dim, dim)
120 |
121 | self.m_emb = nn.Parameter(torch.Tensor(1, 1, num_mode, dim))
122 | self.m_pos = nn.Parameter(torch.Tensor(1, num_mode, dim))
123 |
124 | if self.cat_x:
125 | self.cat_x_proj = nn.Linear(2 * dim, dim)
126 |
127 | self.loc_head = MLPLayer(dim, 2 * dim, self.future_steps * 2)
128 | self.yaw_head = MLPLayer(dim, 2 * dim, self.future_steps * 2)
129 | self.vel_head = MLPLayer(dim, 2 * dim, self.future_steps * 2)
130 | self.pi_head = MLPLayer(dim, dim, 1)
131 |
132 | nn.init.normal_(self.m_emb, mean=0.0, std=0.01)
133 | nn.init.normal_(self.m_pos, mean=0.0, std=0.01)
134 |
135 | def forward(self, data, enc_data):
136 | enc_emb = enc_data["enc_emb"]
137 | enc_key_padding_mask = enc_data["enc_key_padding_mask"]
138 |
139 | r_position = data["reference_line"]["position"]
140 | r_vector = data["reference_line"]["vector"]
141 | r_orientation = data["reference_line"]["orientation"]
142 | r_valid_mask = data["reference_line"]["valid_mask"]
143 | r_key_padding_mask = ~r_valid_mask.any(-1)
144 |
145 | r_feature = torch.cat(
146 | [
147 | r_position - r_position[..., 0:1, :2],
148 | r_vector,
149 | torch.stack([r_orientation.cos(), r_orientation.sin()], dim=-1),
150 | ],
151 | dim=-1,
152 | )
153 |
154 | bs, R, P, C = r_feature.shape
155 | r_valid_mask = r_valid_mask.view(bs * R, P)
156 | r_feature = r_feature.reshape(bs * R, P, C)
157 | r_emb = self.r_encoder(r_feature, r_valid_mask).view(bs, R, -1)
158 |
159 | r_pos = torch.cat([r_position[:, :, 0], r_orientation[:, :, 0, None]], dim=-1)
160 | r_emb = r_emb + self.r_pos_emb(r_pos)
161 |
162 | r_emb = r_emb.unsqueeze(2).repeat(1, 1, self.num_mode, 1)
163 | m_emb = self.m_emb.repeat(bs, R, 1, 1)
164 |
165 | q = self.q_proj(torch.cat([r_emb, m_emb], dim=-1))
166 |
167 | for blk in self.decoder_blocks:
168 | q = blk(
169 | q,
170 | enc_emb,
171 | tgt_key_padding_mask=r_key_padding_mask,
172 | memory_key_padding_mask=enc_key_padding_mask,
173 | m_pos=self.m_pos,
174 | )
175 | assert torch.isfinite(q).all()
176 |
177 | if self.cat_x:
178 | x = enc_emb[:, 0].unsqueeze(1).unsqueeze(2).repeat(1, R, self.num_mode, 1)
179 | q = self.cat_x_proj(torch.cat([q, x], dim=-1))
180 |
181 | loc = self.loc_head(q).view(bs, R, self.num_mode, self.future_steps, 2)
182 | yaw = self.yaw_head(q).view(bs, R, self.num_mode, self.future_steps, 2)
183 | vel = self.vel_head(q).view(bs, R, self.num_mode, self.future_steps, 2)
184 | pi = self.pi_head(q).squeeze(-1)
185 |
186 | traj = torch.cat([loc, yaw, vel], dim=-1)
187 |
188 | return traj, pi
189 |
--------------------------------------------------------------------------------
/src/models/pluto/modules/static_objects_encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ..layers.fourier_embedding import FourierEmbedding
6 |
7 |
8 | class StaticObjectsEncoder(nn.Module):
9 | def __init__(self, dim) -> None:
10 | super().__init__()
11 |
12 | self.obj_encoder = FourierEmbedding(2, dim, 64)
13 | self.type_emb = nn.Embedding(4, dim)
14 |
15 | nn.init.normal_(self.type_emb.weight, mean=0.0, std=0.01)
16 |
17 | def forward(self, data):
18 | pos = data["static_objects"]["position"]
19 | heading = data["static_objects"]["heading"]
20 | shape = data["static_objects"]["shape"]
21 | category = data["static_objects"]["category"].long()
22 | valid_mask = data["static_objects"]["valid_mask"] # [bs, N]
23 |
24 | obj_emb_tmp = self.obj_encoder(shape) + self.type_emb(category.long())
25 | obj_emb = torch.zeros_like(obj_emb_tmp)
26 | obj_emb[valid_mask] = obj_emb_tmp[valid_mask]
27 |
28 | heading = (heading + math.pi) % (2 * math.pi) - math.pi
29 | obj_pos = torch.cat([pos, heading.unsqueeze(-1)], dim=-1)
30 |
31 | return obj_emb, obj_pos, ~valid_mask
32 |
--------------------------------------------------------------------------------
/src/models/pluto/pluto_model.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
7 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper
8 | from nuplan.planning.training.preprocessing.target_builders.ego_trajectory_target_builder import (
9 | EgoTrajectoryTargetBuilder,
10 | )
11 |
12 | from src.feature_builders.pluto_feature_builder import PlutoFeatureBuilder
13 |
14 | from .layers.fourier_embedding import FourierEmbedding
15 | from .layers.transformer import TransformerEncoderLayer
16 | from .modules.agent_encoder import AgentEncoder
17 | from .modules.agent_predictor import AgentPredictor
18 | from .modules.map_encoder import MapEncoder
19 | from .modules.static_objects_encoder import StaticObjectsEncoder
20 | from .modules.planning_decoder import PlanningDecoder
21 | from .layers.mlp_layer import MLPLayer
22 |
23 | # no meaning, required by nuplan
24 | trajectory_sampling = TrajectorySampling(num_poses=8, time_horizon=8, interval_length=1)
25 |
26 |
27 | class PlanningModel(TorchModuleWrapper):
28 | def __init__(
29 | self,
30 | dim=128,
31 | state_channel=6,
32 | polygon_channel=6,
33 | history_channel=9,
34 | history_steps=21,
35 | future_steps=80,
36 | encoder_depth=4,
37 | decoder_depth=4,
38 | drop_path=0.2,
39 | dropout=0.1,
40 | num_heads=8,
41 | num_modes=6,
42 | use_ego_history=False,
43 | state_attn_encoder=True,
44 | state_dropout=0.75,
45 | use_hidden_proj=False,
46 | cat_x=False,
47 | ref_free_traj=False,
48 | feature_builder: PlutoFeatureBuilder = PlutoFeatureBuilder(),
49 | ) -> None:
50 | super().__init__(
51 | feature_builders=[feature_builder],
52 | target_builders=[EgoTrajectoryTargetBuilder(trajectory_sampling)],
53 | future_trajectory_sampling=trajectory_sampling,
54 | )
55 |
56 | self.dim = dim
57 | self.history_steps = history_steps
58 | self.future_steps = future_steps
59 | self.use_hidden_proj = use_hidden_proj
60 | self.num_modes = num_modes
61 | self.radius = feature_builder.radius
62 | self.ref_free_traj = ref_free_traj
63 |
64 | self.pos_emb = FourierEmbedding(3, dim, 64)
65 |
66 | self.agent_encoder = AgentEncoder(
67 | state_channel=state_channel,
68 | history_channel=history_channel,
69 | dim=dim,
70 | hist_steps=history_steps,
71 | drop_path=drop_path,
72 | use_ego_history=use_ego_history,
73 | state_attn_encoder=state_attn_encoder,
74 | state_dropout=state_dropout,
75 | )
76 |
77 | self.map_encoder = MapEncoder(
78 | dim=dim,
79 | polygon_channel=polygon_channel,
80 | use_lane_boundary=True,
81 | )
82 |
83 | self.static_objects_encoder = StaticObjectsEncoder(dim=dim)
84 |
85 | self.encoder_blocks = nn.ModuleList(
86 | TransformerEncoderLayer(dim=dim, num_heads=num_heads, drop_path=dp)
87 | for dp in [x.item() for x in torch.linspace(0, drop_path, encoder_depth)]
88 | )
89 | self.norm = nn.LayerNorm(dim)
90 |
91 | self.agent_predictor = AgentPredictor(dim=dim, future_steps=future_steps)
92 | self.planning_decoder = PlanningDecoder(
93 | num_mode=num_modes,
94 | decoder_depth=decoder_depth,
95 | dim=dim,
96 | num_heads=num_heads,
97 | mlp_ratio=4,
98 | dropout=dropout,
99 | cat_x=cat_x,
100 | future_steps=future_steps,
101 | )
102 |
103 | if use_hidden_proj:
104 | self.hidden_proj = nn.Sequential(
105 | nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)
106 | )
107 |
108 | if self.ref_free_traj:
109 | self.ref_free_decoder = MLPLayer(dim, 2 * dim, future_steps * 4)
110 |
111 | self.apply(self._init_weights)
112 |
113 | def _init_weights(self, m):
114 | if isinstance(m, nn.Linear):
115 | torch.nn.init.xavier_uniform_(m.weight)
116 | if isinstance(m, nn.Linear) and m.bias is not None:
117 | nn.init.constant_(m.bias, 0)
118 | elif isinstance(m, nn.LayerNorm):
119 | nn.init.constant_(m.bias, 0)
120 | nn.init.constant_(m.weight, 1.0)
121 | elif isinstance(m, nn.BatchNorm1d):
122 | nn.init.ones_(m.weight)
123 | nn.init.zeros_(m.bias)
124 | elif isinstance(m, nn.Embedding):
125 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
126 |
127 | def forward(self, data):
128 | agent_pos = data["agent"]["position"][:, :, self.history_steps - 1]
129 | agent_heading = data["agent"]["heading"][:, :, self.history_steps - 1]
130 | agent_mask = data["agent"]["valid_mask"][:, :, : self.history_steps]
131 | polygon_center = data["map"]["polygon_center"]
132 | polygon_mask = data["map"]["valid_mask"]
133 |
134 | bs, A = agent_pos.shape[0:2]
135 |
136 | position = torch.cat([agent_pos, polygon_center[..., :2]], dim=1)
137 | angle = torch.cat([agent_heading, polygon_center[..., 2]], dim=1)
138 | angle = (angle + math.pi) % (2 * math.pi) - math.pi
139 | pos = torch.cat([position, angle.unsqueeze(-1)], dim=-1)
140 |
141 | agent_key_padding = ~(agent_mask.any(-1))
142 | polygon_key_padding = ~(polygon_mask.any(-1))
143 | key_padding_mask = torch.cat([agent_key_padding, polygon_key_padding], dim=-1)
144 |
145 | x_agent = self.agent_encoder(data)
146 | x_polygon = self.map_encoder(data)
147 | x_static, static_pos, static_key_padding = self.static_objects_encoder(data)
148 |
149 | x = torch.cat([x_agent, x_polygon, x_static], dim=1)
150 |
151 | pos = torch.cat([pos, static_pos], dim=1)
152 | pos_embed = self.pos_emb(pos)
153 |
154 | key_padding_mask = torch.cat([key_padding_mask, static_key_padding], dim=-1)
155 | x = x + pos_embed
156 |
157 | for blk in self.encoder_blocks:
158 | x = blk(x, key_padding_mask=key_padding_mask, return_attn_weights=False)
159 | x = self.norm(x)
160 |
161 | prediction = self.agent_predictor(x[:, 1:A])
162 |
163 | ref_line_available = data["reference_line"]["position"].shape[1] > 0
164 |
165 | if ref_line_available:
166 | trajectory, probability = self.planning_decoder(
167 | data, {"enc_emb": x, "enc_key_padding_mask": key_padding_mask}
168 | )
169 | else:
170 | trajectory, probability = None, None
171 |
172 | out = {
173 | "trajectory": trajectory,
174 | "probability": probability, # (bs, R, M)
175 | "prediction": prediction, # (bs, A-1, T, 2)
176 | }
177 |
178 | if self.use_hidden_proj:
179 | out["hidden"] = self.hidden_proj(x[:, 0])
180 |
181 | if self.ref_free_traj:
182 | ref_free_traj = self.ref_free_decoder(x[:, 0]).reshape(
183 | bs, self.future_steps, 4
184 | )
185 | out["ref_free_trajectory"] = ref_free_traj
186 |
187 | if not self.training:
188 | if self.ref_free_traj:
189 | ref_free_traj_angle = torch.arctan2(
190 | ref_free_traj[..., 3], ref_free_traj[..., 2]
191 | )
192 | ref_free_traj = torch.cat(
193 | [ref_free_traj[..., :2], ref_free_traj_angle.unsqueeze(-1)], dim=-1
194 | )
195 | out["output_ref_free_trajectory"] = ref_free_traj
196 |
197 | output_prediction = torch.cat(
198 | [
199 | prediction[..., :2] + agent_pos[:, 1:A, None],
200 | torch.atan2(prediction[..., 3], prediction[..., 2]).unsqueeze(-1)
201 | + agent_heading[:, 1:A, None, None],
202 | prediction[..., 4:6],
203 | ],
204 | dim=-1,
205 | )
206 | out["output_prediction"] = output_prediction
207 |
208 | if trajectory is not None:
209 | r_padding_mask = ~data["reference_line"]["valid_mask"].any(-1)
210 | probability.masked_fill_(r_padding_mask.unsqueeze(-1), -1e6)
211 |
212 | angle = torch.atan2(trajectory[..., 3], trajectory[..., 2])
213 | out_trajectory = torch.cat(
214 | [trajectory[..., :2], angle.unsqueeze(-1)], dim=-1
215 | )
216 |
217 | bs, R, M, T, _ = out_trajectory.shape
218 | flattened_probability = probability.reshape(bs, R * M)
219 | best_trajectory = out_trajectory.reshape(bs, R * M, T, -1)[
220 | torch.arange(bs), flattened_probability.argmax(-1)
221 | ]
222 |
223 | out["output_trajectory"] = best_trajectory
224 | out["candidate_trajectories"] = out_trajectory
225 | else:
226 | out["output_trajectory"] = out["output_ref_free_trajectory"]
227 | out["probability"] = torch.zeros(1, 0, 0)
228 | out["candidate_trajectories"] = torch.zeros(
229 | 1, 0, 0, self.future_steps, 3
230 | )
231 |
232 | return out
233 |
--------------------------------------------------------------------------------
/src/optim/warmup_cos_lr.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class WarmupCosLR(_LRScheduler):
7 | def __init__(
8 | self, optimizer, min_lr, lr, warmup_epochs, epochs, last_epoch=-1, verbose=False
9 | ) -> None:
10 | self.min_lr = min_lr
11 | self.lr = lr
12 | self.epochs = epochs
13 | self.warmup_epochs = warmup_epochs
14 | super(WarmupCosLR, self).__init__(optimizer, last_epoch, verbose)
15 |
16 | def state_dict(self):
17 | """Returns the state of the scheduler as a :class:`dict`.
18 |
19 | It contains an entry for every variable in self.__dict__ which
20 | is not the optimizer.
21 | """
22 | return {
23 | key: value for key, value in self.__dict__.items() if key != "optimizer"
24 | }
25 |
26 | def load_state_dict(self, state_dict):
27 | """Loads the schedulers state.
28 |
29 | Args:
30 | state_dict (dict): scheduler state. Should be an object returned
31 | from a call to :meth:`state_dict`.
32 | """
33 | self.__dict__.update(state_dict)
34 |
35 | def get_init_lr(self):
36 | lr = self.lr / self.warmup_epochs
37 | return lr
38 |
39 | def get_lr(self):
40 | if self.last_epoch < self.warmup_epochs:
41 | lr = self.lr * (self.last_epoch + 1) / self.warmup_epochs
42 | else:
43 | lr = self.min_lr + 0.5 * (self.lr - self.min_lr) * (
44 | 1
45 | + math.cos(
46 | math.pi
47 | * (self.last_epoch - self.warmup_epochs)
48 | / (self.epochs - self.warmup_epochs)
49 | )
50 | )
51 | if "lr_scale" in self.optimizer.param_groups[0]:
52 | return [lr * group["lr_scale"] for group in self.optimizer.param_groups]
53 |
54 | return [lr for _ in self.optimizer.param_groups]
55 |
--------------------------------------------------------------------------------
/src/planners/ml_planner_utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Deque, List, Tuple
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import numpy.typing as npt
7 | import torch
8 | from nuplan.common.actor_state.ego_state import EgoState
9 | from nuplan.common.actor_state.state_representation import StateSE2
10 | from nuplan.planning.simulation.planner.ml_planner.transform_utils import (
11 | _get_fixed_timesteps,
12 | _get_velocity_and_acceleration,
13 | _se2_vel_acc_to_ego_state,
14 | )
15 |
16 |
17 | def normalize_angle(angle):
18 | return (angle + np.pi) % (2 * np.pi) - np.pi
19 |
20 |
21 | def global_trajectory_to_states(
22 | global_trajectory: npt.NDArray[np.float32],
23 | ego_history: Deque[EgoState],
24 | future_horizon: float,
25 | step_interval: float,
26 | include_ego_state: bool = True,
27 | ):
28 | ego_state = ego_history[-1]
29 | timesteps = _get_fixed_timesteps(ego_state, future_horizon, step_interval)
30 | global_states = [StateSE2.deserialize(pose) for pose in global_trajectory]
31 |
32 | velocities, accelerations = _get_velocity_and_acceleration(
33 | global_states, ego_history, timesteps
34 | )
35 | agent_states = [
36 | _se2_vel_acc_to_ego_state(
37 | state,
38 | velocity,
39 | acceleration,
40 | timestep,
41 | ego_state.car_footprint.vehicle_parameters,
42 | )
43 | for state, velocity, acceleration, timestep in zip(
44 | global_states, velocities, accelerations, timesteps
45 | )
46 | ]
47 |
48 | if include_ego_state:
49 | agent_states.insert(0, ego_state)
50 | else:
51 | init_state = deepcopy(agent_states[0])
52 | init_state._time_point = ego_state.time_point
53 | agent_states.insert(0, init_state)
54 |
55 | return agent_states
56 |
57 |
58 | def load_checkpoint(checkpoint: str):
59 | ckpt = torch.load(checkpoint, map_location=torch.device("cpu"))
60 | state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
61 | return state_dict
62 |
--------------------------------------------------------------------------------
/src/post_processing/common/enum.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | from enum import IntEnum
6 |
7 | import numpy as np
8 | from nuplan.common.actor_state.ego_state import EgoState
9 |
10 |
11 | class StateIndex:
12 | _X = 0
13 | _Y = 1
14 | _HEADING = 2
15 | _VELOCITY_X = 3
16 | _VELOCITY_Y = 4
17 | _ACCELERATION_X = 5
18 | _ACCELERATION_Y = 6
19 | _STEERING_ANGLE = 7
20 | _STEERING_RATE = 8
21 | _ANGULAR_VELOCITY = 9
22 | _ANGULAR_ACCELERATION = 10
23 |
24 | @classmethod
25 | def size(cls):
26 | valid_attributes = [
27 | attribute
28 | for attribute in dir(cls)
29 | if attribute.startswith("_")
30 | and not attribute.startswith("__")
31 | and not callable(getattr(cls, attribute))
32 | ]
33 | return len(valid_attributes)
34 |
35 | @classmethod
36 | @property
37 | def X(cls):
38 | return cls._X
39 |
40 | @classmethod
41 | @property
42 | def Y(cls):
43 | return cls._Y
44 |
45 | @classmethod
46 | @property
47 | def HEADING(cls):
48 | return cls._HEADING
49 |
50 | @classmethod
51 | @property
52 | def VELOCITY_X(cls):
53 | return cls._VELOCITY_X
54 |
55 | @classmethod
56 | @property
57 | def VELOCITY_Y(cls):
58 | return cls._VELOCITY_Y
59 |
60 | @classmethod
61 | @property
62 | def ACCELERATION_X(cls):
63 | return cls._ACCELERATION_X
64 |
65 | @classmethod
66 | @property
67 | def ACCELERATION_Y(cls):
68 | return cls._ACCELERATION_Y
69 |
70 | @classmethod
71 | @property
72 | def STEERING_ANGLE(cls):
73 | return cls._STEERING_ANGLE
74 |
75 | @classmethod
76 | @property
77 | def STEERING_RATE(cls):
78 | return cls._STEERING_RATE
79 |
80 | @classmethod
81 | @property
82 | def ANGULAR_VELOCITY(cls):
83 | return cls._ANGULAR_VELOCITY
84 |
85 | @classmethod
86 | @property
87 | def ANGULAR_ACCELERATION(cls):
88 | return cls._ANGULAR_ACCELERATION
89 |
90 | @classmethod
91 | @property
92 | def POINT(cls):
93 | # assumes X, Y have subsequent indices
94 | return slice(cls._X, cls._Y + 1)
95 |
96 | @classmethod
97 | @property
98 | def STATE_SE2(cls):
99 | # assumes X, Y, HEADING have subsequent indices
100 | return slice(cls._X, cls._HEADING + 1)
101 |
102 | @classmethod
103 | @property
104 | def VELOCITY_2D(cls):
105 | # assumes velocity X, Y have subsequent indices
106 | return slice(cls._VELOCITY_X, cls._VELOCITY_Y + 1)
107 |
108 | @classmethod
109 | @property
110 | def ACCELERATION_2D(cls):
111 | # assumes acceleration X, Y have subsequent indices
112 | return slice(cls._ACCELERATION_X, cls._ACCELERATION_Y + 1)
113 |
114 |
115 | class DynamicStateIndex(IntEnum):
116 | ACCELERATION_X = 0
117 | STEERING_RATE = 1
118 |
119 |
120 | class BBCoordsIndex(IntEnum):
121 | FRONT_LEFT = 0
122 | REAR_LEFT = 1
123 | REAR_RIGHT = 2
124 | FRONT_RIGHT = 3
125 | CENTER = 4
126 |
127 |
128 | class EgoAreaIndex(IntEnum):
129 | MULTIPLE_LANES = 0
130 | NON_DRIVABLE_AREA = 1
131 | ONCOMING_TRAFFIC = 2
132 |
133 |
134 | class MultiMetricIndex(IntEnum):
135 | NO_COLLISION = 0
136 | DRIVABLE_AREA = 1
137 | DRIVING_DIRECTION = 2
138 |
139 |
140 | class WeightedMetricIndex(IntEnum):
141 | PROGRESS = 0
142 | SPEED_LIMIT = 1
143 | COMFORTABLE = 2
144 | TTC = 3
145 |
146 |
147 | class CollisionType(IntEnum):
148 | """Enum for the types of collisions of interest."""
149 |
150 | STOPPED_EGO_COLLISION = 0
151 | STOPPED_TRACK_COLLISION = 1
152 | ACTIVE_FRONT_COLLISION = 2
153 | ACTIVE_REAR_COLLISION = 3
154 | ACTIVE_LATERAL_COLLISION = 4
155 |
156 |
157 | def ego_state_to_state_array(ego_state: EgoState):
158 | """
159 | Converts an ego state into an array representation (drops time-stamps and vehicle parameters)
160 | :param ego_state: ego state class
161 | :return: array containing ego state values
162 | """
163 | state_array = np.zeros(StateIndex.size(), dtype=np.float64)
164 |
165 | state_array[StateIndex.STATE_SE2] = ego_state.rear_axle.serialize()
166 | state_array[
167 | StateIndex.VELOCITY_2D
168 | ] = ego_state.dynamic_car_state.rear_axle_velocity_2d.array
169 | state_array[
170 | StateIndex.ACCELERATION_2D
171 | ] = ego_state.dynamic_car_state.rear_axle_acceleration_2d.array
172 |
173 | state_array[StateIndex.STEERING_ANGLE] = ego_state.tire_steering_angle
174 | state_array[
175 | StateIndex.STEERING_RATE
176 | ] = ego_state.dynamic_car_state.tire_steering_rate
177 |
178 | state_array[
179 | StateIndex.ANGULAR_VELOCITY
180 | ] = ego_state.dynamic_car_state.angular_velocity
181 | state_array[
182 | StateIndex.ANGULAR_ACCELERATION
183 | ] = ego_state.dynamic_car_state.angular_acceleration
184 |
185 | return state_array
186 |
--------------------------------------------------------------------------------
/src/post_processing/common/geometry.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | from typing import Dict, Any
6 |
7 | import numpy as np
8 | from nuplan.common.actor_state.state_representation import StateSE2
9 | from shapely import Polygon, LineString
10 |
11 | from .enum import CollisionType, StateIndex
12 | from nuplan.planning.simulation.observation.idm.utils import is_agent_behind
13 |
14 |
15 | def compute_agents_vertices(
16 | center: np.ndarray,
17 | angle: np.ndarray,
18 | shape: np.ndarray,
19 | ) -> np.ndarray:
20 | """
21 | Args:
22 | position: (N, T, 2)
23 | angle: (N, T)
24 | shape: (N, 2) [width, length]
25 | Returns:
26 | 4 corners of oriented box (FL, RL, RR, FR)
27 | vertices: (N, T, 4, 2)
28 | """
29 | # Extracting dimensions
30 | N, T = center.shape[0], center.shape[1]
31 |
32 | # Reshaping the arrays for calculations
33 | center = center.reshape(N * T, 2)
34 | angle = angle.reshape(N * T)
35 |
36 | if shape.ndim == 2:
37 | shape = (shape / 2).repeat(T, axis=0)
38 | else:
39 | shape = (shape / 2).reshape(N * T, 2)
40 |
41 | # Calculating half width and half_l
42 | half_w = shape[:, 0]
43 | half_l = shape[:, 1]
44 |
45 | # Calculating cos and sin of angles
46 | cos_angle = np.cos(angle)[:, None]
47 | sin_angle = np.sin(angle)[:, None]
48 | rot_mat = np.stack([cos_angle, sin_angle, -sin_angle, cos_angle], axis=-1).reshape(
49 | N * T, 2, 2
50 | )
51 |
52 | offset_width = np.stack([half_w, half_w, -half_w, -half_w], axis=-1)
53 | offset_length = np.stack([half_l, -half_l, -half_l, half_l], axis=-1)
54 |
55 | vertices = np.stack([offset_length, offset_width], axis=-1)
56 | vertices = np.matmul(vertices, rot_mat) + center[:, None]
57 |
58 | # Calculating vertices
59 | vertices = vertices.reshape(N, T, 4, 2)
60 |
61 | return vertices
62 |
63 |
64 | def ego_rear_to_center(rear_xy, heading, rear_to_center=1.461):
65 | direction = np.stack([np.cos(heading), np.sin(heading)], axis=-1)
66 | center = rear_xy + direction * rear_to_center
67 | return center
68 |
69 |
70 | def get_sub_polygon(polygon: Polygon, ratio=0.4):
71 | vertices = np.array(polygon.exterior.coords)
72 | return Polygon(
73 | [
74 | vertices[0],
75 | vertices[0] * (1 - ratio) + vertices[1] * ratio,
76 | vertices[3] * (1 - ratio) + vertices[2] * ratio,
77 | vertices[3],
78 | ]
79 | )
80 |
81 |
82 | def get_collision_type(
83 | state: np.ndarray,
84 | ego_polygon: Polygon,
85 | object_info: Dict[str, Any],
86 | stopped_speed_threshold: float = 5e-02,
87 | ) -> CollisionType:
88 | """
89 | Classify collision between ego and the track.
90 | :param ego_state: Ego's state at the current timestamp.
91 | :param tracked_object: Tracked object.
92 | :param stopped_speed_threshold: Threshold for 0 speed due to noise.
93 | :return Collision type.
94 | """
95 |
96 | ego_speed = np.hypot(state[StateIndex.VELOCITY_X], state[StateIndex.VELOCITY_Y])
97 |
98 | is_ego_stopped = float(ego_speed) <= stopped_speed_threshold
99 |
100 | object_pos: np.ndarray = object_info["pose"]
101 | object_velocity: np.ndarray = object_info["velocity"]
102 | object_polygon: Polygon = object_info["polygon"]
103 |
104 | tracked_object_center = StateSE2(*object_pos)
105 |
106 | ego_rear_axle_pose: StateSE2 = StateSE2(*state[StateIndex.STATE_SE2])
107 |
108 | # Collisions at (close-to) zero ego speed
109 | if is_ego_stopped:
110 | collision_type = CollisionType.STOPPED_EGO_COLLISION
111 |
112 | # Collisions at (close-to) zero track speed
113 | elif np.linalg.norm(object_velocity) <= stopped_speed_threshold:
114 | collision_type = CollisionType.STOPPED_TRACK_COLLISION
115 |
116 | # Rear collision when both ego and track are not stopped
117 | elif is_agent_behind(ego_rear_axle_pose, tracked_object_center):
118 | collision_type = CollisionType.ACTIVE_REAR_COLLISION
119 |
120 | # Front bumper collision when both ego and track are not stopped
121 | # elif get_sub_polygon(ego_polygon).intersects(object_polygon):
122 | # collision_type = CollisionType.ACTIVE_FRONT_COLLISION
123 | elif LineString(
124 | [
125 | ego_polygon.exterior.coords[0],
126 | ego_polygon.exterior.coords[3],
127 | ]
128 | ).intersects(object_polygon):
129 | collision_type = CollisionType.ACTIVE_FRONT_COLLISION
130 |
131 | # Lateral collision when both ego and track are not stopped
132 | else:
133 | collision_type = CollisionType.ACTIVE_LATERAL_COLLISION
134 |
135 | return collision_type
136 |
--------------------------------------------------------------------------------
/src/post_processing/emergency_brake.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | from typing import Optional
6 |
7 | import numpy as np
8 | import numpy.typing as npt
9 | from nuplan.common.actor_state.ego_state import EgoState
10 | from nuplan.common.actor_state.state_representation import (
11 | StateSE2,
12 | StateVector2D,
13 | TimePoint,
14 | )
15 | from nuplan.planning.simulation.trajectory.interpolated_trajectory import (
16 | InterpolatedTrajectory,
17 | )
18 | from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
19 |
20 | from .forward_simulation.forward_simulator import ForwardSimulator
21 | from .common.enum import StateIndex
22 |
23 |
24 | class EmergencyBrake:
25 | def __init__(
26 | self,
27 | trajectory_sampling: TrajectorySampling = TrajectorySampling(
28 | num_poses=80, interval_length=0.1
29 | ),
30 | time_to_infraction_threshold: float = 2.0,
31 | max_ego_speed: float = 8.0,
32 | max_long_accel: float = 2.40,
33 | min_long_accel: float = -2.40,
34 | emergency_decel: float = -4.05,
35 | ):
36 | # trajectory parameters
37 | self._trajectory_sampling = trajectory_sampling
38 |
39 | # braking parameters
40 | self._max_ego_speed: float = max_ego_speed # [m/s]
41 | self._max_long_accel: float = max_long_accel # [m/s^2]
42 | self._min_long_accel: float = min_long_accel # [m/s^2]
43 | self._emergency_decel: float = emergency_decel # [m/s^2]
44 |
45 | # braking condition parameters
46 | self._time_to_infraction_threshold: float = time_to_infraction_threshold
47 |
48 | def brake_if_emergency(
49 | self,
50 | ego_state: EgoState,
51 | time_to_at_fault_collision: float,
52 | ego_trajectory: npt.NDArray[np.float64],
53 | ) -> Optional[InterpolatedTrajectory]:
54 | trajectory = None
55 | ego_speed: float = ego_state.dynamic_car_state.speed
56 |
57 | time_to_infraction = time_to_at_fault_collision
58 | min_brake_time = max(ego_speed / abs(self._min_long_accel) + 0.5, 3.0)
59 |
60 | if time_to_infraction <= min_brake_time and ego_speed <= self._max_ego_speed:
61 | print("Emergency Brake")
62 | min_reaction_time = ego_speed / abs(self._emergency_decel) + 0.5
63 | is_soft_brake_possible = time_to_infraction > min_reaction_time
64 | trajectory = self._generate_ebrake_trajectory(
65 | ego_trajectory, ego_state, soft_brake=is_soft_brake_possible
66 | )
67 |
68 | return trajectory
69 |
70 | def _generate_ebrake_trajectory(
71 | self, origin_trajectory: np.ndarray, ego_state: EgoState, soft_brake=False
72 | ):
73 | simulator = ForwardSimulator(
74 | dt=self._trajectory_sampling.interval_length,
75 | num_frames=self._trajectory_sampling.num_poses,
76 | estop=True,
77 | soft_brake=soft_brake,
78 | )
79 | rollout = simulator.forward(origin_trajectory[None, ...], ego_state)[0]
80 |
81 | ego_states, current_time_point = [], ego_state.time_point
82 | delta_t = TimePoint(int(self._trajectory_sampling.interval_length * 1e6))
83 |
84 | for state in rollout:
85 | ego_states.append(
86 | EgoState.build_from_rear_axle(
87 | rear_axle_pose=StateSE2(*state[:3]),
88 | rear_axle_velocity_2d=StateVector2D(*state[StateIndex.VELOCITY_2D]),
89 | rear_axle_acceleration_2d=StateVector2D(
90 | *state[StateIndex.ACCELERATION_2D]
91 | ),
92 | tire_steering_angle=state[StateIndex.STEERING_ANGLE],
93 | time_point=current_time_point,
94 | vehicle_parameters=ego_state.car_footprint.vehicle_parameters,
95 | )
96 | )
97 | current_time_point += delta_t
98 |
99 | return InterpolatedTrajectory(ego_states)
100 |
--------------------------------------------------------------------------------
/src/post_processing/forward_simulation/batch_kinematic_bicycle.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | import copy
6 |
7 | import numpy as np
8 | import numpy.typing as npt
9 | from nuplan.common.actor_state.ego_state import EgoState
10 | from nuplan.common.actor_state.state_representation import TimePoint
11 | from nuplan.common.actor_state.vehicle_parameters import (
12 | VehicleParameters,
13 | get_pacifica_parameters,
14 | )
15 | from nuplan.common.geometry.compute import principal_value
16 | from ..common.enum import DynamicStateIndex, StateIndex
17 |
18 |
19 | def forward_integrate(
20 | init: npt.NDArray[np.float64],
21 | delta: npt.NDArray[np.float64],
22 | sampling_time: TimePoint,
23 | ) -> npt.NDArray[np.float64]:
24 | """
25 | Performs a simple euler integration.
26 | :param init: Initial state
27 | :param delta: The rate of change of the state.
28 | :param sampling_time: The time duration to propagate for.
29 | :return: The result of integration
30 | """
31 | return init + delta * sampling_time.time_s
32 |
33 |
34 | class BatchKinematicBicycleModel:
35 | """
36 | A batch-wise operating class describing the kinematic motion model where the rear axle is the point of reference.
37 | """
38 |
39 | def __init__(
40 | self,
41 | vehicle: VehicleParameters = get_pacifica_parameters(),
42 | max_steering_angle: float = np.pi / 3,
43 | accel_time_constant: float = 0.2,
44 | steering_angle_time_constant: float = 0.05,
45 | ):
46 | """
47 | Construct BatchKinematicBicycleModel.
48 | :param vehicle: Vehicle parameters.
49 | :param max_steering_angle: [rad] Maximum absolute value steering angle allowed by model.
50 | :param accel_time_constant: low pass filter time constant for acceleration in s
51 | :param steering_angle_time_constant: low pass filter time constant for steering angle in s
52 | """
53 | self._vehicle = vehicle
54 | self._max_steering_angle = max_steering_angle
55 | self._accel_time_constant = accel_time_constant
56 | self._steering_angle_time_constant = steering_angle_time_constant
57 |
58 | def get_state_dot(self, states: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
59 | """
60 | Calculates the changing rate of state array representation.
61 | :param states: array describing the state of the ego-vehicle
62 | :return: change rate across several state values
63 | """
64 | state_dots = np.zeros(states.shape, dtype=np.float64)
65 |
66 | longitudinal_speeds = states[:, StateIndex.VELOCITY_X]
67 |
68 | state_dots[:, StateIndex.X] = longitudinal_speeds * np.cos(
69 | states[:, StateIndex.HEADING]
70 | )
71 | state_dots[:, StateIndex.Y] = longitudinal_speeds * np.sin(
72 | states[:, StateIndex.HEADING]
73 | )
74 | state_dots[:, StateIndex.HEADING] = (
75 | longitudinal_speeds
76 | * np.tan(states[:, StateIndex.STEERING_ANGLE])
77 | / self._vehicle.wheel_base
78 | )
79 |
80 | state_dots[:, StateIndex.VELOCITY_2D] = states[:, StateIndex.ACCELERATION_2D]
81 | state_dots[:, StateIndex.ACCELERATION_2D] = 0.0
82 |
83 | state_dots[:, StateIndex.STEERING_ANGLE] = states[:, StateIndex.STEERING_RATE]
84 |
85 | return state_dots
86 |
87 | def _update_commands(
88 | self,
89 | states: npt.NDArray[np.float64],
90 | command_states: npt.NDArray[np.float64],
91 | sampling_time: TimePoint,
92 | ) -> EgoState:
93 | """
94 | This function applies some first order control delay/a low pass filter to acceleration/steering.
95 |
96 | :param state: Ego state
97 | :param ideal_dynamic_state: The desired dynamic state for propagation
98 | :param sampling_time: The time duration to propagate for
99 | :return: propagating_state including updated dynamic_state
100 | """
101 |
102 | propagating_state: npt.NDArray[np.float64] = copy.deepcopy(states)
103 |
104 | dt_control = sampling_time.time_s
105 |
106 | accel = states[:, StateIndex.ACCELERATION_X]
107 | steering_angle = states[:, StateIndex.STEERING_ANGLE]
108 |
109 | ideal_accel_x = command_states[:, DynamicStateIndex.ACCELERATION_X]
110 | ideal_steering_angle = (
111 | dt_control * command_states[:, DynamicStateIndex.STEERING_RATE]
112 | + steering_angle
113 | )
114 |
115 | updated_accel_x = (
116 | dt_control
117 | / (dt_control + self._accel_time_constant)
118 | * (ideal_accel_x - accel)
119 | + accel
120 | )
121 | updated_steering_angle = (
122 | dt_control
123 | / (dt_control + self._steering_angle_time_constant)
124 | * (ideal_steering_angle - steering_angle)
125 | + steering_angle
126 | )
127 | updated_steering_rate = (updated_steering_angle - steering_angle) / dt_control
128 |
129 | propagating_state[:, StateIndex.ACCELERATION_X] = updated_accel_x
130 | propagating_state[:, StateIndex.ACCELERATION_Y] = 0.0
131 | propagating_state[:, StateIndex.STEERING_RATE] = updated_steering_rate
132 |
133 | return propagating_state
134 |
135 | def propagate_state(
136 | self,
137 | states: npt.NDArray[np.float64],
138 | command_states: npt.NDArray[np.float64],
139 | sampling_time: TimePoint,
140 | ) -> npt.NDArray[np.float64]:
141 | """
142 | Propagates ego state array forward with motion model.
143 | :param states: state array representation of the ego-vehicle
144 | :param command_states: command array representation of controller
145 | :param sampling_time: time to propagate [s]
146 | :return: updated tate array representation of the ego-vehicle
147 | """
148 |
149 | assert len(states) == len(
150 | command_states
151 | ), "Batch size of states and command_states does not match!"
152 |
153 | propagating_state = self._update_commands(states, command_states, sampling_time)
154 | output_state = copy.deepcopy(states)
155 |
156 | # Compute state derivatives
157 | state_dot = self.get_state_dot(propagating_state)
158 |
159 | output_state[:, StateIndex.X] = forward_integrate(
160 | states[:, StateIndex.X], state_dot[:, StateIndex.X], sampling_time
161 | )
162 | output_state[:, StateIndex.Y] = forward_integrate(
163 | states[:, StateIndex.Y], state_dot[:, StateIndex.Y], sampling_time
164 | )
165 |
166 | output_state[:, StateIndex.HEADING] = principal_value(
167 | forward_integrate(
168 | states[:, StateIndex.HEADING],
169 | state_dot[:, StateIndex.HEADING],
170 | sampling_time,
171 | )
172 | )
173 |
174 | output_state[:, StateIndex.VELOCITY_X] = forward_integrate(
175 | states[:, StateIndex.VELOCITY_X],
176 | state_dot[:, StateIndex.VELOCITY_X],
177 | sampling_time,
178 | )
179 |
180 | # Lateral velocity is always zero in kinematic bicycle model
181 | output_state[:, StateIndex.VELOCITY_Y] = 0.0
182 |
183 | # Integrate steering angle and clip to bounds
184 | output_state[:, StateIndex.STEERING_ANGLE] = np.clip(
185 | forward_integrate(
186 | propagating_state[:, StateIndex.STEERING_ANGLE],
187 | state_dot[:, StateIndex.STEERING_ANGLE],
188 | sampling_time,
189 | ),
190 | -self._max_steering_angle,
191 | self._max_steering_angle,
192 | )
193 |
194 | output_state[:, StateIndex.ANGULAR_VELOCITY] = (
195 | output_state[:, StateIndex.VELOCITY_X]
196 | * np.tan(output_state[:, StateIndex.STEERING_ANGLE])
197 | / self._vehicle.wheel_base
198 | )
199 |
200 | output_state[:, StateIndex.ACCELERATION_2D] = state_dot[
201 | :, StateIndex.VELOCITY_2D
202 | ]
203 |
204 | output_state[:, StateIndex.ANGULAR_ACCELERATION] = (
205 | output_state[:, StateIndex.ANGULAR_VELOCITY]
206 | - states[:, StateIndex.ANGULAR_VELOCITY]
207 | ) / sampling_time.time_s
208 |
209 | output_state[:, StateIndex.STEERING_RATE] = state_dot[
210 | :, StateIndex.STEERING_ANGLE
211 | ]
212 |
213 | return output_state
214 |
--------------------------------------------------------------------------------
/src/post_processing/forward_simulation/forward_simulator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from nuplan.common.actor_state.ego_state import EgoState
3 | from nuplan.common.actor_state.state_representation import TimeDuration, TimePoint
4 | from nuplan.planning.simulation.simulation_time_controller.simulation_iteration import (
5 | SimulationIteration,
6 | )
7 |
8 | from .batch_kinematic_bicycle import BatchKinematicBicycleModel
9 | from .batch_lqr import BatchLQRTracker
10 | from ..common.enum import StateIndex, ego_state_to_state_array
11 |
12 |
13 | class ForwardSimulator:
14 | def __init__(
15 | self,
16 | dt: float = 0.1,
17 | num_frames: int = 40,
18 | estop: bool = False,
19 | soft_brake: bool = False,
20 | ) -> None:
21 | self.dt = dt
22 | self.interval = int(dt * 10)
23 | self.num_frames = num_frames
24 | self.motion_model = BatchKinematicBicycleModel()
25 | self.tracker = BatchLQRTracker(
26 | discretization_time=dt,
27 | tracking_horizon=int(1 / dt),
28 | estop=estop,
29 | soft_brake=soft_brake,
30 | )
31 |
32 | def forward(self, candidate_trajectories: np.ndarray, init_ego_state: EgoState):
33 | """
34 | Args:
35 | candidate_trajectories: (N, 80+1, S), sampled at 10 Hz
36 | """
37 | N = candidate_trajectories.shape[0]
38 | rollout_states = np.zeros(
39 | (N, self.num_frames + 1, StateIndex.size()), dtype=np.float64
40 | )
41 | rollout_states[:, 0] = ego_state_to_state_array(init_ego_state)
42 |
43 | t_now = init_ego_state.time_point
44 | delta_t = TimeDuration.from_s(self.dt)
45 |
46 | current_iteration = SimulationIteration(t_now, 0)
47 | next_iteration = SimulationIteration(t_now + delta_t, 1)
48 |
49 | self.tracker.update(candidate_trajectories[:, :: self.interval])
50 |
51 | for t in range(1, self.num_frames + 1):
52 | sampling_time: TimePoint = (
53 | next_iteration.time_point - current_iteration.time_point
54 | )
55 | command_states = self.tracker.track_trajectory(
56 | current_iteration,
57 | next_iteration,
58 | rollout_states[:, t - 1],
59 | )
60 |
61 | rollout_states[:, t] = self.motion_model.propagate_state(
62 | states=rollout_states[:, t - 1],
63 | command_states=command_states,
64 | sampling_time=sampling_time,
65 | )
66 |
67 | current_iteration = next_iteration
68 | next_iteration = SimulationIteration(
69 | current_iteration.time_point + delta_t, 1 + t
70 | )
71 |
72 | return rollout_states
73 |
--------------------------------------------------------------------------------
/src/post_processing/observation/world_from_prediction.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | from typing import Dict, List, Optional
6 |
7 | import numpy as np
8 | import shapely.creation
9 | from nuplan.common.actor_state.ego_state import EgoState
10 | from nuplan.common.maps.abstract_map_objects import LaneGraphEdgeMapObject
11 | from nuplan.common.maps.maps_datatypes import (
12 | TrafficLightStatusData,
13 | TrafficLightStatusType,
14 | )
15 | from nuplan.planning.simulation.observation.observation_type import DetectionsTracks
16 |
17 | from src.scenario_manager.occupancy_map import OccupancyMap
18 |
19 | from ..common.geometry import compute_agents_vertices
20 |
21 |
22 | class WorldFromPrediction:
23 | def __init__(self, dt=0.1, num_frames=40, base_radius=50) -> None:
24 | self.dt = dt
25 | self.num_frames = num_frames
26 | self.interval = int(dt // 0.1)
27 |
28 | # todo: determined by velocity
29 | self.radius = max(base_radius * dt * num_frames / 4, base_radius)
30 |
31 | self.occupancy_map: Optional[List[OccupancyMap]] = None
32 | self.drivable_area: Optional[OccupancyMap] = None
33 | self.objects_info: Optional[Dict[str, np.ndarray]] = None
34 |
35 | self.red_light_prefix = "red_light"
36 |
37 | self.collided_tokens = None
38 | self._static_object_tokens = None
39 | self._agent_tokens = None
40 | self._ego_state = None
41 |
42 | def __getitem__(self, idx: int) -> OccupancyMap:
43 | assert 0 <= idx < len(self.occupancy_map), "index out of range"
44 | return self.occupancy_map[idx]
45 |
46 | def __len__(self) -> int:
47 | return len(self.occupancy_map)
48 |
49 | def update(
50 | self,
51 | ego_state: EgoState,
52 | detections: DetectionsTracks,
53 | traffic_light_data: List[TrafficLightStatusData],
54 | agents_info: Dict[str, np.ndarray],
55 | route_lane_dict: Dict[str, LaneGraphEdgeMapObject],
56 | ):
57 | self._ego_state = ego_state
58 | self.occupancy_map: List[OccupancyMap] = []
59 |
60 | tl_tokens, tl_polygons = self._get_route_red_traffic_lights(
61 | traffic_light_data, route_lane_dict
62 | )
63 | statics_tokens, statics_polygon = self._get_static_obstacles(
64 | ego_state, detections
65 | )
66 | agents_tokens, agents_vertices = self._get_dynamic_agents_from_prediction(
67 | agents_info
68 | )
69 | has_agents = len(agents_tokens) > 0
70 | agents_polygon = np.array([], dtype=np.object_)
71 |
72 | for i in range(self.num_frames + 1):
73 | if has_agents:
74 | agents_polygon = agents_vertices[:, i]
75 | agents_polygon = shapely.creation.polygons(agents_polygon)
76 |
77 | frame_tokens = statics_tokens + agents_tokens + tl_tokens
78 | frame_polygons = np.concatenate(
79 | [statics_polygon, agents_polygon, tl_polygons], axis=0
80 | )
81 | frame_occupancy_map = OccupancyMap(
82 | tokens=frame_tokens, geometries=frame_polygons
83 | )
84 | self.occupancy_map.append(frame_occupancy_map)
85 |
86 | # update initial ego collision status
87 | self.collided_tokens = []
88 | ego_polygon = ego_state.car_footprint.geometry
89 | intersect_tokens = self.occupancy_map[0].intersects(ego_polygon)
90 | for token in intersect_tokens:
91 | if token.startswith(self.red_light_prefix):
92 | if not ego_polygon.within(self.occupancy_map[0][token]):
93 | continue
94 | self.collided_tokens.append(token)
95 |
96 | self._static_object_tokens = set(statics_tokens)
97 | self._agent_tokens = set(agents_tokens)
98 |
99 | def _get_static_obstacles(self, ego_state: EgoState, detections: DetectionsTracks):
100 | self._static_objects = {}
101 | tokens, polygons = [], []
102 |
103 | for static_obstacle in detections.tracked_objects.get_static_objects():
104 | if (
105 | np.linalg.norm(static_obstacle.center.array - ego_state.center.array)
106 | > self.radius
107 | ):
108 | continue
109 | if len(self.drivable_area.intersects(static_obstacle.box.geometry)) > 0:
110 | tokens.append(static_obstacle.track_token)
111 | polygons.append(static_obstacle.box.geometry)
112 | self._static_objects[static_obstacle.track_token] = static_obstacle
113 |
114 | if len(tokens) == 0:
115 | polygons = np.array([], dtype=np.object_)
116 |
117 | return tokens, polygons
118 |
119 | def _get_dynamic_agents_from_prediction(
120 | self,
121 | agents_info: Dict[str, np.ndarray],
122 | ):
123 | tokens = agents_info["tokens"]
124 | shape = agents_info["shape"]
125 | category = agents_info["category"]
126 | velocity = agents_info["velocity"]
127 | predictions = agents_info["predictions"][:, :: self.interval]
128 |
129 | agents_vertices = compute_agents_vertices(
130 | center=predictions[..., :2], angle=predictions[..., 2], shape=shape
131 | )
132 |
133 | self._agents_info = {}
134 |
135 | for i in range(len(tokens)):
136 | self._agents_info[tokens[i]] = {
137 | "shape": shape[i],
138 | "velocity": velocity[i],
139 | "prediction": predictions[i],
140 | }
141 |
142 | return tokens, agents_vertices
143 |
144 | def _get_route_red_traffic_lights(
145 | self,
146 | traffic_light_data: List[TrafficLightStatusData],
147 | route_lane_dict: Dict[str, LaneGraphEdgeMapObject],
148 | ):
149 | tokens, polygons = [], []
150 |
151 | for data in traffic_light_data:
152 | if data.status != TrafficLightStatusType.RED:
153 | continue
154 | lane_connector_id = str(data.lane_connector_id)
155 | if lane_connector_id in route_lane_dict.keys():
156 | lane_connector = route_lane_dict[lane_connector_id]
157 | tokens.append(f"{self.red_light_prefix}_{lane_connector_id}")
158 | polygons.append(lane_connector.polygon)
159 |
160 | if len(tokens) == 0:
161 | polygons = np.array([], dtype=np.object_)
162 |
163 | return tokens, polygons
164 |
165 | def get_object_at_frame(self, token, frame_idx):
166 | if token in self._static_object_tokens:
167 | return {
168 | "is_agent": False,
169 | "pose": self._static_objects[token].center,
170 | "velocity": np.zeros(2),
171 | "polygon": self.occupancy_map[frame_idx][token],
172 | }
173 | else:
174 | return {
175 | "is_agent": True,
176 | "pose": self._agents_info[token]["prediction"][frame_idx],
177 | "velocity": self._agents_info[token]["velocity"][frame_idx],
178 | "polygon": self.occupancy_map[frame_idx][token],
179 | }
180 |
--------------------------------------------------------------------------------
/src/scenario_manager/cost_map_manager.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Dict, List, Set
3 |
4 | import cv2
5 | import numpy as np
6 | from nuplan.common.actor_state.state_representation import Point2D
7 | from nuplan.common.actor_state.static_object import StaticObject
8 | from nuplan.common.actor_state.tracked_objects import TrackedObjects
9 | from nuplan.common.maps.abstract_map import AbstractMap
10 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer
11 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
12 | from scipy import ndimage
13 | from shapely import Polygon
14 |
15 | DA = [SemanticMapLayer.LANE, SemanticMapLayer.LANE_CONNECTOR]
16 |
17 |
18 | class CostMapManager:
19 | def __init__(
20 | self,
21 | origin: np.ndarray,
22 | angle: float,
23 | map_api: AbstractMap,
24 | height: int = 500,
25 | width: int = 500,
26 | resolution: float = 0.2,
27 | ) -> None:
28 | self.map_api = map_api
29 | self.height = height
30 | self.width = width
31 | self.resolution = resolution
32 | self.resolution_hw = np.array([resolution, -resolution], dtype=np.float32)
33 | self.origin = origin
34 | self.angle = angle
35 | self.offset = np.array([height / 2, width / 2], dtype=np.float32)
36 | self.rot_mat = np.array(
37 | [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
38 | dtype=np.float64,
39 | )
40 |
41 | @classmethod
42 | def from_scenario(cls, scenario: AbstractScenario):
43 | ego_state = scenario.initial_ego_state
44 | origin = ego_state.rear_axle.point.array
45 | angle = ego_state.rear_axle.heading
46 |
47 | return cls(origin=origin, angle=angle, map_api=scenario.map_api)
48 |
49 | def build_cost_maps(
50 | self,
51 | static_objects: list[StaticObject],
52 | agents: Dict[str, np.ndarray] = None,
53 | agents_polygon: List[Polygon] = None,
54 | route_roadblock_ids: Set[str] = None,
55 | ):
56 | drivable_area_mask = np.zeros((self.height, self.width), dtype=np.uint8)
57 | speed_limit_mask = np.zeros((self.height, self.width), dtype=np.float32)
58 |
59 | radius = max(self.height, self.width) * self.resolution / 2
60 | da_objects_dict = self.map_api.get_proximal_map_objects(
61 | Point2D(*self.origin), radius, DA
62 | )
63 | da_objects = itertools.chain.from_iterable(da_objects_dict.values())
64 |
65 | for obj in da_objects:
66 | self.fill_polygon(drivable_area_mask, obj.polygon, value=1)
67 |
68 | speed_limit_mps = obj.speed_limit_mps if obj.speed_limit_mps else 50
69 | self.fill_polygon(speed_limit_mask, obj.polygon, value=speed_limit_mps)
70 |
71 | for static_ojb in static_objects:
72 | if np.linalg.norm(static_ojb.center.array - self.origin, axis=-1) > radius:
73 | continue
74 | self.fill_convex_polygon(
75 | drivable_area_mask, static_ojb.box.geometry, value=0
76 | )
77 |
78 | if agents is not None:
79 | # parking vehicles as static obstacles
80 | position = agents["position"]
81 | valid_mask = agents["valid_mask"]
82 | for pos, mask, polygon in zip(position, valid_mask, agents_polygon):
83 | if mask.sum() < 50:
84 | continue
85 | pos = pos[mask]
86 | displacement = np.linalg.norm(pos[-1] - pos[0])
87 | if displacement < 1.0:
88 | self.fill_convex_polygon(drivable_area_mask, polygon, value=0)
89 |
90 | distance = ndimage.distance_transform_edt(drivable_area_mask)
91 | inv_distance = ndimage.distance_transform_edt(1 - drivable_area_mask)
92 | drivable_area_sdf = distance - inv_distance
93 | drivable_area_sdf *= self.resolution
94 |
95 | return {
96 | "cost_maps": drivable_area_sdf[:, :, None].astype(np.float16), # (H, W. C)
97 | }
98 |
99 | def global_to_pixel(self, coord: np.ndarray):
100 | coord = np.matmul(coord - self.origin, self.rot_mat)
101 | coord = coord / self.resolution_hw + self.offset
102 | return coord
103 |
104 | def fill_polygon(self, mask, polygon, value=1):
105 | polygon = self.global_to_pixel(np.stack(polygon.exterior.coords.xy, axis=1))
106 | cv2.fillPoly(mask, [np.round(polygon).astype(np.int32)], value)
107 |
108 | def fill_convex_polygon(self, mask, polygon, value=1):
109 | polygon = self.global_to_pixel(np.stack(polygon.exterior.coords.xy, axis=1))
110 | cv2.fillConvexPoly(mask, np.round(polygon).astype(np.int32), value)
111 |
112 | def fill_polyline(self, mask, polyline, value=1):
113 | polyline = self.global_to_pixel(polyline)
114 | cv2.polylines(
115 | mask,
116 | [np.round(polyline.reshape(-1, 1, 2)).astype(np.int32)],
117 | isClosed=False,
118 | color=value,
119 | thickness=1,
120 | )
121 |
--------------------------------------------------------------------------------
/src/scenario_manager/occupancy_map.py:
--------------------------------------------------------------------------------
1 | # Heavily borrowed from:
2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0)
3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0)
4 |
5 | from enum import Enum
6 | from typing import Any, Dict, List
7 |
8 | import numpy as np
9 | import numpy.typing as npt
10 | import shapely.vectorized
11 | from nuplan.planning.simulation.occupancy_map.abstract_occupancy_map import Geometry
12 | from shapely.strtree import STRtree
13 |
14 |
15 | class OccupancyType(Enum):
16 | DYNAMIC = 0, "dynamic"
17 | STATIC = 1, "static"
18 | RED_LIGHT = 2, "red_light"
19 |
20 |
21 | class OccupancyMap:
22 | def __init__(
23 | self,
24 | tokens: List[str],
25 | geometries: npt.NDArray[np.object_],
26 | types: List[Enum] = None,
27 | node_capacity: int = 10,
28 | attribute: Dict[str, Any] = None,
29 | ):
30 | self._tokens: List[str] = tokens
31 | self._types: List[Enum] = types
32 | self._token_to_idx: Dict[str, int] = {
33 | token: idx for idx, token in enumerate(tokens)
34 | }
35 |
36 | self._geometries = geometries
37 | self._attribute = attribute
38 | self._node_capacity = node_capacity
39 | self._str_tree = STRtree(self._geometries, node_capacity)
40 |
41 | def __getitem__(self, token) -> Geometry:
42 | """
43 | Retrieves geometry of token.
44 | :param token: geometry identifier
45 | :return: Geometry of token
46 | """
47 | return self._geometries[self._token_to_idx[token]]
48 |
49 | def __len__(self) -> int:
50 | """
51 | Number of geometries in the occupancy map
52 | :return: int
53 | """
54 | return len(self._tokens)
55 |
56 | def get_type(self, token: str) -> Enum:
57 | """
58 | Retrieves type of token.
59 | :param token: geometry identifier
60 | :return: type of token
61 | """
62 | return self._types[self._token_to_idx[token]]
63 |
64 | @property
65 | def tokens(self) -> List[str]:
66 | """
67 | Getter for track tokens in occupancy map
68 | :return: list of strings
69 | """
70 | return self._tokens
71 |
72 | @property
73 | def token_to_idx(self) -> Dict[str, int]:
74 | """
75 | Getter for track tokens in occupancy map
76 | :return: dictionary of tokens and indices
77 | """
78 | return self._token_to_idx
79 |
80 | def intersects(self, geometry: Geometry) -> List[str]:
81 | """
82 | Searches for intersecting geometries in the occupancy map
83 | :param geometry: geometries to query
84 | :return: list of tokens for intersecting geometries
85 | """
86 | indices = self.query(geometry, predicate="intersects")
87 | return [self._tokens[idx] for idx in indices]
88 |
89 | def get_subset_by_intersection(self, geometry: Geometry) -> "OccupancyMap":
90 | indices = self.query(geometry, predicate="intersects")
91 | polygons = [self._geometries[i] for i in indices]
92 |
93 | return OccupancyMap(self._tokens, polygons)
94 |
95 | def get_subset_by_type(self, type):
96 | assert self._types is not None, "OccupancyMap: No types defined!"
97 | indices = [i for i, t in enumerate(self._types) if t == type]
98 | polygons = [self._geometries[i] for i in indices]
99 |
100 | return OccupancyMap(self._tokens, polygons)
101 |
102 | def query(self, geometry: Geometry, predicate=None):
103 | """
104 | Function to directly calls shapely's query function on str-tree
105 | :param geometry: geometries to query
106 | :param predicate: see shapely, defaults to None
107 | :return: query output
108 | """
109 | return self._str_tree.query(geometry, predicate=predicate)
110 |
111 | def points_in_polygons(
112 | self, points: npt.NDArray[np.float64]
113 | ) -> npt.NDArray[np.bool_]:
114 | """
115 | Determines wether input-points are in polygons of the occupancy map
116 | :param points: input-points
117 | :return: boolean array of shape (polygons, input-points)
118 | """
119 | output = np.zeros((len(self._geometries), len(points)), dtype=bool)
120 | for i, polygon in enumerate(self._geometries):
121 | output[i] = shapely.vectorized.contains(polygon, points[:, 0], points[:, 1])
122 |
123 | return output
124 |
125 | def points_in_polygons_with_attribute(
126 | self, points: npt.NDArray[np.float64], attribute_name: str
127 | ):
128 | """
129 | Determines wether input-points are in polygons of the occupancy map
130 | :param points: input-points
131 | :return: boolean array of shape (polygons, input-points)
132 | """
133 | output = np.zeros((len(self._geometries), len(points)), dtype=bool)
134 | attribute = np.zeros((len(self._geometries), len(points)))
135 | for i, polygon in enumerate(self._geometries):
136 | output[i] = shapely.vectorized.contains(polygon, points[:, 0], points[:, 1])
137 | attribute[i] = self._attribute[attribute_name][i]
138 |
139 | return output, attribute
140 |
--------------------------------------------------------------------------------
/src/scenario_manager/scenario_manager.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from typing import List
4 |
5 | import numpy as np
6 | from nuplan.common.actor_state.ego_state import EgoState
7 | from nuplan.common.actor_state.state_representation import Point2D, StateSE2
8 | from nuplan.common.actor_state.tracked_objects import TrackedObjects
9 | from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType
10 | from nuplan.common.maps.maps_datatypes import (
11 | SemanticMapLayer,
12 | TrafficLightStatusData,
13 | TrafficLightStatusType,
14 | )
15 | from nuplan.planning.simulation.observation.idm.utils import (
16 | create_path_from_se2,
17 | path_to_linestring,
18 | )
19 | from nuplan.planning.simulation.path.utils import trim_path
20 | from shapely.geometry import Point, Polygon
21 | from shapely.geometry.base import CAP_STYLE
22 | from shapely.ops import unary_union
23 |
24 | from .occupancy_map import OccupancyMap, OccupancyType
25 | from .route_manager import RouteManager
26 |
27 | DRIVABLE_LAYERS = {
28 | SemanticMapLayer.ROADBLOCK,
29 | SemanticMapLayer.ROADBLOCK_CONNECTOR,
30 | SemanticMapLayer.CARPARK_AREA,
31 | }
32 |
33 | STATIC_OBJECT_TYPES = {
34 | TrackedObjectType.CZONE_SIGN,
35 | TrackedObjectType.BARRIER,
36 | TrackedObjectType.TRAFFIC_CONE,
37 | TrackedObjectType.GENERIC_OBJECT,
38 | }
39 |
40 | DYNAMIC_OBJECT_TYPES = {
41 | TrackedObjectType.VEHICLE,
42 | TrackedObjectType.BICYCLE,
43 | TrackedObjectType.PEDESTRIAN,
44 | }
45 |
46 |
47 | class ScenarioManager:
48 | def __init__(
49 | self,
50 | map_api,
51 | ego_state: EgoState,
52 | route_roadblocks_ids: List[str],
53 | radius=50,
54 | ) -> None:
55 | self._map_api = map_api
56 | self._radius = radius
57 | self._drivable_area_map: OccupancyMap = None # [lanes, lane connectors]
58 | self._obstacle_map: OccupancyMap = None # [agents, red lights, etc]
59 |
60 | self._ego_state = ego_state
61 | self._ego_path = None
62 | self._ego_path_linestring = None
63 | self._ego_trimmed_path = None
64 | self._ego_progress = None
65 |
66 | self._route_manager = RouteManager(map_api, route_roadblocks_ids, radius)
67 |
68 | @property
69 | def drivable_area_map(self):
70 | return self._drivable_area_map
71 |
72 | def get_route_roadblock_ids(self, process=True) -> List[str]:
73 | if not self._route_manager.initialized:
74 | self._route_manager.load_route(self._ego_state, process)
75 | return self._route_manager.route_roadblock_ids
76 |
77 | def get_route_lane_dicts(self):
78 | assert self._route_manager.initialized
79 | return self._route_manager._route_lane_dict
80 |
81 | def update_ego_state(self, ego_state: EgoState):
82 | self._ego_state = ego_state
83 |
84 | def update_drivable_area_map(self):
85 | """
86 | Builds occupancy map of drivable area.
87 | :param ego_state: EgoState
88 | """
89 |
90 | position: Point2D = self._ego_state.center.point
91 | drivable_area = self._map_api.get_proximal_map_objects(
92 | position, self._radius, DRIVABLE_LAYERS
93 | )
94 |
95 | drivable_polygons: List[Polygon] = []
96 | drivable_polygon_ids: List[str] = []
97 | lane_speed_limit = []
98 |
99 | for road in [SemanticMapLayer.ROADBLOCK, SemanticMapLayer.ROADBLOCK_CONNECTOR]:
100 | for roadblock in drivable_area[road]:
101 | for lane in roadblock.interior_edges:
102 | drivable_polygons.append(lane.polygon)
103 | drivable_polygon_ids.append(lane.id)
104 | speed_limit = lane.speed_limit_mps if lane.speed_limit_mps else 50
105 | lane_speed_limit.append(speed_limit)
106 |
107 | for carpark in drivable_area[SemanticMapLayer.CARPARK_AREA]:
108 | drivable_polygons.append(carpark.polygon)
109 | drivable_polygon_ids.append(carpark.id)
110 | speed_limit = lane.speed_limit_mps if lane.speed_limit_mps else 50
111 | lane_speed_limit.append(speed_limit)
112 |
113 | self._drivable_area_map = OccupancyMap(
114 | drivable_polygon_ids,
115 | drivable_polygons,
116 | attribute={"speed_limit": lane_speed_limit},
117 | )
118 | self._route_manager.update_drivable_area_map(self._drivable_area_map)
119 |
120 | def update_obstacle_map(
121 | self,
122 | detections: TrackedObjects,
123 | traffic_light_status: List[TrafficLightStatusData],
124 | ):
125 | """
126 | Builds occupancy map of obstacles.
127 | :param ego_state: EgoState
128 | """
129 | tokens = []
130 | types = []
131 | polygons = []
132 |
133 | for obj in detections.tracked_objects:
134 | if (
135 | np.linalg.norm(self._ego_state.center.array - obj.center.array)
136 | < self._radius
137 | ):
138 | obj_type = (
139 | OccupancyType.DYNAMIC
140 | if obj.tracked_object_type in DYNAMIC_OBJECT_TYPES
141 | else OccupancyType.STATIC
142 | )
143 | tokens.append(obj.track_token)
144 | types.append(obj_type)
145 | polygons.append(obj.box.geometry)
146 |
147 | for data in traffic_light_status:
148 | if (
149 | data.status == TrafficLightStatusType.RED
150 | and str(data.lane_connector_id) in self._route_manager.route_lane_ids
151 | ):
152 | tokens.append(data.lane_connector_id)
153 | types.append(OccupancyType.RED_LIGHT)
154 | polygons.append(
155 | self._map_api.get_map_object(
156 | str(data.lane_connector_id), SemanticMapLayer.LANE_CONNECTOR
157 | ).polygon
158 | )
159 |
160 | self._obstacle_map = OccupancyMap(tokens, polygons, types)
161 |
162 | def update_ego_path(self, length=50):
163 | ego_path: List[StateSE2] = self._route_manager.get_ego_path(self._ego_state)
164 | self._ego_path = create_path_from_se2(ego_path)
165 | self._ego_path_linestring = path_to_linestring(ego_path)
166 |
167 | start_progress = self._ego_path.get_start_progress()
168 | end_progress = self._ego_path.get_end_progress()
169 |
170 | with warnings.catch_warnings():
171 | # https://github.com/shapely/shapely/issues/1796
172 | warnings.simplefilter("ignore")
173 | self._ego_progress = self._ego_path_linestring.project(
174 | Point([self._ego_state.center.x, self._ego_state.center.y])
175 | )
176 |
177 | trimmed_path = trim_path(
178 | self._ego_path,
179 | max(start_progress, min(self._ego_progress, end_progress)),
180 | min(self._ego_progress + length, end_progress),
181 | )
182 | self._ego_trimmed_path = trimmed_path
183 |
184 | np_path = np.array([p.point.array for p in trimmed_path])
185 | return np_path
186 |
187 | def get_leading_objects(self):
188 | expanded_ego_path = path_to_linestring(self._ego_trimmed_path).buffer(
189 | self._ego_state.car_footprint.width / 2, cap_style=CAP_STYLE.square
190 | )
191 | expanded_ego_path = unary_union(
192 | [expanded_ego_path, self._ego_state.car_footprint.geometry]
193 | )
194 | intersecting_objects = self._obstacle_map.intersects(expanded_ego_path)
195 |
196 | if len(intersecting_objects) == 0:
197 | return []
198 |
199 | leading_objects = []
200 | ego_polygon = self._ego_state.car_footprint.geometry
201 |
202 | for obj_token in intersecting_objects:
203 | leading_objects.append(
204 | (
205 | obj_token,
206 | self._obstacle_map.get_type(obj_token),
207 | self._obstacle_map[obj_token].distance(ego_polygon),
208 | )
209 | )
210 |
211 | self.leading_objects = sorted(leading_objects, key=lambda x: x[2])
212 |
213 | return self.leading_objects
214 |
215 | def get_occupancy_object(self, token: str):
216 | return self._obstacle_map[token]
217 |
218 | def get_ego_path_points(self, start_progress, end_progress):
219 | start_progress += self._ego_progress
220 | end_progress += self._ego_progress
221 | return np.array(
222 | [
223 | np.array([p.x, p.y, p.heading], dtype=np.float64)
224 | for p in self._ego_trimmed_path
225 | if p.progress >= start_progress and p.progress <= end_progress
226 | ]
227 | )
228 |
229 | def get_reference_lines(self, length=100):
230 | return self._route_manager.get_reference_lines(self._ego_state, length=length)
231 |
232 | def get_cached_reference_lines(self):
233 | if self._route_manager.reference_lines:
234 | return self._route_manager.reference_lines
235 | else:
236 | raise ValueError("Reference lines not cached")
237 |
238 | def object_in_drivable_area(self, polygon: Polygon):
239 | return len(self._drivable_area_map.intersects(polygon)) > 0
240 |
--------------------------------------------------------------------------------
/src/scenario_manager/utils/bfs_roadblock.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from typing import Dict, Optional, Tuple, Union, List
3 |
4 | from nuplan.common.maps.abstract_map import AbstractMap
5 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject
6 |
7 |
8 | class BreadthFirstSearchRoadBlock:
9 | """
10 | A class that performs iterative breadth first search. The class operates on the roadblock graph.
11 | """
12 |
13 | def __init__(
14 | self, start_roadblock_id: int, map_api: Optional[AbstractMap], forward_search: str = True
15 | ):
16 | """
17 | Constructor of BreadthFirstSearchRoadBlock class
18 | :param start_roadblock_id: roadblock id where graph starts
19 | :param map_api: map class in nuPlan
20 | :param forward_search: whether to search in driving direction, defaults to True
21 | """
22 | self._map_api: Optional[AbstractMap] = map_api
23 | self._queue = deque([self.id_to_roadblock(start_roadblock_id), None])
24 | self._parent: Dict[str, Optional[RoadBlockGraphEdgeMapObject]] = dict()
25 | self._forward_search = forward_search
26 |
27 | # lazy loaded
28 | self._target_roadblock_ids: List[str] = None
29 |
30 | def search(
31 | self, target_roadblock_id: Union[str, List[str]], max_depth: int
32 | ) -> Tuple[List[RoadBlockGraphEdgeMapObject], bool]:
33 | """
34 | Apply BFS to find route to target roadblock.
35 | :param target_roadblock_id: id of target roadblock
36 | :param max_depth: maximum search depth
37 | :return: tuple of route and whether a path was found
38 | """
39 |
40 | if isinstance(target_roadblock_id, str):
41 | target_roadblock_id = [target_roadblock_id]
42 | self._target_roadblock_ids = target_roadblock_id
43 |
44 | start_edge = self._queue[0]
45 |
46 | # Initial search states
47 | path_found: bool = False
48 | end_edge: RoadBlockGraphEdgeMapObject = start_edge
49 | end_depth: int = 1
50 | depth: int = 1
51 |
52 | self._parent[start_edge.id + f"_{depth}"] = None
53 |
54 | while self._queue:
55 | current_edge = self._queue.popleft()
56 |
57 | # Early exit condition
58 | if self._check_end_condition(depth, max_depth):
59 | break
60 |
61 | # Depth tracking
62 | if current_edge is None:
63 | depth += 1
64 | self._queue.append(None)
65 | if self._queue[0] is None:
66 | break
67 | continue
68 |
69 | # Goal condition
70 | if self._check_goal_condition(current_edge, depth, max_depth):
71 | end_edge = current_edge
72 | end_depth = depth
73 | path_found = True
74 | break
75 |
76 | neighbors = (
77 | current_edge.outgoing_edges if self._forward_search else current_edge.incoming_edges
78 | )
79 |
80 | # Populate queue
81 | for next_edge in neighbors:
82 | # if next_edge.id in self._candidate_lane_edge_ids_old:
83 | self._queue.append(next_edge)
84 | self._parent[next_edge.id + f"_{depth + 1}"] = current_edge
85 | end_edge = next_edge
86 | end_depth = depth + 1
87 |
88 | return self._construct_path(end_edge, end_depth), path_found
89 |
90 | def id_to_roadblock(self, id: str) -> RoadBlockGraphEdgeMapObject:
91 | """
92 | Retrieves roadblock from map-api based on id
93 | :param id: id of roadblock
94 | :return: roadblock class
95 | """
96 | block = self._map_api._get_roadblock(id)
97 | block = block or self._map_api._get_roadblock_connector(id)
98 | return block
99 |
100 | @staticmethod
101 | def _check_end_condition(depth: int, max_depth: int) -> bool:
102 | """
103 | Check if the search should end regardless if the goal condition is met.
104 | :param depth: The current depth to check.
105 | :param target_depth: The target depth to check against.
106 | :return: whether depth exceeds the target depth.
107 | """
108 | return depth > max_depth
109 |
110 | def _check_goal_condition(
111 | self,
112 | current_edge: RoadBlockGraphEdgeMapObject,
113 | depth: int,
114 | max_depth: int,
115 | ) -> bool:
116 | """
117 | Check if the current edge is at the target roadblock at the given depth.
118 | :param current_edge: edge to check.
119 | :param depth: current depth to check.
120 | :param max_depth: maximum depth the edge should be at.
121 | :return: True if the lane edge is contain the in the target roadblock. False, otherwise.
122 | """
123 | return current_edge.id in self._target_roadblock_ids and depth <= max_depth
124 |
125 | def _construct_path(
126 | self, end_edge: RoadBlockGraphEdgeMapObject, depth: int
127 | ) -> List[RoadBlockGraphEdgeMapObject]:
128 | """
129 | Constructs a path when goal was found.
130 | :param end_edge: The end edge to start back propagating back to the start edge.
131 | :param depth: The depth of the target edge.
132 | :return: The constructed path as a list of RoadBlockGraphEdgeMapObject
133 | """
134 | path = [end_edge]
135 | path_id = [end_edge.id]
136 |
137 | while self._parent[end_edge.id + f"_{depth}"] is not None:
138 | path.append(self._parent[end_edge.id + f"_{depth}"])
139 | path_id.append(path[-1].id)
140 | end_edge = self._parent[end_edge.id + f"_{depth}"]
141 | depth -= 1
142 |
143 | if self._forward_search:
144 | path.reverse()
145 | path_id.reverse()
146 |
147 | return (path, path_id)
148 |
--------------------------------------------------------------------------------
/src/scenario_manager/utils/dijkstra.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 | import numpy as np
3 | from nuplan.common.maps.abstract_map_objects import (
4 | LaneGraphEdgeMapObject,
5 | RoadBlockGraphEdgeMapObject,
6 | )
7 |
8 |
9 | class Dijkstra:
10 | """
11 | A class that performs dijkstra's shortest path. The class operates on lane level graph search.
12 | The goal condition is specified to be if the lane can be found at the target roadblock or roadblock connector.
13 | """
14 |
15 | def __init__(self, start_edge: LaneGraphEdgeMapObject, candidate_lane_edge_ids: List[str]):
16 | """
17 | Constructor for the Dijkstra class.
18 | :param start_edge: The starting edge for the search
19 | :param candidate_lane_edge_ids: The candidates lane ids that can be included in the search.
20 | """
21 | self._queue = list([start_edge])
22 | self._parent: Dict[str, Optional[LaneGraphEdgeMapObject]] = dict()
23 | self._candidate_lane_edge_ids = candidate_lane_edge_ids
24 |
25 | def search(
26 | self, target_roadblock: RoadBlockGraphEdgeMapObject
27 | ) -> Tuple[List[LaneGraphEdgeMapObject], bool]:
28 | """
29 | Performs dijkstra's shortest path to find a route to the target roadblock.
30 | :param target_roadblock: The target roadblock the path should end at.
31 | :return:
32 | - A route starting from the given start edge
33 | - A bool indicating if the route is successfully found. Successful means that there exists a path
34 | from the start edge to an edge contained in the end roadblock.
35 | If unsuccessful the shortest deepest path is returned.
36 | """
37 | start_edge = self._queue[0]
38 |
39 | # Initial search states
40 | path_found: bool = False
41 | end_edge: LaneGraphEdgeMapObject = start_edge
42 |
43 | self._parent[start_edge.id] = None
44 | self._frontier = [start_edge.id]
45 | self._dist = [1]
46 | self._depth = [1]
47 |
48 | self._expanded = []
49 | self._expanded_id = []
50 | self._expanded_dist = []
51 | self._expanded_depth = []
52 |
53 | while len(self._queue) > 0:
54 | dist, idx = min((val, idx) for (idx, val) in enumerate(self._dist))
55 | current_edge = self._queue[idx]
56 | current_depth = self._depth[idx]
57 |
58 | del self._dist[idx], self._queue[idx], self._frontier[idx], self._depth[idx]
59 |
60 | if self._check_goal_condition(current_edge, target_roadblock):
61 | end_edge = current_edge
62 | path_found = True
63 | break
64 |
65 | self._expanded.append(current_edge)
66 | self._expanded_id.append(current_edge.id)
67 | self._expanded_dist.append(dist)
68 | self._expanded_depth.append(current_depth)
69 |
70 | # Populate queue
71 | for next_edge in current_edge.outgoing_edges:
72 | if not next_edge.id in self._candidate_lane_edge_ids:
73 | continue
74 |
75 | alt = dist + self._edge_cost(next_edge)
76 | if next_edge.id not in self._expanded_id and next_edge.id not in self._frontier:
77 | self._parent[next_edge.id] = current_edge
78 | self._queue.append(next_edge)
79 | self._frontier.append(next_edge.id)
80 | self._dist.append(alt)
81 | self._depth.append(current_depth + 1)
82 | end_edge = next_edge
83 |
84 | elif next_edge.id in self._frontier:
85 | next_edge_idx = self._frontier.index(next_edge.id)
86 | current_cost = self._dist[next_edge_idx]
87 | if alt < current_cost:
88 | self._parent[next_edge.id] = current_edge
89 | self._dist[next_edge_idx] = alt
90 | self._depth[next_edge_idx] = current_depth + 1
91 |
92 | if not path_found:
93 | # filter max depth
94 | max_depth = max(self._expanded_depth)
95 | idx_max_depth = list(np.where(np.array(self._expanded_depth) == max_depth)[0])
96 | dist_at_max_depth = [self._expanded_dist[i] for i in idx_max_depth]
97 |
98 | dist, _idx = min((val, idx) for (idx, val) in enumerate(dist_at_max_depth))
99 | end_edge = self._expanded[idx_max_depth[_idx]]
100 |
101 | return self._construct_path(end_edge), path_found
102 |
103 | @staticmethod
104 | def _edge_cost(lane: LaneGraphEdgeMapObject) -> float:
105 | """
106 | Edge cost of given lane.
107 | :param lane: lane class
108 | :return: length of lane
109 | """
110 | return lane.baseline_path.length
111 |
112 | @staticmethod
113 | def _check_end_condition(depth: int, target_depth: int) -> bool:
114 | """
115 | Check if the search should end regardless if the goal condition is met.
116 | :param depth: The current depth to check.
117 | :param target_depth: The target depth to check against.
118 | :return: True if:
119 | - The current depth exceeds the target depth.
120 | """
121 | return depth > target_depth
122 |
123 | @staticmethod
124 | def _check_goal_condition(
125 | current_edge: LaneGraphEdgeMapObject,
126 | target_roadblock: RoadBlockGraphEdgeMapObject,
127 | ) -> bool:
128 | """
129 | Check if the current edge is at the target roadblock at the given depth.
130 | :param current_edge: The edge to check.
131 | :param target_roadblock: The target roadblock the edge should be contained in.
132 | :return: whether the current edge is in the target roadblock
133 | """
134 | return current_edge.get_roadblock_id() == target_roadblock.id
135 |
136 | def _construct_path(self, end_edge: LaneGraphEdgeMapObject) -> List[LaneGraphEdgeMapObject]:
137 | """
138 | :param end_edge: The end edge to start back propagating back to the start edge.
139 | :param depth: The depth of the target edge.
140 | :return: The constructed path as a list of LaneGraphEdgeMapObject
141 | """
142 | path = [end_edge]
143 | while self._parent[end_edge.id] is not None:
144 | node = self._parent[end_edge.id]
145 | path.append(node)
146 | end_edge = node
147 | path.reverse()
148 |
149 | return path
150 |
--------------------------------------------------------------------------------
/src/scenario_manager/utils/route_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import numpy as np
4 | from nuplan.common.actor_state.ego_state import EgoState
5 | from nuplan.common.actor_state.state_representation import StateSE2
6 | from nuplan.common.maps.abstract_map import AbstractMap
7 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject
8 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer
9 | from nuplan.planning.simulation.occupancy_map.strtree_occupancy_map import (
10 | STRTreeOccupancyMapFactory,
11 | )
12 |
13 | from .bfs_roadblock import BreadthFirstSearchRoadBlock
14 |
15 |
16 | def normalize_angle(angle: float) -> float:
17 | return (angle + np.pi) % (2 * np.pi) - np.pi
18 |
19 |
20 | def get_current_roadblock_candidates(
21 | ego_state: EgoState,
22 | map_api: AbstractMap,
23 | route_roadblocks_dict: Dict[str, RoadBlockGraphEdgeMapObject],
24 | heading_error_thresh: float = np.pi / 4,
25 | displacement_error_thresh: float = 3,
26 | ) -> Tuple[RoadBlockGraphEdgeMapObject, List[RoadBlockGraphEdgeMapObject]]:
27 | """
28 | Determines a set of roadblock candidate where ego is located
29 | :param ego_state: class containing ego state
30 | :param map_api: map object
31 | :param route_roadblocks_dict: dictionary of on-route roadblocks
32 | :param heading_error_thresh: maximum heading error, defaults to np.pi/4
33 | :param displacement_error_thresh: maximum displacement, defaults to 3
34 | :return: tuple of most promising roadblock and other candidates
35 | """
36 | ego_pose: StateSE2 = ego_state.rear_axle
37 | roadblock_candidates = []
38 |
39 | layers = [SemanticMapLayer.ROADBLOCK, SemanticMapLayer.ROADBLOCK_CONNECTOR]
40 | roadblock_dict = map_api.get_proximal_map_objects(
41 | point=ego_pose.point, radius=2.5, layers=layers
42 | )
43 | roadblock_candidates = (
44 | roadblock_dict[SemanticMapLayer.ROADBLOCK]
45 | + roadblock_dict[SemanticMapLayer.ROADBLOCK_CONNECTOR]
46 | )
47 |
48 | if not roadblock_candidates:
49 | for layer in layers:
50 | roadblock_id_, distance = map_api.get_distance_to_nearest_map_object(
51 | point=ego_pose.point, layer=layer
52 | )
53 | roadblock = map_api.get_map_object(roadblock_id_, layer)
54 |
55 | if roadblock:
56 | roadblock_candidates.append(roadblock)
57 |
58 | on_route_candidates, on_route_candidate_displacement_errors = [], []
59 | candidates, candidate_displacement_errors = [], []
60 |
61 | roadblock_displacement_errors = []
62 | roadblock_heading_errors = []
63 |
64 | for idx, roadblock in enumerate(roadblock_candidates):
65 | lane_displacement_error, lane_heading_error = np.inf, np.inf
66 |
67 | for lane in roadblock.interior_edges:
68 | lane_discrete_path: List[StateSE2] = lane.baseline_path.discrete_path
69 | lane_discrete_points = np.array(
70 | [state.point.array for state in lane_discrete_path], dtype=np.float64
71 | )
72 | lane_state_distances = (
73 | (lane_discrete_points - ego_pose.point.array[None, ...]) ** 2.0
74 | ).sum(axis=-1) ** 0.5
75 | argmin = np.argmin(lane_state_distances)
76 |
77 | heading_error = np.abs(
78 | normalize_angle(lane_discrete_path[argmin].heading - ego_pose.heading)
79 | )
80 | displacement_error = lane_state_distances[argmin]
81 |
82 | if displacement_error < lane_displacement_error:
83 | lane_heading_error, lane_displacement_error = (
84 | heading_error,
85 | displacement_error,
86 | )
87 |
88 | if (
89 | heading_error < heading_error_thresh
90 | and displacement_error < displacement_error_thresh
91 | ):
92 | if roadblock.id in route_roadblocks_dict.keys():
93 | on_route_candidates.append(roadblock)
94 | on_route_candidate_displacement_errors.append(displacement_error)
95 | else:
96 | candidates.append(roadblock)
97 | candidate_displacement_errors.append(displacement_error)
98 |
99 | roadblock_displacement_errors.append(lane_displacement_error)
100 | roadblock_heading_errors.append(lane_heading_error)
101 |
102 | if on_route_candidates: # prefer on-route roadblocks
103 | return (
104 | on_route_candidates[np.argmin(on_route_candidate_displacement_errors)],
105 | on_route_candidates,
106 | )
107 | elif candidates: # fallback to most promising candidate
108 | return candidates[np.argmin(candidate_displacement_errors)], candidates
109 |
110 | # otherwise, just find any close roadblock
111 | return (
112 | roadblock_candidates[np.argmin(roadblock_displacement_errors)],
113 | roadblock_candidates,
114 | )
115 |
116 |
117 | def route_roadblock_correction(
118 | ego_state: EgoState,
119 | map_api: AbstractMap,
120 | route_roadblock_ids: List[str],
121 | search_depth_backward: int = 15,
122 | search_depth_forward: int = 30,
123 | ) -> List[str]:
124 | """
125 | Applies several methods to correct route roadblocks.
126 | :param ego_state: class containing ego state
127 | :param map_api: map object
128 | :param route_roadblocks_dict: dictionary of on-route roadblocks
129 | :param search_depth_backward: depth of forward BFS search, defaults to 15
130 | :param search_depth_forward: depth of backward BFS search, defaults to 30
131 | :return: list of roadblock id's of corrected route
132 | """
133 |
134 | route_roadblock_dict = {}
135 | for id_ in route_roadblock_ids:
136 | block = map_api.get_map_object(id_, SemanticMapLayer.ROADBLOCK)
137 | block = block or map_api.get_map_object(
138 | id_, SemanticMapLayer.ROADBLOCK_CONNECTOR
139 | )
140 | route_roadblock_dict[id_] = block
141 |
142 | starting_block, starting_block_candidates = get_current_roadblock_candidates(
143 | ego_state, map_api, route_roadblock_dict
144 | )
145 | starting_block_ids = [roadblock.id for roadblock in starting_block_candidates]
146 |
147 | route_roadblocks = list(route_roadblock_dict.values())
148 | route_roadblock_ids = list(route_roadblock_dict.keys())
149 |
150 | # Fix 1: when agent starts off-route
151 | if starting_block.id not in route_roadblock_ids:
152 | # Backward search if current roadblock not in route
153 | graph_search = BreadthFirstSearchRoadBlock(
154 | route_roadblock_ids[0], map_api, forward_search=False
155 | )
156 | (path, path_id), path_found = graph_search.search(
157 | starting_block_ids, max_depth=search_depth_backward
158 | )
159 |
160 | if path_found:
161 | route_roadblocks[:0] = path[:-1]
162 | route_roadblock_ids[:0] = path_id[:-1]
163 |
164 | else:
165 | # Forward search to any route roadblock
166 | graph_search = BreadthFirstSearchRoadBlock(
167 | starting_block.id, map_api, forward_search=True
168 | )
169 | (path, path_id), path_found = graph_search.search(
170 | route_roadblock_ids[:3], max_depth=search_depth_forward
171 | )
172 |
173 | if path_found:
174 | end_roadblock_idx = np.argmax(
175 | np.array(route_roadblock_ids) == path_id[-1]
176 | )
177 |
178 | route_roadblocks = route_roadblocks[end_roadblock_idx + 1 :]
179 | route_roadblock_ids = route_roadblock_ids[end_roadblock_idx + 1 :]
180 |
181 | route_roadblocks[:0] = path
182 | route_roadblock_ids[:0] = path_id
183 |
184 | # Fix 2: check if roadblocks are linked, search for links if not
185 | roadblocks_to_append = {}
186 | for i in range(len(route_roadblocks) - 1):
187 | next_incoming_block_ids = [
188 | _roadblock.id for _roadblock in route_roadblocks[i + 1].incoming_edges
189 | ]
190 | is_incoming = route_roadblock_ids[i] in next_incoming_block_ids
191 |
192 | if is_incoming:
193 | continue
194 |
195 | graph_search = BreadthFirstSearchRoadBlock(
196 | route_roadblock_ids[i], map_api, forward_search=True
197 | )
198 | (path, path_id), path_found = graph_search.search(
199 | route_roadblock_ids[i + 1], max_depth=search_depth_forward
200 | )
201 |
202 | if path_found and path and len(path) >= 3:
203 | path, path_id = path[1:-1], path_id[1:-1]
204 | roadblocks_to_append[i] = (path, path_id)
205 |
206 | # append missing intermediate roadblocks
207 | offset = 1
208 | for i, (path, path_id) in roadblocks_to_append.items():
209 | route_roadblocks[i + offset : i + offset] = path
210 | route_roadblock_ids[i + offset : i + offset] = path_id
211 | offset += len(path)
212 |
213 | # Fix 3: cut route-loops
214 | route_roadblocks, route_roadblock_ids = remove_route_loops(
215 | route_roadblocks, route_roadblock_ids
216 | )
217 |
218 | return route_roadblock_ids
219 |
220 |
221 | def remove_route_loops(
222 | route_roadblocks: List[RoadBlockGraphEdgeMapObject],
223 | route_roadblock_ids: List[str],
224 | ) -> Tuple[List[str], List[RoadBlockGraphEdgeMapObject]]:
225 | """
226 | Remove ending of route, if the roadblock are intersecting the route (forming a loop).
227 | :param route_roadblocks: input route roadblocks
228 | :param route_roadblock_ids: input route roadblocks ids
229 | :return: tuple of ids and roadblocks of route without loops
230 | """
231 |
232 | roadblock_occupancy_map = None
233 | loop_idx = None
234 |
235 | for idx, roadblock in enumerate(route_roadblocks):
236 | # loops only occur at intersection, thus searching for roadblock-connectors.
237 | if str(roadblock.__class__.__name__) == "NuPlanRoadBlockConnector":
238 | if not roadblock_occupancy_map:
239 | roadblock_occupancy_map = STRTreeOccupancyMapFactory.get_from_geometry(
240 | [roadblock.polygon], [roadblock.id]
241 | )
242 | continue
243 |
244 | strtree, index_by_id = roadblock_occupancy_map._build_strtree()
245 | indices = strtree.query(roadblock.polygon)
246 | if len(indices) > 0:
247 | for geom in strtree.geometries.take(indices):
248 | area = geom.intersection(roadblock.polygon).area
249 | if area > 1:
250 | loop_idx = idx
251 | break
252 | if loop_idx:
253 | break
254 |
255 | roadblock_occupancy_map.insert(roadblock.id, roadblock.polygon)
256 |
257 | if loop_idx:
258 | route_roadblocks = route_roadblocks[:loop_idx]
259 | route_roadblock_ids = route_roadblock_ids[:loop_idx]
260 |
261 | return route_roadblocks, route_roadblock_ids
262 |
--------------------------------------------------------------------------------
/src/utils/collision_checker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from nuplan.common.actor_state.vehicle_parameters import (
3 | VehicleParameters,
4 | get_pacifica_parameters,
5 | )
6 |
7 |
8 | class CollisionChecker:
9 | def __init__(
10 | self,
11 | vehicle: VehicleParameters = get_pacifica_parameters(),
12 | ) -> None:
13 | self._vehicle = vehicle
14 | self._sdc_half_length = vehicle.length / 2
15 | self._sdc_half_width = vehicle.width / 2
16 |
17 | self._sdc_normalized_corners = torch.stack(
18 | [
19 | torch.tensor([vehicle.length / 2, vehicle.width / 2]),
20 | torch.tensor([vehicle.length / 2, -vehicle.width / 2]),
21 | torch.tensor([-vehicle.length / 2, -vehicle.width / 2]),
22 | torch.tensor([-vehicle.length / 2, vehicle.width / 2]),
23 | ],
24 | dim=0,
25 | )
26 |
27 | def to_device(self, device):
28 | self._sdc_normalized_corners = self._sdc_normalized_corners.to(device)
29 |
30 | def build_bbox_from_center(self, center, heading, width, length):
31 | """
32 | params:
33 | center: [bs, N, (x, y)]
34 | heading: [bs, N]
35 | width: [bs, N]
36 | length: [bs, N]
37 | return:
38 | corners: [bs, 4, (x, y)]
39 | heading_vec, tanh_vec: [bs, 2]
40 | """
41 | cos = torch.cos(heading)
42 | sin = torch.sin(heading)
43 |
44 | heading_vec = torch.stack([cos, sin], dim=-1) * length.unsqueeze(-1) / 2
45 | tanh_vec = torch.stack([-sin, cos], dim=-1) * width.unsqueeze(-1) / 2
46 |
47 | corners = torch.stack(
48 | [
49 | center + heading_vec + tanh_vec,
50 | center - heading_vec + tanh_vec,
51 | center - heading_vec - tanh_vec,
52 | center + heading_vec - tanh_vec,
53 | ],
54 | dim=-2,
55 | )
56 |
57 | return corners, heading_vec, tanh_vec
58 |
59 | def collision_check(self, ego_state, objects, objects_width, objects_length):
60 | """performing batch-wise collision check using Separating Axis Theorem
61 | params:
62 | ego_states: [bs, (x, y, theta)], center of the ego
63 | objects: [bs, N, (x, y, theta)], center of the objects
64 | returns:
65 | is_collided: [bs, N]
66 | """
67 |
68 | bs, N = objects.shape[:2]
69 |
70 | # rotate object to ego's local frame
71 | cos, sin = torch.cos(ego_state[:, 2]), torch.sin(ego_state[:, 2])
72 | rotate_mat = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(bs, 2, 2)
73 |
74 | rotated_objects = objects.clone()
75 | rotated_objects[..., :2] = torch.matmul(
76 | rotated_objects[..., :2] - ego_state[:, :2].unsqueeze(1), rotate_mat
77 | )
78 | rotated_objects[..., 2] -= ego_state[..., 2].unsqueeze(1)
79 |
80 | # [bs, N, 4, 2], [bs, N, 2], [bs, N, 2]
81 | object_corners, axis1, axis2 = self.build_bbox_from_center(
82 | rotated_objects[..., :2],
83 | rotated_objects[..., 2],
84 | objects_width,
85 | objects_length,
86 | )
87 |
88 | ego_corners = self._sdc_normalized_corners.reshape(1, 1, 4, 2).repeat(
89 | bs, N, 1, 1
90 | ) # [bs, N, 4, 2]
91 |
92 | all_corners = torch.concat(
93 | [object_corners, ego_corners], dim=-2
94 | ) # [bs, N, 8, 2]
95 |
96 | x_projection = object_corners[..., 0]
97 | y_projection = object_corners[..., 1]
98 | axis1_projection = torch.matmul(all_corners, axis1.unsqueeze(-1)).squeeze(-1)
99 | axis2_projection = torch.matmul(all_corners, axis2.unsqueeze(-1)).squeeze(-1)
100 |
101 | x_separated = (x_projection.max(-1)[0] < -self._sdc_half_length) | (
102 | x_projection.min(-1)[0] > self._sdc_half_length
103 | )
104 | y_separated = (y_projection.max(-1)[0] < -self._sdc_half_width) | (
105 | y_projection.min(-1)[0] > self._sdc_half_width
106 | )
107 | axis1_separated = (
108 | axis1_projection[..., :4].max(-1)[0] < axis1_projection[..., 4:].min(-1)[0]
109 | ) | (
110 | axis1_projection[..., :4].min(-1)[0] > axis1_projection[..., 4:].max(-1)[0]
111 | )
112 | axis2_separated = (
113 | axis2_projection[..., :4].max(-1)[0] < axis2_projection[..., 4:].min(-1)[0]
114 | ) | (
115 | axis2_projection[..., :4].min(-1)[0] > axis2_projection[..., 4:].max(-1)[0]
116 | )
117 |
118 | collision = ~(x_separated | y_separated | axis1_separated | axis2_separated)
119 |
120 | return collision
121 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import pprint
4 | from pathlib import Path
5 | import numpy
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | import cv2
10 |
11 |
12 | def to_tensor(data):
13 | if isinstance(data, dict):
14 | return {k: to_tensor(v) for k, v in data.items()}
15 | elif isinstance(data, numpy.ndarray):
16 | if data.dtype == numpy.float64:
17 | return torch.from_numpy(data).float()
18 | else:
19 | return torch.from_numpy(data)
20 | elif isinstance(data, numpy.number):
21 | return torch.tensor(data).float()
22 | elif isinstance(data, list):
23 | return data
24 | elif isinstance(data, int):
25 | return torch.tensor(data)
26 | elif isinstance(data, tuple):
27 | return to_tensor(data[0])
28 | else:
29 | print(type(data), data)
30 | raise NotImplementedError
31 |
32 |
33 | def to_numpy(data):
34 | if isinstance(data, dict):
35 | return {k: to_numpy(v) for k, v in data.items()}
36 | elif isinstance(data, torch.Tensor):
37 | if data.requires_grad:
38 | return data.detach().cpu().numpy()
39 | else:
40 | return data.cpu().numpy()
41 | else:
42 | print(type(data), data)
43 | raise NotImplementedError
44 |
45 |
46 | def enable_grad(data):
47 | if isinstance(data, dict):
48 | return {k: enable_grad(v) for k, v in data.items()}
49 | elif isinstance(data, torch.Tensor):
50 | if data.dtype == torch.float32:
51 | data.requires_grad = True
52 | else:
53 | raise NotImplementedError
54 |
55 |
56 | def to_device(data, device):
57 | if isinstance(data, dict):
58 | return {k: to_device(v, device) for k, v in data.items()}
59 | elif isinstance(data, torch.Tensor):
60 | return data.to(device)
61 | else:
62 | raise NotImplementedError
63 |
64 |
65 | def print_dict_tensor(data, prefix=""):
66 | for k, v in data.items():
67 | if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray):
68 | print(f"{prefix}{k}: {v.shape}")
69 | elif isinstance(v, dict):
70 | print(f"{prefix}{k}:")
71 | print_dict_tensor(v, " ")
72 |
73 |
74 | def print_simulation_results(file=None):
75 | if file is not None:
76 | df = pd.read_parquet(file)
77 | else:
78 | root = Path(os.getcwd()) / "aggregator_metric"
79 | result = list(root.glob("*.parquet"))
80 | result = max(result, key=lambda item: item.stat().st_ctime)
81 | df = pd.read_parquet(result)
82 | final_score = df[df["scenario"] == "final_score"]
83 | final_score = final_score.to_dict(orient="records")[0]
84 | pprint.PrettyPrinter(indent=4).pprint(final_score)
85 |
86 |
87 | def load_checkpoint(checkpoint):
88 | ckpt = torch.load(checkpoint, map_location=torch.device("cpu"))
89 | state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
90 | return state_dict
91 |
92 |
93 | def safe_index(ls, value):
94 | try:
95 | return ls.index(value)
96 | except ValueError:
97 | return None
98 |
99 |
100 | def shift_and_rotate_img(img, shift, angle, resolution, cval=-200):
101 | """
102 | img: (H, W, C)
103 | shift: (H_shift, W_shift, 0)
104 | resolution: float
105 | angle: float
106 | """
107 | rows, cols = img.shape[:2]
108 | shift = shift / resolution
109 | translation_matrix = np.float32([[1, 0, shift[1]], [0, 1, shift[0]]])
110 | translated_img = cv2.warpAffine(
111 | img, translation_matrix, (cols, rows), borderValue=cval
112 | )
113 | M = cv2.getRotationMatrix2D((cols / 2, rows / 2), math.degrees(angle), 1)
114 | rotated_img = cv2.warpAffine(translated_img, M, (cols, rows), borderValue=cval)
115 | if len(img.shape) == 3 and len(rotated_img.shape) == 2:
116 | rotated_img = rotated_img[..., np.newaxis]
117 | return rotated_img.astype(np.float32)
118 |
119 |
120 | def crop_img_from_center(img, crop_size):
121 | h, w = img.shape[:2]
122 | h_crop, w_crop = crop_size
123 | h_start = (h - h_crop) // 2
124 | w_start = (w - w_crop) // 2
125 | return img[h_start : h_start + h_crop, w_start : w_start + w_crop].astype(
126 | np.float32
127 | )
128 |
--------------------------------------------------------------------------------