├── .gitignore ├── README.md ├── config ├── custom_trainer │ └── pluto_trainer.yaml ├── data_augmentation │ └── contrastive_scenario_generator.yaml ├── default_simulation.yaml ├── default_training.yaml ├── lightning │ └── custom_lightning.yaml ├── model │ └── pluto_model.yaml ├── planner │ └── pluto_planner.yaml ├── scenario_filter │ ├── mini_demo_scenario.yaml │ ├── random14_benchmark.yaml │ ├── training_scenarios_1M.yaml │ ├── training_scenarios_tiny.yaml │ ├── val14_benchmark.yaml │ └── val_demo_scenario.yaml └── training │ └── train_pluto.yaml ├── requirements.txt ├── run_simulation.py ├── run_training.py ├── script ├── run_pluto_planner.sh └── setup_env.sh └── src ├── custom_training ├── custom_datamodule.py └── custom_training_builder.py ├── data_augmentation └── contrastive_scenario_generator.py ├── feature_builders ├── common.py ├── nuplan_scenario_render.py └── pluto_feature_builder.py ├── features └── pluto_feature.py ├── metrics ├── __init__.py ├── min_ade.py ├── min_fde.py ├── mr.py ├── prediction_avg_ade.py ├── prediction_avg_fde.py └── utils.py ├── models └── pluto │ ├── layers │ ├── common_layers.py │ ├── embedding.py │ ├── fourier_embedding.py │ ├── mlp_layer.py │ └── transformer.py │ ├── loss │ └── esdf_collision_loss.py │ ├── modules │ ├── agent_encoder.py │ ├── agent_predictor.py │ ├── map_encoder.py │ ├── planning_decoder.py │ └── static_objects_encoder.py │ ├── pluto_model.py │ └── pluto_trainer.py ├── optim └── warmup_cos_lr.py ├── planners ├── ml_planner_utils.py └── pluto_planner.py ├── post_processing ├── common │ ├── enum.py │ └── geometry.py ├── emergency_brake.py ├── evaluation │ └── comfort_metrics.py ├── forward_simulation │ ├── batch_kinematic_bicycle.py │ ├── batch_lqr.py │ ├── batch_lqr_utils.py │ └── forward_simulator.py ├── observation │ └── world_from_prediction.py └── trajectory_evaluator.py ├── scenario_manager ├── cost_map_manager.py ├── occupancy_map.py ├── route_manager.py ├── scenario_manager.py └── utils │ ├── bfs_roadblock.py │ ├── dijkstra.py │ └── route_utils.py └── utils ├── collision_checker.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | share/python-wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | .DS_Store 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | *.py,cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | cover/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | .pybuilder/ 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | # For a library or package, you might want to ignore these files since the code is 86 | # intended to run in multiple environments; otherwise, check them in: 87 | # .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # poetry 97 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 98 | # This is especially recommended for binary packages to ensure reproducibility, and is more 99 | # commonly ignored for libraries. 100 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 101 | #poetry.lock 102 | 103 | # pdm 104 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 105 | #pdm.lock 106 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 107 | # in version control. 108 | # https://pdm.fming.dev/#use-with-ide 109 | .pdm.toml 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .venv 123 | # env/ 124 | venv/ 125 | ENV/ 126 | env.bak/ 127 | venv.bak/ 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | # pytype static type analyzer 148 | .pytype/ 149 | 150 | # Cython debug symbols 151 | cython_debug/ 152 | 153 | # PyCharm 154 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 155 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 156 | # and can be added to the global gitignore or merged into this file. For a more nuclear 157 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 158 | #.idea/ 159 | 160 | .vscode 161 | wandb/ 162 | outputs/ 163 | *.ckpt 164 | checkpoints/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLUTO 2 | 3 | This is the official repository of 4 | 5 | **PLUTO: Push the Limit of Imitation Learning-based Planning for Autonomous Driving**, 6 | 7 | [Jie Cheng](https://jchengai.github.io/), [Yingbing Chen](https://sites.google.com/view/chenyingbing-homepage), and [Qifeng Chen](https://cqf.io/) 8 | 9 | 10 |

11 | 12 | 13 | 14 | 15 | arXiv PDF 16 | 17 |

18 | 19 | ## Setup Environment 20 | 21 | ### Setup dataset 22 | 23 | Setup the nuPlan dataset following the [offiical-doc](https://nuplan-devkit.readthedocs.io/en/latest/dataset_setup.html) 24 | 25 | ### Setup conda environment 26 | 27 | ``` 28 | conda create -n pluto python=3.9 29 | conda activate pluto 30 | 31 | # install nuplan-devkit 32 | git clone https://github.com/motional/nuplan-devkit.git && cd nuplan-devkit 33 | pip install -e . 34 | pip install -r ./requirements.txt 35 | 36 | # setup pluto 37 | cd .. 38 | git clone https://github.com/jchengai/pluto.git && cd pluto 39 | sh ./script/setup_env.sh 40 | ``` 41 | 42 | ## Feature Cache 43 | 44 | Preprocess the dataset to accelerate training. It is recommended to run a small sanity check to make sure everything is correctly setup. 45 | 46 | ``` 47 | python run_training.py \ 48 | py_func=cache +training=train_pluto \ 49 | scenario_builder=nuplan_mini \ 50 | cache.cache_path=/nuplan/exp/sanity_check \ 51 | cache.cleanup_cache=true \ 52 | scenario_filter=training_scenarios_tiny \ 53 | worker=sequential 54 | ``` 55 | 56 | Then preprocess the whole nuPlan training set (this will take some time). You may need to change `cache.cache_path` to suit your condition 57 | 58 | ``` 59 | export PYTHONPATH=$PYTHONPATH:$(pwd) 60 | 61 | python run_training.py \ 62 | py_func=cache +training=train_pluto \ 63 | scenario_builder=nuplan \ 64 | cache.cache_path=/nuplan/exp/cache_pluto_1M \ 65 | cache.cleanup_cache=true \ 66 | scenario_filter=training_scenarios_1M \ 67 | worker.threads_per_node=40 68 | ``` 69 | 70 | ## Training 71 | 72 | (The training part it not fully tested) 73 | 74 | Same, it is recommended to run a sanity check first: 75 | 76 | ``` 77 | CUDA_VISIBLE_DEVICES=0 python run_training.py \ 78 | py_func=train +training=train_pluto \ 79 | worker=single_machine_thread_pool worker.max_workers=4 \ 80 | scenario_builder=nuplan cache.cache_path=/nuplan/exp/sanity_check cache.use_cache_without_dataset=true \ 81 | data_loader.params.batch_size=4 data_loader.params.num_workers=1 82 | ``` 83 | 84 | Training on the full dataset (without CIL): 85 | 86 | ``` 87 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_training.py \ 88 | py_func=train +training=train_pluto \ 89 | worker=single_machine_thread_pool worker.max_workers=32 \ 90 | scenario_builder=nuplan cache.cache_path=/nuplan/exp/cache_pluto_1M cache.use_cache_without_dataset=true \ 91 | data_loader.params.batch_size=32 data_loader.params.num_workers=16 \ 92 | lr=1e-3 epochs=25 warmup_epochs=3 weight_decay=0.0001 \ 93 | wandb.mode=online wandb.project=nuplan wandb.name=pluto 94 | ``` 95 | 96 | - add option `model.use_hidden_proj=true +custom_trainer.use_contrast_loss=true` to enable CIL. 97 | 98 | - you can remove wandb related configurations if your prefer tensorboard. 99 | 100 | 101 | ## Checkpoint 102 | 103 | Download and place the checkpoint in the `pluto/checkpoints` folder. 104 | 105 | | Model | Download | 106 | | ---------------- | -------- | 107 | | Pluto-1M-aux-cil | [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchengai_connect_ust_hk/EaFpLwwHFYVKsPVLH2nW5nEBNbPS7gqqu_Rv2V1dzODO-Q?e=LAZQcI) | 108 | 109 | 110 | ## Run Pluto-planner simulation 111 | 112 | Run simulation for a random scenario in the nuPlan-mini split 113 | 114 | ``` 115 | sh ./script/run_pluto_planner.sh pluto_planner nuplan_mini mini_demo_scenario pluto_1M_aux_cil.ckpt /dir_to_save_the_simulation_result_video 116 | ``` 117 | 118 | The rendered simulation video will be saved to the specified directory (need change `/dir_to_save_the_simulation_result_video`). 119 | 120 | ## To Do 121 | 122 | The code is under cleaning and will be released gradually. 123 | 124 | - [ ] improve docs 125 | - [x] training code 126 | - [x] visualization 127 | - [x] pluto-planner & checkpoint 128 | - [x] feature builder & model 129 | - [x] initial repo & paper 130 | 131 | ## Citation 132 | 133 | If you find this repo useful, please consider giving us a star 🌟 and citing our related paper. 134 | 135 | ```bibtex 136 | @article{cheng2024pluto, 137 | title={PLUTO: Pushing the Limit of Imitation Learning-based Planning for Autonomous Driving}, 138 | author={Cheng, Jie and Chen, Yingbing and Chen, Qifeng}, 139 | journal={arXiv preprint arXiv:2404.14327}, 140 | year={2024} 141 | } 142 | ``` -------------------------------------------------------------------------------- /config/custom_trainer/pluto_trainer.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.pluto.pluto_trainer.LightningTrainer -------------------------------------------------------------------------------- /config/data_augmentation/contrastive_scenario_generator.yaml: -------------------------------------------------------------------------------- 1 | contrastive_augmentation: 2 | _target_: src.data_augmentation.contrastive_scenario_generator.ContrastiveScenarioGenerator 3 | _convert_: "all" 4 | 5 | history_steps: 21 6 | low: [-1.0, -0.75, -0.35, -1, -0.5, -0.2, -0.1] 7 | high: [1.0, 0.75, 0.35, 1, 0.5, 0.2, 0.1] 8 | max_interaction_horizon: 40 9 | -------------------------------------------------------------------------------- /config/default_simulation.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${output_dir} 4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging 5 | searchpath: # Only in these paths are discoverable 6 | - pkg://nuplan.planning.script.config.common 7 | - pkg://nuplan.planning.script.config.simulation 8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/ 9 | - config/simulation 10 | - config/scenario_filter 11 | 12 | defaults: 13 | # Add ungrouped items 14 | - default_experiment 15 | - default_common 16 | - default_submission 17 | 18 | - simulation_metric: 19 | - default_metrics 20 | - callback: 21 | - simulation_log_callback 22 | - main_callback: 23 | - time_callback 24 | - metric_file_callback 25 | - metric_aggregator_callback 26 | - metric_summary_callback 27 | - splitter: nuplan 28 | 29 | # Hyperparameters need to be specified 30 | - observation: null 31 | - ego_controller: null 32 | - planner: null 33 | - simulation_time_controller: step_simulation_time_controller 34 | - metric_aggregator: 35 | - default_weighted_average 36 | 37 | - override hydra/job_logging: none # Disable hydra's logging 38 | - override hydra/hydra_logging: none # Disable hydra's logging 39 | 40 | experiment_name: 'simulation' 41 | aggregated_metric_folder_name: 'aggregator_metric' # Aggregated metric folder name 42 | aggregator_save_path: ${output_dir}/${aggregated_metric_folder_name} 43 | 44 | 45 | # Progress Visualization 46 | enable_simulation_progress_bar: true # Show for every simulation its progress 47 | 48 | # Simulation Setup 49 | simulation_history_buffer_duration: 2.0 # [s] The look back duration to initialize the simulation history buffer with 50 | 51 | # Number (or fractional, e.g., 0.25) of GPUs available for single simulation (per scenario and planner). 52 | # This number can also be < 1 because we allow multiple models to be loaded into a single GPU. 53 | # In case this number is 0 or null, no GPU is used for simulation and all cpu cores are leveraged 54 | # Note, that the user have to make sure that if a number < 1 is chosen, the model will fit 1 / num_gpus into GPU memory 55 | number_of_gpus_allocated_per_simulation: 1 56 | 57 | # This number specifies number of CPU threads that are used for simulation 58 | # In case this is null, then each simulation will use unlimited resources. 59 | # That will typically swamp the host computer, leading to slowdowns and failure. 60 | number_of_cpus_allocated_per_simulation: 1 61 | 62 | # Set false to disable metric computation 63 | run_metric: true 64 | 65 | # Set to rerun metrics with existing simulation logs without setting run_metric to false. 66 | simulation_log_main_path: null 67 | 68 | # If false, continue running the simulation even it a scenario has failed 69 | exit_on_failure: false 70 | 71 | # Maximum number of workers to be used for running simulation callbacks outside the main process 72 | max_callback_workers: 4 73 | 74 | # Disable callback parallelization when using the Sequential worker. By default, when running with the sequential worker, 75 | # on_simulation_end callbacks are not submitted to a parallel worker. 76 | disable_callback_parallelization: true 77 | 78 | # Distributed processing mode. If multi-node simulation is enable, this parameter selects how the scenarios distributed 79 | # to each node. The modes are: 80 | # - SCENARIO_BASED: Works in two stages, first getting a list of all, scenarios to process, then breaking up that 81 | # list and distributing across the workers 82 | # - LOG_FILE_BASED: Works in a single stage, breaking up the scenarios based on what log file they are in and 83 | # distributing the number of log files evenly across all workers 84 | # - SINGLE_NODE: Does no distribution, processes all scenarios in config 85 | distributed_mode: 'SINGLE_NODE' -------------------------------------------------------------------------------- /config/default_training.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${output_dir} 4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging 5 | searchpath: # Only in these paths are discoverable 6 | - pkg://nuplan.planning.script.config.common 7 | - pkg://nuplan.planning.script.config.training 8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/ 9 | - config/training 10 | 11 | 12 | defaults: 13 | - default_experiment 14 | - default_common 15 | 16 | # Trainer and callbacks 17 | - lightning: custom_lightning 18 | - callbacks: default_callbacks 19 | 20 | # Optimizer settings 21 | - optimizer: adam # [adam, adamw] supported optimizers 22 | - lr_scheduler: null # [one_cycle_lr] supported lr_schedulers 23 | - warm_up_lr_scheduler: null # [linear_warm_up, constant_warm_up] supported warm up lr schedulers 24 | 25 | # Data Loading 26 | - data_loader: default_data_loader 27 | - splitter: ??? 28 | 29 | # Objectives and metrics 30 | - objective: 31 | - training_metric: 32 | - data_augmentation: null 33 | - data_augmentation_scheduler: null # [default_augmentation_schedulers, stepwise_augmentation_probability_scheduler, stepwise_noise_parameter_scheduler] supported data augmentation schedulers 34 | - scenario_type_weights: default_scenario_type_weights 35 | - custom_trainer: null 36 | 37 | nuplan_trainer: false 38 | experiment_name: 'training' 39 | objective_aggregate_mode: ??? # How to aggregate multiple objectives, can be 'mean', 'max', 'sum' 40 | 41 | # Cache parameters 42 | cache: 43 | cache_path: # Local/remote path to store all preprocessed artifacts from the data pipeline 44 | use_cache_without_dataset: false # Load all existing features from a local/remote cache without loading the dataset 45 | force_feature_computation: false # Recompute features even if a cache exists 46 | cleanup_cache: false # Cleanup cached data in the cache_path, this ensures that new data are generated if the same cache_path is passed 47 | 48 | # Mandatory parameters 49 | py_func: ??? # Function to be run inside main (can be "train", "test", "cache") 50 | epochs: 25 51 | warmup_epochs: 3 52 | lr: 1e-3 53 | weight_decay: 0.0001 54 | checkpoint: 55 | 56 | # wandb settings 57 | wandb: 58 | mode: disable 59 | project: nuplan-pluto 60 | name: ${experiment_name} 61 | log_model: all 62 | artifact: 63 | run_id: 64 | -------------------------------------------------------------------------------- /config/lightning/custom_lightning.yaml: -------------------------------------------------------------------------------- 1 | distributed_training: 2 | equal_variance_scaling_strategy: true # scales lr and betas either linearly if false (multiply by num GPUs) or with equal_variance if true (multiply by sqaure root of num GPUs) 3 | 4 | trainer: 5 | checkpoint: 6 | resume_training: false # load the model from the last epoch and resume training 7 | save_top_k: 5 # save the top K models in terms of performance 8 | monitor: loss/val_loss # metric to monitor for performance 9 | mode: min # minimize/maximize metric 10 | 11 | params: 12 | # max_time: 00:16:00:00 # training time before the process is terminated 13 | 14 | max_epochs: ${epochs} # maximum number of training epochs 15 | # check_val_every_n_epoch: 1 # run validation set every n training epochs 16 | val_check_interval: 1.0 # [%] run validation set every X% of training set 17 | 18 | limit_train_batches: # how much of training dataset to check (float = fraction, int = num_batches) 19 | limit_val_batches: # how much of validation dataset to check (float = fraction, int = num_batches) 20 | limit_test_batches: # how much of test dataset to check (float = fraction, int = num_batches) 21 | 22 | devices: -1 # number of GPUs to utilize (-1 means all available GPUs) 23 | accelerator: gpu # distribution method 24 | precision: 32 # floating point precision 25 | # amp_level: O2 # AMP optimization level 26 | # num_nodes: 1 # Number of nodes used for training 27 | 28 | # auto_scale_batch_size: false 29 | # auto_lr_find: false # tunes LR before beginning training 30 | # terminate_on_nan: true # terminates training if a nan is encountered in loss/weights 31 | 32 | num_sanity_val_steps: 1 # number of validation steps to run before training begins 33 | fast_dev_run: false # runs 1 batch of train/val/test for sanity 34 | 35 | # accumulate_grad_batches: 1 # accumulates gradients every n batches 36 | # track_grad_norm: -1 # logs the p-norm for inspection 37 | gradient_clip_val: 5.0 # value to clip gradients 38 | gradient_clip_algorithm: norm # [value, norm] method to clip gradients 39 | sync_batchnorm: true 40 | strategy: ddp_find_unused_parameters_false 41 | 42 | # checkpoint_callback: true # enable default checkpoint 43 | 44 | overfitting: 45 | enable: false # run an overfitting test instead of training 46 | 47 | params: 48 | max_epochs: 150 # number of epochs to overfit the same batches 49 | overfit_batches: 1 # number of batches to overfit 50 | -------------------------------------------------------------------------------- /config/model/pluto_model.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.pluto.pluto_model.PlanningModel 2 | _convert_: "all" 3 | 4 | dim: 128 5 | state_channel: 6 6 | polygon_channel: 6 7 | history_channel: 9 8 | history_steps: 21 9 | future_steps: 80 10 | encoder_depth: 4 11 | decoder_depth: 4 12 | drop_path: 0.2 13 | dropout: 0.1 14 | num_heads: 4 15 | num_modes: 12 16 | state_dropout: 0.75 17 | use_ego_history: false 18 | state_attn_encoder: true 19 | use_hidden_proj: false 20 | 21 | feature_builder: 22 | _target_: src.feature_builders.pluto_feature_builder.PlutoFeatureBuilder 23 | _convert_: "all" 24 | radius: 120 25 | history_horizon: 2 26 | future_horizon: 8 27 | sample_interval: 0.1 28 | max_agents: 48 29 | build_reference_line: true 30 | -------------------------------------------------------------------------------- /config/planner/pluto_planner.yaml: -------------------------------------------------------------------------------- 1 | pluto_planner: 2 | _target_: src.planners.pluto_planner.PlutoPlanner 3 | _convert_: "all" 4 | 5 | render: false 6 | eval_dt: 0.1 7 | eval_num_frames: 40 8 | candidate_subsample_ratio: 1.0 9 | candidate_min_num: 10 10 | learning_based_score_weight: 0.3 11 | 12 | planner: 13 | _target_: src.models.pluto.pluto_model.PlanningModel 14 | _convert_: "all" 15 | 16 | dim: 128 17 | state_channel: 6 18 | polygon_channel: 6 19 | history_channel: 9 20 | history_steps: 21 21 | future_steps: 80 22 | encoder_depth: 4 23 | decoder_depth: 4 24 | drop_path: 0.2 25 | dropout: 0.1 26 | num_heads: 4 27 | num_modes: 12 28 | state_dropout: 0.75 29 | use_ego_history: false 30 | state_attn_encoder: true 31 | use_hidden_proj: true 32 | cat_x: true 33 | ref_free_traj: true 34 | 35 | feature_builder: 36 | _target_: src.feature_builders.pluto_feature_builder.PlutoFeatureBuilder 37 | _convert_: "all" 38 | radius: 120 39 | history_horizon: 2 40 | future_horizon: 8 41 | sample_interval: 0.1 42 | max_agents: 48 43 | build_reference_line: true 44 | 45 | planner_ckpt: 46 | -------------------------------------------------------------------------------- /config/scenario_filter/mini_demo_scenario.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: 'all' 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: null # List of scenario tokens to include 6 | 7 | log_names: null # Filter scenarios by log names 8 | map_names: null # Filter scenarios by map names 9 | 10 | num_scenarios_per_type: 1 # Number of scenarios per type 11 | limit_total_scenarios: 1 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 12 | timestamp_threshold_s: 15 # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 17 | 18 | 19 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios 20 | remove_invalid_goals: True # Whether to remove scenarios where the mission goal is invalid 21 | shuffle: true # Whether to shuffle the scenarios 22 | -------------------------------------------------------------------------------- /config/scenario_filter/random14_benchmark.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: all 3 | scenario_types: 4 | - starting_left_turn 5 | - starting_right_turn 6 | - starting_straight_traffic_light_intersection_traversal 7 | - stopping_with_lead 8 | - high_lateral_acceleration 9 | - high_magnitude_speed 10 | - low_magnitude_speed 11 | - traversing_pickup_dropoff 12 | - waiting_for_pedestrian_to_cross 13 | - behind_long_vehicle 14 | - stationary_in_traffic 15 | - near_multiple_vehicles 16 | - changing_lane 17 | - following_lane_with_lead 18 | 19 | scenario_tokens: 20 | - "09ca86cc9ae65428" 21 | - d77598de12fb5f46 22 | - "011118ec4f9952bc" 23 | - e9faae87fb83540d 24 | - "7b28cbcdcae35e7e" 25 | - "80c56ad735545de3" 26 | - "9031cf5175f253cf" 27 | - "22022119f3bd53f3" 28 | - ac67e458c88d5d78 29 | - "9cab2c90252c549b" 30 | - a639bd510b0d5b0a 31 | - "9e507708119c5596" 32 | - "4f9321ffbcb95b55" 33 | - a0f50ac13caa51ac 34 | - "843eb1eee80b529e" 35 | - "5051ccd75314515c" 36 | - "1939feda817551d5" 37 | - "281b1fe38bdc5467" 38 | - c9d8621d7c2255e6 39 | - "8b8cdc059d585494" 40 | - "6c6ddf6740d355ef" 41 | - "74e550e274f959a4" 42 | - "0393fbf49cd55295" 43 | - f00be015e344557c 44 | - c00c5e8e89e3571c 45 | - "33ef46dbce1c577f" 46 | - "69b83d7c62d65006" 47 | - "7cc26b3af3b0557f" 48 | - dd117e2e2dfa560d 49 | - "7280988648d05358" 50 | - "9a81702c38d757b8" 51 | - a88d3779eb535274 52 | - "27fb96d75d4759a2" 53 | - "5dec4c8bfd49502f" 54 | - "942f6862e6055f62" 55 | - "72038bf480a55042" 56 | - e3342e2e8b535405 57 | - "07874767a6d55ad7" 58 | - f1cd02f2371c5c67 59 | - "2e56997063c057a0" 60 | - "382951755d8e5e77" 61 | - "2f69576b55305cc8" 62 | - f9e6e7fc604b5ff9 63 | - "24cc8621c8595a92" 64 | - "0c88901b6cfb53d0" 65 | - "93c583b46398560e" 66 | - "7835a781f6bd5688" 67 | - "4aeaa8bebf7a57f1" 68 | - "33624e1585945551" 69 | - e5d69d5a0b135831 70 | - "711c54430cae590f" 71 | - "29b77c53a46956af" 72 | - "7d094d765a8f5ffa" 73 | - a1d2f54d8ec1564c 74 | - af3681f002005504 75 | - a26d34c327365e71 76 | - "15a3339778d058db" 77 | - "81ecd0ddcc4e5c1f" 78 | - "770a7cb9f1585e38" 79 | - "7f50df6eea5c54e9" 80 | - b7485487d49853ca 81 | - cb153e6cd19851ba 82 | - "4e0c8ec081805052" 83 | - b1f9cfc8ac885d91 84 | - "4eaa9a0e22c95491" 85 | - "812679cd23f45d5d" 86 | - "7be9d726f1be5003" 87 | - dc98be51b7155407 88 | - aff9a658a9e450cb 89 | - "845645e7c5665590" 90 | - d3dfcbb4c7be5e8b 91 | - "33a3eb36641857e6" 92 | - a489a4faafbd5c67 93 | - "1ce0f41f0d9754f5" 94 | - "2df19b9bf2fe5c88" 95 | - "0908906ecbed58df" 96 | - d7a3ae0854575652 97 | - "29a854ae1aee58ff" 98 | - "64986687478f5338" 99 | - cada319ebf2b505a 100 | - a874cefaca8e5b69 101 | - "5e33493023b75f33" 102 | - "356d80474d1252c0" 103 | - "0eb1eb9046925cf8" 104 | - bc2485198080598a 105 | - b342971fb07451f9 106 | - "098144ba356652a5" 107 | - "33b5947ae3d75a80" 108 | - b086f30875e65add 109 | - b0f8ae874f2b5e78 110 | - "38e301b3cbad5707" 111 | - e5ef6199dc3b5909 112 | - "4236a1f3c9e35f9d" 113 | - ccdca0ce28565318 114 | - "2664bb314a155875" 115 | - d8633b984d75530e 116 | - d21b22d6c0405ad7 117 | - f3895453b6c35e51 118 | - dca16f92a1015a7f 119 | - "9fd5ec2b453d556e" 120 | - "619533b026ce59f1" 121 | - d5a09eb525e8592a 122 | - "48b0081cd2385220" 123 | - "35dd043cd64e5537" 124 | - cc4bc91994df5424 125 | - "522cda9d26b1526a" 126 | - "8160c5b3a5c9555e" 127 | - "704c05ebeb5959e6" 128 | - "78c5b673402b5861" 129 | - "34a52b2f68f357df" 130 | - a614ad720a76576c 131 | - b6f238f681bf5a09 132 | - b3b422a58ff35545 133 | - f0be7cf0e03e58db 134 | - c1027e45d3b956e4 135 | - "99c06a975b465903" 136 | - e51649556f215dd9 137 | - b526edea27bc5f93 138 | - e9d360ea046554d6 139 | - e7fc2f835ea95ece 140 | - "2fd784c6bf4f572c" 141 | - "28d5868f9d6e5035" 142 | - "6c37bc8ac424562b" 143 | - "2f8c75dfca3f50b9" 144 | - ead9590e094d5a8f 145 | - "5af50221d55658e4" 146 | - "65af9ce73917587a" 147 | - "0cc3d9bb137b5a2b" 148 | - "7e4a2d822a2e55e0" 149 | - bd5783916d995801 150 | - fa10175cddfa52fa 151 | - e095899901a75a56 152 | - "936b015ebca15109" 153 | - "2f149371d38e5806" 154 | - "40bcab2def635436" 155 | - a438e0b5c26351dd 156 | - ee479ca0620c5077 157 | - c3bdac749eeb5206 158 | - "3c0e98e1b77b5d44" 159 | - c28d662b14b05c83 160 | - "2be476b6b8db5693" 161 | - "2574c2dbef5850f1" 162 | - "144473d8f03a57f4" 163 | - bff3663d750c5ee1 164 | - "541899555b645329" 165 | - "0d05468e475657b8" 166 | - "9eb227aec6dc5e23" 167 | - "8fd7ccd629615f24" 168 | - ef202cd018ce55e1 169 | - "7dea428154ad5d8a" 170 | - "6582d5ea9e2e5fa8" 171 | - "1f5dc80046035f11" 172 | - "1a278d6ef4435e69" 173 | - "52e1da615adf57c4" 174 | - "0612d791f0825acd" 175 | - "88c66c9df39e5193" 176 | - "9bb232eead8e525d" 177 | - db0204e9603f59e9 178 | - da64279ff31552d1 179 | - "462120b01b425182" 180 | - e257ace5273058ae 181 | - "660d375c109f5eed" 182 | - "9bdc15b26c455633" 183 | - c7be4ae92b455fa3 184 | - "84008abb23955152" 185 | - "957e6a7dec135b62" 186 | - "29686b1e8e6859ca" 187 | - "56726e1f05ae5164" 188 | - b85af6bc62cf5e03 189 | - "341ca59d11c15742" 190 | - eaef31406a205542 191 | - "63e48fa44d7c5aa4" 192 | - "713972db4cf35d09" 193 | - "3452d2d6fb655c2b" 194 | - "8efa8475e103598e" 195 | - e5fd3465d00b57ae 196 | - "8de10fd86b825304" 197 | - "477820688c4f5683" 198 | - "8ec9c713f4fc5d52" 199 | - "9fb968b3cea6507a" 200 | - a762a67013d4550e 201 | - "714dffb0115b5071" 202 | - "17e261bca6435550" 203 | - "561e8b2703f651f5" 204 | - "96074214b9645952" 205 | - "094db8d50cf95250" 206 | - "61d9e27ce6ee5a38" 207 | - daaaa1360bcf5451 208 | - c167ed0e96e455c4 209 | - "724cc15d7fc856df" 210 | - c700656703ec57bc 211 | - d6dde3e83ebe5b53 212 | - fa75094937835cf0 213 | - adc1c85043525786 214 | - e4bac3a9053955e2 215 | - "2fa51997ae9554cb" 216 | - "3ee6e16b65f956a8" 217 | - "1039c939079950f4" 218 | - "2e82153ebbed5d67" 219 | - bb0e6c0c719a5e7f 220 | - "7dc66aff7ac85f5f" 221 | - fd31b918c10b56e2 222 | - "0fa049c07b5b57c0" 223 | - bac6f1145eef5bf2 224 | - "6a1ffac82183523e" 225 | - "9f006885686852e7" 226 | - "3acc225a64965bd8" 227 | - a2ec5056da3c5c67 228 | - f69d0b59d08d512b 229 | - "3474e3d6c7485935" 230 | - c139820a97455f83 231 | - "93b3c95f3fea5812" 232 | - a09bce6cb8eb591b 233 | - c4f2b1f170eb5649 234 | - "9ba0cad0c3e7580f" 235 | - "1772467d58bb5dec" 236 | - "07db0457ad745d8e" 237 | - "94ce7eff6722572c" 238 | - b83ee0e949e45520 239 | - b497992774675304 240 | - "255e819738da54df" 241 | - "1b1aafb5916e5534" 242 | - "93e79e172de3521b" 243 | - c757a4faf12f52be 244 | - "155505762e1d5cc0" 245 | - "51933e2f44775d6a" 246 | - "071b206ea836546e" 247 | - f5d7ec43b9415862 248 | - "0ef918a26914564e" 249 | - "1c8be883a97a575e" 250 | - bd345bdaa3715b71 251 | - c998657ba1535677 252 | - c1c5dec6bab3598a 253 | - "7721f05072a85928" 254 | - "95495f60a1b659f3" 255 | - "1173af87b1e551ce" 256 | - "399986448051533a" 257 | - c299dae1f0745da8 258 | - beb18cddda575028 259 | - fcedb0da569c518b 260 | - a88f9e60e43c5ebc 261 | - f5a6ac1890c75f87 262 | - e7aa83d247dd5cdc 263 | - cfbbb77238bd541f 264 | - "62db31b428315ae8" 265 | - "0409c3925f245965" 266 | - "70086024fed658cf" 267 | - ceec28e1943f5d76 268 | - "3094b2c29265536a" 269 | log_names: null 270 | map_names: null 271 | num_scenarios_per_type: null 272 | limit_total_scenarios: null 273 | timestamp_threshold_s: 15 274 | ego_displacement_minimum_m: null 275 | ego_start_speed_threshold: null 276 | ego_stop_speed_threshold: null 277 | speed_noise_tolerance: null 278 | expand_scenarios: null 279 | remove_invalid_goals: true 280 | shuffle: false 281 | -------------------------------------------------------------------------------- /config/scenario_filter/training_scenarios_1M.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: 'all' 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: null # List of scenario tokens to include 6 | 7 | log_names: null # Filter scenarios by log names 8 | map_names: null # Filter scenarios by map names 9 | 10 | num_scenarios_per_type: null # Number of scenarios per type 11 | limit_total_scenarios: 1000000 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 17 | 18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios 19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 20 | shuffle: false # Whether to shuffle the scenarios 21 | -------------------------------------------------------------------------------- /config/scenario_filter/training_scenarios_tiny.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: 'all' 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: null # List of scenario tokens to include 6 | 7 | log_names: null # Filter scenarios by log names 8 | map_names: null # Filter scenarios by map names 9 | 10 | num_scenarios_per_type: null # Number of scenarios per type 11 | limit_total_scenarios: 50 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 17 | 18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios 19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 20 | shuffle: true # Whether to shuffle the scenarios -------------------------------------------------------------------------------- /config/scenario_filter/val_demo_scenario.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: "all" 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: 6 | - c556f4bc6f165a76 7 | 8 | log_names: # Filter scenarios by log names 9 | - 2021.07.24.00.36.59_veh-47_06810_07310 10 | 11 | map_names: null # Filter scenarios by map names 12 | 13 | num_scenarios_per_type: null # Number of scenarios per type 14 | limit_total_scenarios: null # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 15 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 16 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 17 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 18 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 19 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 20 | 21 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios 22 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 23 | shuffle: false # Whether to shuffle the scenarios 24 | -------------------------------------------------------------------------------- /config/training/train_pluto.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | job_name: pluto 3 | py_func: train 4 | objective_aggregate_mode: mean 5 | 6 | defaults: 7 | - override /data_augmentation: 8 | - contrastive_scenario_generator 9 | - override /splitter: nuplan 10 | - override /model: pluto_model 11 | - override /scenario_filter: training_scenarios_tiny 12 | - override /custom_trainer: pluto_trainer 13 | - override /lightning: custom_lightning 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | pytorch-lightning==2.0.1 3 | torchmetrics==0.10.2 4 | tensorboard 5 | wandb==0.14.2 6 | numba 7 | rich==13.3.4 -------------------------------------------------------------------------------- /run_simulation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from typing import List, Optional, Union 7 | 8 | import hydra 9 | import pandas as pd 10 | import pytorch_lightning as pl 11 | from nuplan.common.utils.s3_utils import is_s3_path 12 | from nuplan.planning.script.builders.simulation_builder import build_simulations 13 | from nuplan.planning.script.builders.simulation_callback_builder import ( 14 | build_callbacks_worker, 15 | build_simulation_callbacks, 16 | ) 17 | from nuplan.planning.script.utils import ( 18 | run_runners, 19 | set_default_path, 20 | set_up_common_builder, 21 | ) 22 | from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner 23 | from omegaconf import DictConfig, OmegaConf 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | # If set, use the env. variable to overwrite the default dataset and experiment paths 29 | set_default_path() 30 | 31 | # If set, use the env. variable to overwrite the Hydra config 32 | CONFIG_PATH = os.getenv("NUPLAN_HYDRA_CONFIG_PATH", "config/simulation") 33 | 34 | 35 | def print_simulation_results(file=None): 36 | if file is not None: 37 | df = pd.read_parquet(file) 38 | else: 39 | root = Path(os.getcwd()) / "aggregator_metric" 40 | result = list(root.glob("*.parquet")) 41 | result = max(result, key=lambda item: item.stat().st_ctime) 42 | df = pd.read_parquet(result) 43 | final_score = df[df["scenario"] == "final_score"] 44 | final_score = final_score.to_dict(orient="records")[0] 45 | pprint.PrettyPrinter(indent=4).pprint(final_score) 46 | 47 | 48 | def run_simulation( 49 | cfg: DictConfig, 50 | planners: Optional[Union[AbstractPlanner, List[AbstractPlanner]]] = None, 51 | ) -> None: 52 | """ 53 | Execute all available challenges simultaneously on the same scenario. Helper function for main to allow planner to 54 | be specified via config or directly passed as argument. 55 | :param cfg: Configuration that is used to run the experiment. 56 | Already contains the changes merged from the experiment's config to default config. 57 | :param planners: Pre-built planner(s) to run in simulation. Can either be a single planner or list of planners. 58 | """ 59 | # Fix random seed 60 | pl.seed_everything(cfg.seed, workers=True) 61 | 62 | profiler_name = "building_simulation" 63 | common_builder = set_up_common_builder(cfg=cfg, profiler_name=profiler_name) 64 | 65 | # Build simulation callbacks 66 | callbacks_worker_pool = build_callbacks_worker(cfg) 67 | callbacks = build_simulation_callbacks( 68 | cfg=cfg, output_dir=common_builder.output_dir, worker=callbacks_worker_pool 69 | ) 70 | 71 | # Remove planner from config to make sure run_simulation does not receive multiple planner specifications. 72 | if planners and "planner" in cfg.keys(): 73 | logger.info("Using pre-instantiated planner. Ignoring planner in config") 74 | OmegaConf.set_struct(cfg, False) 75 | cfg.pop("planner") 76 | OmegaConf.set_struct(cfg, True) 77 | 78 | # Construct simulations 79 | if isinstance(planners, AbstractPlanner): 80 | planners = [planners] 81 | 82 | runners = build_simulations( 83 | cfg=cfg, 84 | callbacks=callbacks, 85 | worker=common_builder.worker, 86 | pre_built_planners=planners, 87 | callbacks_worker=callbacks_worker_pool, 88 | ) 89 | 90 | if common_builder.profiler: 91 | # Stop simulation construction profiling 92 | common_builder.profiler.save_profiler(profiler_name) 93 | 94 | logger.info("Running simulation...") 95 | run_runners( 96 | runners=runners, 97 | common_builder=common_builder, 98 | cfg=cfg, 99 | profiler_name="running_simulation", 100 | ) 101 | logger.info("Finished running simulation!") 102 | 103 | 104 | def clean_up_s3_artifacts() -> None: 105 | """ 106 | Cleanup lingering s3 artifacts that are written locally. 107 | This happens because some minor write-to-s3 functionality isn't yet implemented. 108 | """ 109 | # Lingering artifacts get written locally to a 's3:' directory. Hydra changes 110 | # the working directory to a subdirectory of this, so we serach the working 111 | # path for it. 112 | working_path = os.getcwd() 113 | s3_dirname = "s3:" 114 | s3_ind = working_path.find(s3_dirname) 115 | if s3_ind != -1: 116 | local_s3_path = working_path[: working_path.find(s3_dirname) + len(s3_dirname)] 117 | rmtree(local_s3_path) 118 | 119 | 120 | @hydra.main(config_path="./config", config_name="default_simulation") 121 | def main(cfg: DictConfig) -> None: 122 | """ 123 | Execute all available challenges simultaneously on the same scenario. Calls run_simulation to allow planner to 124 | be specified via config or directly passed as argument. 125 | :param cfg: Configuration that is used to run the experiment. 126 | Already contains the changes merged from the experiment's config to default config. 127 | """ 128 | assert ( 129 | cfg.simulation_log_main_path is None 130 | ), "Simulation_log_main_path must not be set when running simulation." 131 | 132 | run_simulation(cfg=cfg) 133 | 134 | if is_s3_path(Path(cfg.output_dir)): 135 | clean_up_s3_artifacts() 136 | 137 | print_simulation_results() 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import hydra 5 | import numpy 6 | import pytorch_lightning as pl 7 | from nuplan.planning.script.builders.folder_builder import ( 8 | build_training_experiment_folder, 9 | ) 10 | from nuplan.planning.script.builders.logging_builder import build_logger 11 | from nuplan.planning.script.builders.worker_pool_builder import build_worker 12 | from nuplan.planning.script.profiler_context_manager import ProfilerContextManager 13 | from nuplan.planning.script.utils import set_default_path 14 | from nuplan.planning.training.experiments.caching import cache_data 15 | from omegaconf import DictConfig 16 | 17 | from src.custom_training import ( 18 | TrainingEngine, 19 | build_training_engine, 20 | update_config_for_training, 21 | ) 22 | 23 | logging.getLogger("numba").setLevel(logging.WARNING) 24 | logger = logging.getLogger(__name__) 25 | 26 | # If set, use the env. variable to overwrite the default dataset and experiment paths 27 | set_default_path() 28 | 29 | # If set, use the env. variable to overwrite the Hydra config 30 | CONFIG_PATH = "./config" 31 | CONFIG_NAME = "default_training" 32 | 33 | 34 | @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) 35 | def main(cfg: DictConfig) -> Optional[TrainingEngine]: 36 | """ 37 | Main entrypoint for training/validation experiments. 38 | :param cfg: omegaconf dictionary 39 | """ 40 | pl.seed_everything(cfg.seed, workers=True) 41 | 42 | # Configure logger 43 | build_logger(cfg) 44 | 45 | # Override configs based on setup, and print config 46 | update_config_for_training(cfg) 47 | 48 | # Create output storage folder 49 | build_training_experiment_folder(cfg=cfg) 50 | 51 | # Build worker 52 | worker = build_worker(cfg) 53 | 54 | if cfg.py_func == "train": 55 | # Build training engine 56 | with ProfilerContextManager( 57 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 58 | ): 59 | engine = build_training_engine(cfg, worker) 60 | 61 | # Run training 62 | logger.info("Starting training...") 63 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "training"): 64 | engine.trainer.fit( 65 | model=engine.model, 66 | datamodule=engine.datamodule, 67 | ckpt_path=cfg.checkpoint, 68 | ) 69 | return engine 70 | if cfg.py_func == "validate": 71 | # Build training engine 72 | with ProfilerContextManager( 73 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 74 | ): 75 | engine = build_training_engine(cfg, worker) 76 | 77 | # Run training 78 | logger.info("Starting training...") 79 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "validate"): 80 | engine.trainer.validate( 81 | model=engine.model, 82 | datamodule=engine.datamodule, 83 | ckpt_path=cfg.checkpoint, 84 | ) 85 | return engine 86 | elif cfg.py_func == "test": 87 | # Build training engine 88 | with ProfilerContextManager( 89 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 90 | ): 91 | engine = build_training_engine(cfg, worker) 92 | 93 | # Test model 94 | logger.info("Starting testing...") 95 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "testing"): 96 | engine.trainer.test(model=engine.model, datamodule=engine.datamodule) 97 | return engine 98 | elif cfg.py_func == "cache": 99 | # Precompute and cache all features 100 | logger.info("Starting caching...") 101 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "caching"): 102 | cache_data(cfg=cfg, worker=worker) 103 | return None 104 | else: 105 | raise NameError(f"Function {cfg.py_func} does not exist") 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /script/run_pluto_planner.sh: -------------------------------------------------------------------------------- 1 | cwd=$(pwd) 2 | CKPT_ROOT="$cwd/checkpoints" 3 | 4 | PLANNER=$1 5 | BUILDER=$2 6 | FILTER=$3 7 | CKPT=$4 8 | VIDEO_SAVE_DIR=$5 9 | 10 | CHALLENGE="closed_loop_nonreactive_agents" 11 | # CHALLENGE="closed_loop_reactive_agents" 12 | # CHALLENGE="open_loop_boxes" 13 | 14 | python run_simulation.py \ 15 | +simulation=$CHALLENGE \ 16 | planner=$PLANNER \ 17 | scenario_builder=$BUILDER \ 18 | scenario_filter=$FILTER \ 19 | worker=sequential \ 20 | verbose=true \ 21 | experiment_uid="pluto_planner/$FILTER" \ 22 | planner.pluto_planner.render=true \ 23 | planner.pluto_planner.planner_ckpt="$CKPT_ROOT/$CKPT" \ 24 | +planner.pluto_planner.save_dir=$VIDEO_SAVE_DIR 25 | 26 | -------------------------------------------------------------------------------- /script/setup_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==2.0.1 torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118 2 | pip3 install natten==0.14.6 -f https://shi-labs.com/natten/wheels/cu118/torch2.0.0/index.html 3 | pip install -r ./requirements.txt 4 | -------------------------------------------------------------------------------- /src/custom_training/custom_training_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from typing import cast 7 | 8 | import pytorch_lightning as pl 9 | from hydra.utils import instantiate 10 | from nuplan.planning.script.builders.data_augmentation_builder import ( 11 | build_agent_augmentor, 12 | ) 13 | from nuplan.planning.script.builders.model_builder import build_torch_module_wrapper 14 | from nuplan.planning.script.builders.objectives_builder import build_objectives 15 | from nuplan.planning.script.builders.scenario_builder import build_scenarios 16 | from nuplan.planning.script.builders.splitter_builder import build_splitter 17 | from nuplan.planning.script.builders.training_metrics_builder import ( 18 | build_training_metrics, 19 | ) 20 | from nuplan.planning.training.modeling.lightning_module_wrapper import ( 21 | LightningModuleWrapper, 22 | ) 23 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 24 | from nuplan.planning.training.preprocessing.feature_preprocessor import ( 25 | FeaturePreprocessor, 26 | ) 27 | from nuplan.planning.utils.multithreading.worker_pool import WorkerPool 28 | from omegaconf import DictConfig, OmegaConf 29 | from pytorch_lightning.callbacks import ( 30 | LearningRateMonitor, 31 | ModelCheckpoint, 32 | RichModelSummary, 33 | RichProgressBar, 34 | ) 35 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 36 | from pytorch_lightning.loggers.wandb import WandbLogger 37 | 38 | from .custom_datamodule import CustomDataModule 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def update_config_for_training(cfg: DictConfig) -> None: 44 | """ 45 | Updates the config based on some conditions. 46 | :param cfg: omegaconf dictionary that is used to run the experiment. 47 | """ 48 | # Make the configuration editable. 49 | OmegaConf.set_struct(cfg, False) 50 | 51 | if cfg.cache.cache_path is None: 52 | logger.warning("Parameter cache_path is not set, caching is disabled") 53 | else: 54 | if not str(cfg.cache.cache_path).startswith("s3://"): 55 | if cfg.cache.cleanup_cache and Path(cfg.cache.cache_path).exists(): 56 | rmtree(cfg.cache.cache_path) 57 | 58 | Path(cfg.cache.cache_path).mkdir(parents=True, exist_ok=True) 59 | 60 | if cfg.lightning.trainer.overfitting.enable: 61 | cfg.data_loader.params.num_workers = 0 62 | 63 | OmegaConf.resolve(cfg) 64 | 65 | # Finalize the configuration and make it non-editable. 66 | OmegaConf.set_struct(cfg, True) 67 | 68 | # Log the final configuration after all overrides, interpolations and updates. 69 | if cfg.log_config: 70 | logger.info( 71 | f"Creating experiment name [{cfg.experiment}] in group [{cfg.group}] with config..." 72 | ) 73 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 74 | 75 | 76 | @dataclass(frozen=True) 77 | class TrainingEngine: 78 | """Lightning training engine dataclass wrapping the lightning trainer, model and datamodule.""" 79 | 80 | trainer: pl.Trainer # Trainer for models 81 | model: pl.LightningModule # Module describing NN model, loss, metrics, visualization 82 | datamodule: pl.LightningDataModule # Loading data 83 | 84 | def __repr__(self) -> str: 85 | """ 86 | :return: String representation of class without expanding the fields. 87 | """ 88 | return f"<{type(self).__module__}.{type(self).__qualname__} object at {hex(id(self))}>" 89 | 90 | 91 | def build_lightning_datamodule( 92 | cfg: DictConfig, worker: WorkerPool, model: TorchModuleWrapper 93 | ) -> pl.LightningDataModule: 94 | """ 95 | Build the lightning datamodule from the config. 96 | :param cfg: Omegaconf dictionary. 97 | :param model: NN model used for training. 98 | :param worker: Worker to submit tasks which can be executed in parallel. 99 | :return: Instantiated datamodule object. 100 | """ 101 | # Build features and targets 102 | feature_builders = model.get_list_of_required_feature() 103 | target_builders = model.get_list_of_computed_target() 104 | 105 | # Build splitter 106 | splitter = build_splitter(cfg.splitter) 107 | 108 | # Create feature preprocessor 109 | feature_preprocessor = FeaturePreprocessor( 110 | cache_path=cfg.cache.cache_path, 111 | force_feature_computation=cfg.cache.force_feature_computation, 112 | feature_builders=feature_builders, 113 | target_builders=target_builders, 114 | ) 115 | 116 | # Create data augmentation 117 | augmentors = ( 118 | build_agent_augmentor(cfg.data_augmentation) 119 | if "data_augmentation" in cfg 120 | else None 121 | ) 122 | 123 | # Build dataset scenarios 124 | scenarios = build_scenarios(cfg, worker, model) 125 | 126 | # Create datamodule 127 | datamodule: pl.LightningDataModule = CustomDataModule( 128 | feature_preprocessor=feature_preprocessor, 129 | splitter=splitter, 130 | all_scenarios=scenarios, 131 | dataloader_params=cfg.data_loader.params, 132 | augmentors=augmentors, 133 | worker=worker, 134 | scenario_type_sampling_weights=cfg.scenario_type_weights.scenario_type_sampling_weights, 135 | **cfg.data_loader.datamodule, 136 | ) 137 | 138 | return datamodule 139 | 140 | 141 | def build_lightning_module( 142 | cfg: DictConfig, torch_module_wrapper: TorchModuleWrapper 143 | ) -> pl.LightningModule: 144 | """ 145 | Builds the lightning module from the config. 146 | :param cfg: omegaconf dictionary 147 | :param torch_module_wrapper: NN model used for training 148 | :return: built object. 149 | """ 150 | # Create the complete Module 151 | if "custom_trainer" in cfg: 152 | model = instantiate( 153 | cfg.custom_trainer, 154 | model=torch_module_wrapper, 155 | lr=cfg.lr, 156 | weight_decay=cfg.weight_decay, 157 | epochs=cfg.epochs, 158 | warmup_epochs=cfg.warmup_epochs, 159 | ) 160 | else: 161 | objectives = build_objectives(cfg) 162 | metrics = build_training_metrics(cfg) 163 | model = LightningModuleWrapper( 164 | model=torch_module_wrapper, 165 | objectives=objectives, 166 | metrics=metrics, 167 | batch_size=cfg.data_loader.params.batch_size, 168 | optimizer=cfg.optimizer, 169 | lr_scheduler=cfg.lr_scheduler if "lr_scheduler" in cfg else None, 170 | warm_up_lr_scheduler=cfg.warm_up_lr_scheduler 171 | if "warm_up_lr_scheduler" in cfg 172 | else None, 173 | objective_aggregate_mode=cfg.objective_aggregate_mode, 174 | ) 175 | 176 | return cast(pl.LightningModule, model) 177 | 178 | 179 | def build_custom_trainer(cfg: DictConfig) -> pl.Trainer: 180 | """ 181 | Builds the lightning trainer from the config. 182 | :param cfg: omegaconf dictionary 183 | :return: built object. 184 | """ 185 | params = cfg.lightning.trainer.params 186 | 187 | # callbacks = build_callbacks(cfg) 188 | callbacks = [ 189 | ModelCheckpoint( 190 | dirpath=os.path.join(os.getcwd(), "checkpoints"), 191 | filename="{epoch}-{val_minFDE:.3f}", 192 | monitor=cfg.lightning.trainer.checkpoint.monitor, 193 | mode=cfg.lightning.trainer.checkpoint.mode, 194 | save_top_k=cfg.lightning.trainer.checkpoint.save_top_k, 195 | save_last=True, 196 | ), 197 | RichModelSummary(max_depth=1), 198 | RichProgressBar(), 199 | LearningRateMonitor(logging_interval="epoch"), 200 | ] 201 | 202 | if cfg.wandb.mode == "disable": 203 | training_logger = TensorBoardLogger( 204 | save_dir=cfg.group, 205 | name=cfg.experiment, 206 | log_graph=False, 207 | version="", 208 | prefix="", 209 | ) 210 | else: 211 | if cfg.wandb.artifact is not None: 212 | os.system(f"wandb artifact get {cfg.wandb.artifact}") 213 | _, _, artifact = cfg.wandb.artifact.split("/") 214 | checkpoint = os.path.join(os.getcwd(), f"artifacts/{artifact}/model.ckpt") 215 | run_id = artifact.split(":")[0][-8:] 216 | cfg.checkpoint = checkpoint 217 | cfg.wandb.run_id = run_id 218 | 219 | training_logger = WandbLogger( 220 | save_dir=cfg.group, 221 | project=cfg.wandb.project, 222 | name=cfg.wandb.name, 223 | mode=cfg.wandb.mode, 224 | log_model=cfg.wandb.log_model, 225 | resume=cfg.checkpoint is not None, 226 | id=cfg.wandb.run_id, 227 | ) 228 | 229 | trainer = pl.Trainer( 230 | callbacks=callbacks, 231 | logger=training_logger, 232 | **params, 233 | ) 234 | 235 | return trainer 236 | 237 | 238 | def build_training_engine(cfg: DictConfig, worker: WorkerPool) -> TrainingEngine: 239 | """ 240 | Build the three core lightning modules: LightningDataModule, LightningModule and Trainer 241 | :param cfg: omegaconf dictionary 242 | :param worker: Worker to submit tasks which can be executed in parallel 243 | :return: TrainingEngine 244 | """ 245 | logger.info("Building training engine...") 246 | 247 | trainer = build_custom_trainer(cfg) 248 | 249 | # Create model 250 | torch_module_wrapper = build_torch_module_wrapper(cfg.model) 251 | 252 | # Build the datamodule 253 | datamodule = build_lightning_datamodule(cfg, worker, torch_module_wrapper) 254 | 255 | # Build lightning module 256 | model = build_lightning_module(cfg, torch_module_wrapper) 257 | 258 | engine = TrainingEngine(trainer=trainer, datamodule=datamodule, model=model) 259 | 260 | return engine -------------------------------------------------------------------------------- /src/feature_builders/common.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from enum import IntEnum 4 | from typing import Any, Dict, List, Set, Tuple, Type, Union 5 | 6 | import cv2 7 | import numba 8 | import numpy as np 9 | from nuplan.common.actor_state.ego_state import EgoState 10 | from nuplan.common.actor_state.state_representation import Point2D 11 | from nuplan.common.actor_state.tracked_objects import TrackedObjects 12 | from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType 13 | from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters 14 | from nuplan.common.maps.abstract_map import AbstractMap, MapObject 15 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer, TrafficLightStatusType 16 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario 17 | 18 | EGO_PARAMS = get_pacifica_parameters() 19 | LANE_LAYERS = [SemanticMapLayer.LANE, SemanticMapLayer.LANE_CONNECTOR] 20 | 21 | HALF_WIDTH = EGO_PARAMS.half_width 22 | FRONT_LENGTH = EGO_PARAMS.front_length 23 | REAR_LENGTH = EGO_PARAMS.rear_length 24 | 25 | 26 | class PolylineElements(IntEnum): 27 | """ 28 | Enum for PolylineElements. 29 | """ 30 | 31 | LANE = 0 32 | BOUNDARY = 1 33 | STOP_LINE = 2 34 | CROSSWALK = 3 35 | 36 | @classmethod 37 | def deserialize(cls, layer: str) -> PolylineElements: 38 | """Deserialize the type when loading from a string.""" 39 | return PolylineElements.__members__[layer] 40 | 41 | 42 | GlobalTypeMapping = { 43 | "AV": 0, 44 | TrackedObjectType.VEHICLE: 1, 45 | TrackedObjectType.PEDESTRIAN: 2, 46 | TrackedObjectType.BICYCLE: 3, 47 | PolylineElements.LANE: 4, 48 | PolylineElements.BOUNDARY: 5, 49 | } 50 | 51 | 52 | def interpolate_polyline(points: np.ndarray, t: int) -> np.ndarray: 53 | """copy from av2-api""" 54 | 55 | if points.ndim != 2: 56 | raise ValueError("Input array must be (N,2) or (N,3) in shape.") 57 | 58 | # the number of points on the curve itself 59 | n, _ = points.shape 60 | 61 | # equally spaced in arclength -- the number of points that will be uniformly interpolated 62 | eq_spaced_points = np.linspace(0, 1, t) 63 | 64 | # Compute the chordal arclength of each segment. 65 | # Compute differences between each x coord, to get the dx's 66 | # Do the same to get dy's. Then the hypotenuse length is computed as a norm. 67 | chordlen: np.ndarray = np.linalg.norm(np.diff(points, axis=0), axis=1) # type: ignore 68 | # Normalize the arclengths to a unit total 69 | chordlen = chordlen / np.sum(chordlen) 70 | # cumulative arclength 71 | 72 | cumarc: np.ndarray = np.zeros(len(chordlen) + 1) 73 | cumarc[1:] = np.cumsum(chordlen) 74 | 75 | # which interval did each point fall in, in terms of eq_spaced_points? (bin index) 76 | tbins: np.ndarray = np.digitize(eq_spaced_points, bins=cumarc).astype(int) # type: ignore 77 | 78 | # #catch any problems at the ends 79 | tbins[np.where((tbins <= 0) | (eq_spaced_points <= 0))] = 1 # type: ignore 80 | tbins[np.where((tbins >= n) | (eq_spaced_points >= 1))] = n - 1 81 | 82 | chordlen[tbins - 1] = np.where( 83 | chordlen[tbins - 1] == 0, chordlen[tbins - 1] + 1e-6, chordlen[tbins - 1] 84 | ) 85 | 86 | s = np.divide((eq_spaced_points - cumarc[tbins - 1]), chordlen[tbins - 1]) 87 | anchors = points[tbins - 1, :] 88 | # broadcast to scale each row of `points` by a different row of s 89 | offsets = (points[tbins, :] - points[tbins - 1, :]) * s.reshape(-1, 1) 90 | points_interp: np.ndarray = anchors + offsets 91 | 92 | return points_interp 93 | 94 | 95 | def get_ego_corners(rear_axle_xy: np.ndarray, heading: np.ndarray): 96 | """ 97 | rear_axle_xy: [T, x, y] 98 | """ 99 | ego_corners_offset = np.array( 100 | [ 101 | [-REAR_LENGTH, -HALF_WIDTH], 102 | [-REAR_LENGTH, HALF_WIDTH], 103 | [FRONT_LENGTH, HALF_WIDTH], 104 | [FRONT_LENGTH, -HALF_WIDTH], 105 | ], 106 | dtype=np.float64, 107 | ) 108 | ego_corners = rear_axle_xy[..., None, :] + ego_corners_offset[None, ...] 109 | rotate_mat = np.zeros((len(heading), 2, 2), dtype=np.float64) 110 | rotate_mat[:, 0, 0] = np.cos(heading) 111 | rotate_mat[:, 0, 1] = np.sin(heading) 112 | rotate_mat[:, 1, 0] = -np.sin(heading) 113 | rotate_mat[:, 1, 1] = np.cos(heading) 114 | 115 | ego_corners = ego_corners @ rotate_mat 116 | return ego_corners 117 | 118 | 119 | @numba.njit 120 | def rotate_round_z_axis(points: np.ndarray, angle: float): 121 | rotate_mat = np.array( 122 | [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], 123 | # dtype=np.float64, 124 | ) 125 | # return np.matmul(points, rotate_mat) 126 | return points @ rotate_mat 127 | -------------------------------------------------------------------------------- /src/features/pluto_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List 5 | 6 | import numpy as np 7 | import torch 8 | from nuplan.planning.training.preprocessing.features.abstract_model_feature import ( 9 | AbstractModelFeature, 10 | ) 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | from src.utils.utils import to_device, to_numpy, to_tensor 14 | 15 | 16 | @dataclass 17 | class PlutoFeature(AbstractModelFeature): 18 | data: Dict[str, Any] # anchor sample 19 | data_p: Dict[str, Any] = None # positive sample 20 | data_n: Dict[str, Any] = None # negative sample 21 | data_n_info: Dict[str, Any] = None # negative sample info 22 | 23 | @classmethod 24 | def collate(cls, feature_list: List[PlutoFeature]) -> PlutoFeature: 25 | batch_data = {} 26 | 27 | pad_keys = ["agent", "map"] 28 | stack_keys = ["current_state", "origin", "angle"] 29 | 30 | if "reference_line" in feature_list[0].data: 31 | pad_keys.append("reference_line") 32 | if "static_objects" in feature_list[0].data: 33 | pad_keys.append("static_objects") 34 | if "cost_maps" in feature_list[0].data: 35 | stack_keys.append("cost_maps") 36 | 37 | if feature_list[0].data_n is not None: 38 | for key in pad_keys: 39 | batch_data[key] = { 40 | k: pad_sequence( 41 | [f.data[key][k] for f in feature_list] 42 | + [f.data_p[key][k] for f in feature_list] 43 | + [f.data_n[key][k] for f in feature_list], 44 | batch_first=True, 45 | ) 46 | for k in feature_list[0].data[key].keys() 47 | } 48 | 49 | batch_data["data_n_valid_mask"] = torch.Tensor( 50 | [f.data_n_info["valid_mask"] for f in feature_list] 51 | ).bool() 52 | batch_data["data_n_type"] = torch.Tensor( 53 | [f.data_n_info["type"] for f in feature_list] 54 | ).long() 55 | 56 | for key in stack_keys: 57 | batch_data[key] = torch.stack( 58 | [f.data[key] for f in feature_list] 59 | + [f.data_p[key] for f in feature_list] 60 | + [f.data_n[key] for f in feature_list], 61 | dim=0, 62 | ) 63 | elif feature_list[0].data_p is not None: 64 | for key in pad_keys: 65 | batch_data[key] = { 66 | k: pad_sequence( 67 | [f.data[key][k] for f in feature_list] 68 | + [f.data_p[key][k] for f in feature_list], 69 | batch_first=True, 70 | ) 71 | for k in feature_list[0].data[key].keys() 72 | } 73 | 74 | for key in stack_keys: 75 | batch_data[key] = torch.stack( 76 | [f.data[key] for f in feature_list] 77 | + [f.data_p[key] for f in feature_list], 78 | dim=0, 79 | ) 80 | else: 81 | for key in pad_keys: 82 | batch_data[key] = { 83 | k: pad_sequence( 84 | [f.data[key][k] for f in feature_list], batch_first=True 85 | ) 86 | for k in feature_list[0].data[key].keys() 87 | } 88 | 89 | for key in stack_keys: 90 | batch_data[key] = torch.stack( 91 | [f.data[key] for f in feature_list], dim=0 92 | ) 93 | 94 | return PlutoFeature(data=batch_data) 95 | 96 | def to_feature_tensor(self) -> PlutoFeature: 97 | new_data = {} 98 | for k, v in self.data.items(): 99 | new_data[k] = to_tensor(v) 100 | 101 | if self.data_p is not None: 102 | new_data_p = {} 103 | for k, v in self.data_p.items(): 104 | new_data_p[k] = to_tensor(v) 105 | else: 106 | new_data_p = None 107 | 108 | if self.data_n is not None: 109 | new_data_n = {} 110 | new_data_n_info = {} 111 | for k, v in self.data_n.items(): 112 | new_data_n[k] = to_tensor(v) 113 | for k, v in self.data_n_info.items(): 114 | new_data_n_info[k] = to_tensor(v) 115 | else: 116 | new_data_n = None 117 | new_data_n_info = None 118 | 119 | return PlutoFeature( 120 | data=new_data, 121 | data_p=new_data_p, 122 | data_n=new_data_n, 123 | data_n_info=new_data_n_info, 124 | ) 125 | 126 | def to_numpy(self) -> PlutoFeature: 127 | new_data = {} 128 | for k, v in self.data.items(): 129 | new_data[k] = to_numpy(v) 130 | if self.data_p is not None: 131 | new_data_p = {} 132 | for k, v in self.data_p.items(): 133 | new_data_p[k] = to_numpy(v) 134 | else: 135 | new_data_p = None 136 | if self.data_n is not None: 137 | new_data_n = {} 138 | for k, v in self.data_n.items(): 139 | new_data_n[k] = to_numpy(v) 140 | else: 141 | new_data_n = None 142 | return PlutoFeature(data=new_data, data_p=new_data_p, data_n=new_data_n) 143 | 144 | def to_device(self, device: torch.device) -> PlutoFeature: 145 | new_data = {} 146 | for k, v in self.data.items(): 147 | new_data[k] = to_device(v, device) 148 | return PlutoFeature(data=new_data) 149 | 150 | def serialize(self) -> Dict[str, Any]: 151 | return {"data": self.data} 152 | 153 | @classmethod 154 | def deserialize(cls, data: Dict[str, Any]) -> PlutoFeature: 155 | return PlutoFeature(data=data["data"]) 156 | 157 | def unpack(self) -> List[AbstractModelFeature]: 158 | raise NotImplementedError 159 | 160 | @property 161 | def is_valid(self) -> bool: 162 | if "reference_line" in self.data: 163 | return self.data["reference_line"]["valid_mask"].any() 164 | else: 165 | return self.data["map"]["point_position"].shape[0] > 0 166 | 167 | @classmethod 168 | def normalize( 169 | self, data, first_time=False, radius=None, hist_steps=21 170 | ) -> PlutoFeature: 171 | cur_state = data["current_state"] 172 | center_xy, center_angle = cur_state[:2].copy(), cur_state[2].copy() 173 | 174 | rotate_mat = np.array( 175 | [ 176 | [np.cos(center_angle), -np.sin(center_angle)], 177 | [np.sin(center_angle), np.cos(center_angle)], 178 | ], 179 | dtype=np.float64, 180 | ) 181 | 182 | data["current_state"][:3] = 0 183 | data["agent"]["position"] = np.matmul( 184 | data["agent"]["position"] - center_xy, rotate_mat 185 | ) 186 | data["agent"]["velocity"] = np.matmul(data["agent"]["velocity"], rotate_mat) 187 | data["agent"]["heading"] -= center_angle 188 | 189 | data["map"]["point_position"] = np.matmul( 190 | data["map"]["point_position"] - center_xy, rotate_mat 191 | ) 192 | data["map"]["point_vector"] = np.matmul(data["map"]["point_vector"], rotate_mat) 193 | data["map"]["point_orientation"] -= center_angle 194 | 195 | data["map"]["polygon_center"][..., :2] = np.matmul( 196 | data["map"]["polygon_center"][..., :2] - center_xy, rotate_mat 197 | ) 198 | data["map"]["polygon_center"][..., 2] -= center_angle 199 | data["map"]["polygon_position"] = np.matmul( 200 | data["map"]["polygon_position"] - center_xy, rotate_mat 201 | ) 202 | data["map"]["polygon_orientation"] -= center_angle 203 | 204 | if "causal" in data: 205 | if len(data["causal"]["free_path_points"]) > 0: 206 | data["causal"]["free_path_points"][..., :2] = np.matmul( 207 | data["causal"]["free_path_points"][..., :2] - center_xy, rotate_mat 208 | ) 209 | data["causal"]["free_path_points"][..., 2] -= center_angle 210 | if "static_objects" in data: 211 | data["static_objects"]["position"] = np.matmul( 212 | data["static_objects"]["position"] - center_xy, rotate_mat 213 | ) 214 | data["static_objects"]["heading"] -= center_angle 215 | if "route" in data: 216 | data["route"]["position"] = np.matmul( 217 | data["route"]["position"] - center_xy, rotate_mat 218 | ) 219 | if "reference_line" in data: 220 | data["reference_line"]["position"] = np.matmul( 221 | data["reference_line"]["position"] - center_xy, rotate_mat 222 | ) 223 | data["reference_line"]["vector"] = np.matmul( 224 | data["reference_line"]["vector"], rotate_mat 225 | ) 226 | data["reference_line"]["orientation"] -= center_angle 227 | 228 | target_position = ( 229 | data["agent"]["position"][:, hist_steps:] 230 | - data["agent"]["position"][:, hist_steps - 1][:, None] 231 | ) 232 | target_heading = ( 233 | data["agent"]["heading"][:, hist_steps:] 234 | - data["agent"]["heading"][:, hist_steps - 1][:, None] 235 | ) 236 | target = np.concatenate([target_position, target_heading[..., None]], -1) 237 | target[~data["agent"]["valid_mask"][:, hist_steps:]] = 0 238 | data["agent"]["target"] = target 239 | 240 | if first_time: 241 | point_position = data["map"]["point_position"] 242 | x_max, x_min = radius, -radius 243 | y_max, y_min = radius, -radius 244 | valid_mask = ( 245 | (point_position[:, 0, :, 0] < x_max) 246 | & (point_position[:, 0, :, 0] > x_min) 247 | & (point_position[:, 0, :, 1] < y_max) 248 | & (point_position[:, 0, :, 1] > y_min) 249 | ) 250 | valid_polygon = valid_mask.any(-1) 251 | data["map"]["valid_mask"] = valid_mask 252 | 253 | for k, v in data["map"].items(): 254 | data["map"][k] = v[valid_polygon] 255 | 256 | if "causal" in data: 257 | data["causal"]["ego_care_red_light_mask"] = data["causal"][ 258 | "ego_care_red_light_mask" 259 | ][valid_polygon] 260 | 261 | data["origin"] = center_xy 262 | data["angle"] = center_angle 263 | 264 | return PlutoFeature(data=data) 265 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .min_ade import minADE 2 | from .min_fde import minFDE 3 | from .mr import MR -------------------------------------------------------------------------------- /src/metrics/min_ade.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from .utils import sort_predictions 7 | 8 | 9 | class minADE(Metric): 10 | """Minimum Average Displacement Error 11 | minADE: The average L2 distance between the best forecasted trajectory and the ground truth. 12 | The best here refers to the trajectory that has the minimum endpoint error. 13 | """ 14 | 15 | full_state_update: Optional[bool] = False 16 | higher_is_better: Optional[bool] = False 17 | 18 | def __init__( 19 | self, 20 | k=6, 21 | compute_on_step: bool = True, 22 | dist_sync_on_step: bool = False, 23 | process_group: Optional[Any] = None, 24 | dist_sync_fn: Callable = None, 25 | ) -> None: 26 | super(minADE, self).__init__( 27 | compute_on_step=compute_on_step, 28 | dist_sync_on_step=dist_sync_on_step, 29 | process_group=process_group, 30 | dist_sync_fn=dist_sync_fn, 31 | ) 32 | self.k = k 33 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 34 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 35 | 36 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 37 | with torch.no_grad(): 38 | pred, _ = sort_predictions( 39 | outputs["trajectory"], outputs["probability"], k=self.k 40 | ) 41 | ade = torch.norm( 42 | pred[..., :2] - target.unsqueeze(1)[..., :2], p=2, dim=-1 43 | ).mean(-1) 44 | min_ade = ade.min(-1)[0] 45 | self.sum += min_ade.sum() 46 | self.count += pred.size(0) 47 | 48 | def compute(self) -> torch.Tensor: 49 | return self.sum / self.count 50 | -------------------------------------------------------------------------------- /src/metrics/min_fde.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from .utils import sort_predictions 7 | 8 | 9 | class minFDE(Metric): 10 | full_state_update: Optional[bool] = False 11 | higher_is_better: Optional[bool] = False 12 | 13 | def __init__( 14 | self, 15 | k=6, 16 | compute_on_step: bool = True, 17 | dist_sync_on_step: bool = False, 18 | process_group: Optional[Any] = None, 19 | dist_sync_fn: Callable = None, 20 | ) -> None: 21 | super(minFDE, self).__init__( 22 | compute_on_step=compute_on_step, 23 | dist_sync_on_step=dist_sync_on_step, 24 | process_group=process_group, 25 | dist_sync_fn=dist_sync_fn, 26 | ) 27 | self.k = k 28 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 29 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 30 | 31 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 32 | with torch.no_grad(): 33 | pred, _ = sort_predictions( 34 | outputs["trajectory"], outputs["probability"], k=self.k 35 | ) 36 | fde = torch.norm( 37 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1 38 | ) 39 | min_fde = fde.min(-1)[0] 40 | self.sum += min_fde.sum() 41 | self.count += pred.shape[0] 42 | 43 | def compute(self) -> torch.Tensor: 44 | return self.sum / self.count 45 | -------------------------------------------------------------------------------- /src/metrics/mr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class MR(Metric): 8 | full_state_update: Optional[bool] = False 9 | higher_is_better: Optional[bool] = False 10 | 11 | def __init__( 12 | self, 13 | miss_threshold: float = 2.0, 14 | compute_on_step: bool = True, 15 | dist_sync_on_step: bool = False, 16 | process_group: Optional[Any] = None, 17 | dist_sync_fn: Callable = None, 18 | ) -> None: 19 | super(MR, self).__init__( 20 | compute_on_step=compute_on_step, 21 | dist_sync_on_step=dist_sync_on_step, 22 | process_group=process_group, 23 | dist_sync_fn=dist_sync_fn, 24 | ) 25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 27 | self.miss_threshold = miss_threshold 28 | 29 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 30 | with torch.no_grad(): 31 | pred = outputs["trajectory"] 32 | missed_pred = ( 33 | torch.norm( 34 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1 35 | ) 36 | > self.miss_threshold 37 | ) 38 | self.sum += missed_pred.all(-1).sum() 39 | self.count += pred.shape[0] 40 | 41 | def compute(self) -> torch.Tensor: 42 | return self.sum / self.count 43 | -------------------------------------------------------------------------------- /src/metrics/prediction_avg_ade.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Dict 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | from torchmetrics.classification.accuracy import Accuracy 6 | 7 | 8 | class PredAvgADE(Metric): 9 | full_state_update: Optional[bool] = False 10 | higher_is_better: Optional[bool] = False 11 | 12 | def __init__( 13 | self, 14 | compute_on_step: bool = True, 15 | dist_sync_on_step: bool = False, 16 | process_group: Optional[Any] = None, 17 | dist_sync_fn: Callable = None, 18 | ) -> None: 19 | super(PredAvgADE, self).__init__( 20 | compute_on_step=compute_on_step, 21 | dist_sync_on_step=dist_sync_on_step, 22 | process_group=process_group, 23 | dist_sync_fn=dist_sync_fn, 24 | ) 25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 27 | 28 | def update( 29 | self, 30 | outputs: Dict[str, torch.Tensor], 31 | target: torch.Tensor, 32 | ) -> None: 33 | """ 34 | outputs: [A, T, 2] 35 | target: [A, T, 2] 36 | """ 37 | with torch.no_grad(): 38 | prediction, valid_mask = outputs["prediction"], outputs["valid_mask"] 39 | target = outputs["prediction_target"] 40 | ade = ( 41 | torch.norm(prediction - target[..., :2], p=2, dim=-1) * valid_mask 42 | ).sum(-1) / (valid_mask.sum(-1) + 1e-6) 43 | 44 | self.sum += ade.sum() 45 | self.count += valid_mask.any(-1).sum().item() 46 | 47 | def compute(self) -> torch.Tensor: 48 | return self.sum / self.count 49 | -------------------------------------------------------------------------------- /src/metrics/prediction_avg_fde.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Dict 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | from torchmetrics.classification.accuracy import Accuracy 6 | 7 | 8 | class PredAvgFDE(Metric): 9 | full_state_update: Optional[bool] = False 10 | higher_is_better: Optional[bool] = False 11 | 12 | def __init__( 13 | self, 14 | compute_on_step: bool = True, 15 | dist_sync_on_step: bool = False, 16 | process_group: Optional[Any] = None, 17 | dist_sync_fn: Callable = None, 18 | ) -> None: 19 | super(PredAvgFDE, self).__init__( 20 | compute_on_step=compute_on_step, 21 | dist_sync_on_step=dist_sync_on_step, 22 | process_group=process_group, 23 | dist_sync_fn=dist_sync_fn, 24 | ) 25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 27 | 28 | def update( 29 | self, 30 | outputs: Dict[str, torch.Tensor], 31 | target: torch.Tensor, 32 | ) -> None: 33 | """ 34 | outputs: [A, T, 2] 35 | target: [A, T, 2] 36 | """ 37 | with torch.no_grad(): 38 | prediction, valid_mask = outputs["prediction"], outputs["valid_mask"] 39 | target = outputs["prediction_target"] 40 | endpoint_mask = valid_mask[..., -1].float() 41 | fde = ( 42 | torch.norm(prediction[..., -1, :2] - target[..., -1, :2], p=2, dim=-1) 43 | * endpoint_mask 44 | ).sum(-1) 45 | 46 | self.sum += fde.sum() 47 | self.count += endpoint_mask.sum().long() 48 | 49 | def compute(self) -> torch.Tensor: 50 | return self.sum / self.count 51 | -------------------------------------------------------------------------------- /src/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sort_predictions(predictions, probability, k=6): 5 | """Sort the predictions based on the probability of each mode. 6 | Args: 7 | predictions (torch.Tensor): The predicted trajectories [b, k, t, 2]. 8 | probability (torch.Tensor): The probability of each mode [b, k]. 9 | Returns: 10 | torch.Tensor: The sorted predictions [b, k', t, 2]. 11 | """ 12 | indices = torch.argsort(probability, dim=-1, descending=True) 13 | sorted_prob = probability[torch.arange(probability.size(0))[:, None], indices] 14 | sorted_predictions = predictions[ 15 | torch.arange(predictions.size(0))[:, None], indices 16 | ] 17 | return sorted_predictions[:, :k], sorted_prob[:, :k] 18 | -------------------------------------------------------------------------------- /src/models/pluto/layers/common_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_mlp(c_in, channels, norm=None, activation="relu"): 5 | layers = [] 6 | num_layers = len(channels) 7 | 8 | if norm is not None: 9 | norm = get_norm(norm) 10 | 11 | activation = get_activation(activation) 12 | 13 | for k in range(num_layers): 14 | if k == num_layers - 1: 15 | layers.append(nn.Linear(c_in, channels[k], bias=True)) 16 | else: 17 | if norm is None: 18 | layers.extend([nn.Linear(c_in, channels[k], bias=True), activation()]) 19 | else: 20 | layers.extend( 21 | [ 22 | nn.Linear(c_in, channels[k], bias=False), 23 | norm(channels[k]), 24 | activation(), 25 | ] 26 | ) 27 | c_in = channels[k] 28 | 29 | return nn.Sequential(*layers) 30 | 31 | 32 | def get_norm(norm: str): 33 | if norm == "bn": 34 | return nn.BatchNorm1d 35 | elif norm == "ln": 36 | return nn.LayerNorm 37 | else: 38 | raise NotImplementedError 39 | 40 | 41 | def get_activation(activation: str): 42 | if activation == "relu": 43 | return nn.ReLU 44 | elif activation == "gelu": 45 | return nn.GELU 46 | else: 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /src/models/pluto/layers/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from natten import NeighborhoodAttention1D 5 | from timm.models.layers import DropPath 6 | 7 | 8 | class NATSequenceEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | in_chans=3, 12 | embed_dim=32, 13 | mlp_ratio=3, 14 | kernel_size=[3, 3, 5], 15 | depths=[2, 2, 2], 16 | num_heads=[2, 4, 8], 17 | out_indices=[0, 1, 2], 18 | drop_rate=0.0, 19 | attn_drop_rate=0.0, 20 | drop_path_rate=0.2, 21 | norm_layer=nn.LayerNorm, 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.embed = ConvTokenizer(in_chans, embed_dim) 26 | self.num_levels = len(depths) 27 | self.num_features = [int(embed_dim * 2**i) for i in range(self.num_levels)] 28 | self.out_indices = out_indices 29 | 30 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 31 | self.levels = nn.ModuleList() 32 | for i in range(self.num_levels): 33 | level = NATBlock( 34 | dim=int(embed_dim * 2**i), 35 | depth=depths[i], 36 | num_heads=num_heads[i], 37 | kernel_size=kernel_size[i], 38 | dilations=None, 39 | mlp_ratio=mlp_ratio, 40 | drop=drop_rate, 41 | attn_drop=attn_drop_rate, 42 | drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], 43 | norm_layer=norm_layer, 44 | downsample=(i < self.num_levels - 1), 45 | ) 46 | self.levels.append(level) 47 | 48 | for i_layer in self.out_indices: 49 | layer = norm_layer(self.num_features[i_layer]) 50 | layer_name = f"norm{i_layer}" 51 | self.add_module(layer_name, layer) 52 | 53 | n = self.num_features[-1] 54 | self.lateral_convs = nn.ModuleList() 55 | for i_layer in self.out_indices: 56 | self.lateral_convs.append( 57 | nn.Conv1d(self.num_features[i_layer], n, 3, padding=1) 58 | ) 59 | 60 | self.fpn_conv = nn.Conv1d(n, n, 3, padding=1) 61 | 62 | def forward(self, x): 63 | """x: [B, C, T]""" 64 | x = self.embed(x) 65 | 66 | out = [] 67 | for idx, level in enumerate(self.levels): 68 | x, xo = level(x) 69 | if idx in self.out_indices: 70 | norm_layer = getattr(self, f"norm{idx}") 71 | x_out = norm_layer(xo) 72 | out.append(x_out.permute(0, 2, 1).contiguous()) 73 | 74 | laterals = [ 75 | lateral_conv(out[i]) for i, lateral_conv in enumerate(self.lateral_convs) 76 | ] 77 | for i in range(len(out) - 1, 0, -1): 78 | laterals[i - 1] = laterals[i - 1] + F.interpolate( 79 | laterals[i], 80 | scale_factor=(laterals[i - 1].shape[-1] / laterals[i].shape[-1]), 81 | mode="linear", 82 | align_corners=False, 83 | ) 84 | 85 | out = self.fpn_conv(laterals[0]) 86 | 87 | return out[:, :, -1] 88 | 89 | 90 | class ConvTokenizer(nn.Module): 91 | def __init__(self, in_chans=3, embed_dim=32, norm_layer=None): 92 | super().__init__() 93 | self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=3, stride=1, padding=1) 94 | 95 | if norm_layer is not None: 96 | self.norm = norm_layer(embed_dim) 97 | else: 98 | self.norm = None 99 | 100 | def forward(self, x): 101 | x = self.proj(x).permute(0, 2, 1) # B, C, L -> B, L, C 102 | if self.norm is not None: 103 | x = self.norm(x) 104 | return x 105 | 106 | 107 | class ConvDownsampler(nn.Module): 108 | def __init__(self, dim, norm_layer=nn.LayerNorm): 109 | super().__init__() 110 | self.reduction = nn.Conv1d( 111 | dim, 2 * dim, kernel_size=3, stride=2, padding=1, bias=False 112 | ) 113 | self.norm = norm_layer(2 * dim) 114 | 115 | def forward(self, x): 116 | x = self.reduction(x.permute(0, 2, 1)).permute(0, 2, 1) 117 | x = self.norm(x) 118 | return x 119 | 120 | 121 | class Mlp(nn.Module): 122 | def __init__( 123 | self, 124 | in_features, 125 | hidden_features=None, 126 | out_features=None, 127 | act_layer=nn.GELU, 128 | drop=0.0, 129 | ): 130 | super().__init__() 131 | out_features = out_features or in_features 132 | hidden_features = hidden_features or in_features 133 | self.fc1 = nn.Linear(in_features, hidden_features) 134 | self.act = act_layer() 135 | self.fc2 = nn.Linear(hidden_features, out_features) 136 | self.drop = nn.Dropout(drop) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) 140 | x = self.act(x) 141 | x = self.drop(x) 142 | x = self.fc2(x) 143 | x = self.drop(x) 144 | return x 145 | 146 | 147 | class NATLayer(nn.Module): 148 | def __init__( 149 | self, 150 | dim, 151 | num_heads, 152 | kernel_size=7, 153 | dilation=None, 154 | mlp_ratio=4.0, 155 | qkv_bias=True, 156 | qk_scale=None, 157 | drop=0.0, 158 | attn_drop=0.0, 159 | drop_path=0.0, 160 | act_layer=nn.GELU, 161 | norm_layer=nn.LayerNorm, 162 | ): 163 | super().__init__() 164 | self.dim = dim 165 | self.num_heads = num_heads 166 | self.mlp_ratio = mlp_ratio 167 | 168 | self.norm1 = norm_layer(dim) 169 | self.attn = NeighborhoodAttention1D( 170 | dim, 171 | kernel_size=kernel_size, 172 | dilation=dilation, 173 | num_heads=num_heads, 174 | qkv_bias=qkv_bias, 175 | qk_scale=qk_scale, 176 | attn_drop=attn_drop, 177 | proj_drop=drop, 178 | ) 179 | 180 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 181 | self.norm2 = norm_layer(dim) 182 | self.mlp = Mlp( 183 | in_features=dim, 184 | hidden_features=int(dim * mlp_ratio), 185 | act_layer=act_layer, 186 | drop=drop, 187 | ) 188 | 189 | def forward(self, x): 190 | shortcut = x 191 | x = self.norm1(x) 192 | x = self.attn(x) 193 | x = shortcut + self.drop_path(x) 194 | x = x + self.drop_path(self.mlp(self.norm2(x))) 195 | return x 196 | 197 | 198 | class NATBlock(nn.Module): 199 | def __init__( 200 | self, 201 | dim, 202 | depth, 203 | num_heads, 204 | kernel_size, 205 | dilations=None, 206 | downsample=True, 207 | mlp_ratio=4.0, 208 | qkv_bias=True, 209 | qk_scale=None, 210 | drop=0.0, 211 | attn_drop=0.0, 212 | drop_path=0.0, 213 | norm_layer=nn.LayerNorm, 214 | act_layer=nn.GELU, 215 | ): 216 | super().__init__() 217 | self.dim = dim 218 | self.depth = depth 219 | 220 | self.blocks = nn.ModuleList( 221 | [ 222 | NATLayer( 223 | dim=dim, 224 | num_heads=num_heads, 225 | kernel_size=kernel_size, 226 | dilation=None if dilations is None else dilations[i], 227 | mlp_ratio=mlp_ratio, 228 | qkv_bias=qkv_bias, 229 | qk_scale=qk_scale, 230 | drop=drop, 231 | attn_drop=attn_drop, 232 | drop_path=drop_path[i] 233 | if isinstance(drop_path, list) 234 | else drop_path, 235 | norm_layer=norm_layer, 236 | act_layer=act_layer, 237 | ) 238 | for i in range(depth) 239 | ] 240 | ) 241 | 242 | self.downsample = ( 243 | None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer) 244 | ) 245 | 246 | def forward(self, x): 247 | for blk in self.blocks: 248 | x = blk(x) 249 | if self.downsample is None: 250 | return x, x 251 | return self.downsample(x), x 252 | 253 | 254 | class PointsEncoder(nn.Module): 255 | def __init__(self, feat_channel, encoder_channel): 256 | super().__init__() 257 | self.encoder_channel = encoder_channel 258 | self.first_mlp = nn.Sequential( 259 | nn.Linear(feat_channel, 128), 260 | nn.BatchNorm1d(128), 261 | nn.ReLU(inplace=True), 262 | nn.Linear(128, 256), 263 | ) 264 | self.second_mlp = nn.Sequential( 265 | nn.Linear(512, 256), 266 | nn.BatchNorm1d(256), 267 | nn.ReLU(inplace=True), 268 | nn.Linear(256, self.encoder_channel), 269 | ) 270 | 271 | def forward(self, x, mask=None): 272 | """ 273 | x : B M 3 274 | mask: B M 275 | ----------------- 276 | feature_global : B C 277 | """ 278 | 279 | bs, n, _ = x.shape 280 | device = x.device 281 | 282 | x_valid = self.first_mlp(x[mask]) # B n 256 283 | x_features = torch.zeros(bs, n, 256, device=device) 284 | x_features[mask] = x_valid 285 | 286 | pooled_feature = x_features.max(dim=1)[0] 287 | x_features = torch.cat( 288 | [x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1 289 | ) 290 | 291 | x_features_valid = self.second_mlp(x_features[mask]) 292 | res = torch.zeros(bs, n, self.encoder_channel, device=device) 293 | res[mask] = x_features_valid 294 | 295 | res = res.max(dim=1)[0] 296 | return res 297 | -------------------------------------------------------------------------------- /src/models/pluto/layers/fourier_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | from typing import List, Optional 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class FourierEmbedding(nn.Module): 22 | def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int) -> None: 23 | super(FourierEmbedding, self).__init__() 24 | self.input_dim = input_dim 25 | self.hidden_dim = hidden_dim 26 | 27 | self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None 28 | self.mlps = nn.ModuleList( 29 | [ 30 | nn.Sequential( 31 | nn.Linear(num_freq_bands * 2 + 1, hidden_dim), 32 | nn.LayerNorm(hidden_dim), 33 | nn.ReLU(inplace=True), 34 | nn.Linear(hidden_dim, hidden_dim), 35 | ) 36 | for _ in range(input_dim) 37 | ] 38 | ) 39 | self.to_out = nn.Sequential( 40 | nn.LayerNorm(hidden_dim), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(hidden_dim, hidden_dim), 43 | ) 44 | 45 | def forward( 46 | self, 47 | continuous_inputs: Optional[torch.Tensor], 48 | ) -> torch.Tensor: 49 | x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi 50 | x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) 51 | continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim 52 | for i in range(self.input_dim): 53 | continuous_embs[i] = self.mlps[i](x[..., i, :]) 54 | x = torch.stack(continuous_embs).sum(dim=0) 55 | return self.to_out(x) 56 | -------------------------------------------------------------------------------- /src/models/pluto/layers/mlp_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLPLayer(nn.Module): 5 | def __init__(self, channel_in, hidden, channel_out) -> None: 6 | super().__init__() 7 | 8 | self.mlp = nn.Sequential( 9 | nn.Linear(channel_in, hidden), 10 | nn.LayerNorm(hidden), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(hidden, channel_out), 13 | ) 14 | 15 | def forward(self, x): 16 | return self.mlp(x) 17 | -------------------------------------------------------------------------------- /src/models/pluto/layers/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from timm.models.layers import DropPath 7 | from torch import Tensor 8 | 9 | 10 | class Mlp(nn.Module): 11 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 12 | 13 | def __init__( 14 | self, 15 | in_features, 16 | hidden_features=None, 17 | out_features=None, 18 | act_layer=nn.GELU, 19 | drop=0.0, 20 | ): 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | 25 | self.fc1 = nn.Linear(in_features, hidden_features) 26 | self.act = act_layer() 27 | self.drop1 = nn.Dropout(drop) 28 | self.fc2 = nn.Linear(hidden_features, out_features) 29 | self.drop2 = nn.Dropout(drop) 30 | 31 | def forward(self, x): 32 | x = self.fc1(x) 33 | x = self.act(x) 34 | x = self.drop1(x) 35 | x = self.fc2(x) 36 | x = self.drop2(x) 37 | return x 38 | 39 | 40 | class TransformerEncoderLayer(nn.Module): 41 | def __init__( 42 | self, 43 | dim, 44 | num_heads, 45 | mlp_ratio=4.0, 46 | qkv_bias=False, 47 | drop=0.0, 48 | attn_drop=0.0, 49 | drop_path=0.0, 50 | act_layer=nn.GELU, 51 | norm_layer=nn.LayerNorm, 52 | ): 53 | super().__init__() 54 | self.norm1 = norm_layer(dim) 55 | self.attn = torch.nn.MultiheadAttention( 56 | dim, 57 | num_heads=num_heads, 58 | add_bias_kv=qkv_bias, 59 | dropout=attn_drop, 60 | batch_first=True, 61 | ) 62 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 63 | 64 | self.norm2 = norm_layer(dim) 65 | self.mlp = Mlp( 66 | in_features=dim, 67 | hidden_features=int(dim * mlp_ratio), 68 | act_layer=act_layer, 69 | drop=drop, 70 | ) 71 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 72 | 73 | def forward( 74 | self, 75 | src, 76 | mask: Optional[Tensor] = None, 77 | key_padding_mask: Optional[Tensor] = None, 78 | return_attn_weights=False, 79 | ): 80 | src2 = self.norm1(src) 81 | src2, attn = self.attn( 82 | query=src2, 83 | key=src2, 84 | value=src2, 85 | attn_mask=mask, 86 | key_padding_mask=key_padding_mask, 87 | ) 88 | src = src + self.drop_path1(src2) 89 | src = src + self.drop_path2(self.mlp(self.norm2(src))) 90 | 91 | if return_attn_weights: 92 | return src, attn 93 | 94 | return src 95 | 96 | 97 | class CrossAttentionLayer(nn.Module): 98 | def __init__( 99 | self, 100 | dim, 101 | num_heads, 102 | mlp_ratio=4, 103 | qkv_bias=False, 104 | dropout=0.1, 105 | attn_drop=0.0, 106 | act_layer=nn.GELU, 107 | norm_layer=nn.LayerNorm, 108 | norm_first=True, 109 | ): 110 | super().__init__() 111 | 112 | self.norm_first = norm_first 113 | self.attn = torch.nn.MultiheadAttention( 114 | dim, 115 | num_heads=num_heads, 116 | add_bias_kv=qkv_bias, 117 | dropout=attn_drop, 118 | batch_first=True, 119 | ) 120 | 121 | self.linear1 = nn.Linear(dim, int(mlp_ratio * dim)) 122 | self.activation = act_layer() 123 | self.linear2 = nn.Linear(int(mlp_ratio * dim), dim) 124 | 125 | self.norm1 = norm_layer(dim) 126 | self.norm2 = norm_layer(dim) 127 | self.dropout1 = nn.Dropout(dropout) 128 | self.dropout2 = nn.Dropout(dropout) 129 | self.dropout3 = nn.Dropout(dropout) 130 | 131 | def forward( 132 | self, 133 | x, 134 | memory, 135 | mask: Optional[Tensor] = None, 136 | key_padding_mask: Optional[Tensor] = None, 137 | ): 138 | if self.norm_first: 139 | x = x + self._mha_block(self.norm1(x), memory, mask, key_padding_mask) 140 | x = x + self._ff_block(self.norm2(x)) 141 | else: 142 | x = self.norm1(x + self._mha_block(x, memory, mask, key_padding_mask)) 143 | x = self.norm2(x + self._ff_block(x)) 144 | 145 | return x 146 | 147 | def _mha_block( 148 | self, 149 | x: Tensor, 150 | mem: Tensor, 151 | attn_mask: Optional[Tensor], 152 | key_padding_mask: Optional[Tensor], 153 | ) -> Tensor: 154 | x = self.attn( 155 | x, 156 | mem, 157 | mem, 158 | attn_mask=attn_mask, 159 | key_padding_mask=key_padding_mask, 160 | need_weights=False, 161 | )[0] 162 | return self.dropout1(x) 163 | 164 | def _ff_block(self, x: Tensor) -> Tensor: 165 | x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) 166 | return self.dropout3(x) 167 | 168 | 169 | class TransformerDecoderLayer(nn.Module): 170 | def __init__( 171 | self, 172 | d_model, 173 | nhead, 174 | dim_feedforward=2048, 175 | dropout=0.1, 176 | activation="relu", 177 | normalize_before=False, 178 | ): 179 | super().__init__() 180 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 181 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 182 | # Implementation of Feedforward model 183 | self.linear1 = nn.Linear(d_model, dim_feedforward) 184 | self.dropout = nn.Dropout(dropout) 185 | self.linear2 = nn.Linear(dim_feedforward, d_model) 186 | 187 | self.norm1 = nn.LayerNorm(d_model) 188 | self.norm2 = nn.LayerNorm(d_model) 189 | self.norm3 = nn.LayerNorm(d_model) 190 | self.dropout1 = nn.Dropout(dropout) 191 | self.dropout2 = nn.Dropout(dropout) 192 | self.dropout3 = nn.Dropout(dropout) 193 | 194 | self.activation = _get_activation_fn(activation) 195 | self.normalize_before = normalize_before 196 | 197 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 198 | return tensor if pos is None else tensor + pos 199 | 200 | def forward_post( 201 | self, 202 | tgt, 203 | memory, 204 | tgt_mask: Optional[Tensor] = None, 205 | memory_mask: Optional[Tensor] = None, 206 | tgt_key_padding_mask: Optional[Tensor] = None, 207 | memory_key_padding_mask: Optional[Tensor] = None, 208 | pos: Optional[Tensor] = None, 209 | query_pos: Optional[Tensor] = None, 210 | ): 211 | q = k = self.with_pos_embed(tgt, query_pos) 212 | tgt2 = self.self_attn( 213 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 214 | )[0] 215 | tgt = tgt + self.dropout1(tgt2) 216 | tgt = self.norm1(tgt) 217 | tgt2 = self.multihead_attn( 218 | query=self.with_pos_embed(tgt, query_pos), 219 | key=self.with_pos_embed(memory, pos), 220 | value=memory, 221 | attn_mask=memory_mask, 222 | key_padding_mask=memory_key_padding_mask, 223 | )[0] 224 | tgt = tgt + self.dropout2(tgt2) 225 | tgt = self.norm2(tgt) 226 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 227 | tgt = tgt + self.dropout3(tgt2) 228 | tgt = self.norm3(tgt) 229 | return tgt 230 | 231 | def forward_pre( 232 | self, 233 | tgt, 234 | memory, 235 | tgt_mask: Optional[Tensor] = None, 236 | memory_mask: Optional[Tensor] = None, 237 | tgt_key_padding_mask: Optional[Tensor] = None, 238 | memory_key_padding_mask: Optional[Tensor] = None, 239 | pos: Optional[Tensor] = None, 240 | query_pos: Optional[Tensor] = None, 241 | ): 242 | tgt2 = self.norm1(tgt) 243 | q = k = self.with_pos_embed(tgt2, query_pos) 244 | tgt2 = self.self_attn( 245 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 246 | )[0] 247 | tgt = tgt + self.dropout1(tgt2) 248 | tgt2 = self.norm2(tgt) 249 | tgt2 = self.multihead_attn( 250 | query=self.with_pos_embed(tgt2, query_pos), 251 | key=self.with_pos_embed(memory, pos), 252 | value=memory, 253 | attn_mask=memory_mask, 254 | key_padding_mask=memory_key_padding_mask, 255 | )[0] 256 | tgt = tgt + self.dropout2(tgt2) 257 | tgt2 = self.norm3(tgt) 258 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 259 | tgt = tgt + self.dropout3(tgt2) 260 | return tgt 261 | 262 | def forward( 263 | self, 264 | tgt, 265 | memory, 266 | tgt_mask: Optional[Tensor] = None, 267 | memory_mask: Optional[Tensor] = None, 268 | tgt_key_padding_mask: Optional[Tensor] = None, 269 | memory_key_padding_mask: Optional[Tensor] = None, 270 | pos: Optional[Tensor] = None, 271 | query_pos: Optional[Tensor] = None, 272 | ): 273 | if self.normalize_before: 274 | return self.forward_pre( 275 | tgt, 276 | memory, 277 | tgt_mask, 278 | memory_mask, 279 | tgt_key_padding_mask, 280 | memory_key_padding_mask, 281 | pos, 282 | query_pos, 283 | ) 284 | return self.forward_post( 285 | tgt, 286 | memory, 287 | tgt_mask, 288 | memory_mask, 289 | tgt_key_padding_mask, 290 | memory_key_padding_mask, 291 | pos, 292 | query_pos, 293 | ) 294 | 295 | 296 | def get_activation_fn(activation): 297 | """Return an activation function given a string""" 298 | if activation == "relu": 299 | return F.relu 300 | if activation == "gelu": 301 | return F.gelu 302 | if activation == "glu": 303 | return F.glu 304 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 305 | -------------------------------------------------------------------------------- /src/models/pluto/loss/esdf_collision_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | class ESDFCollisionLoss(nn.Module): 10 | def __init__( 11 | self, 12 | num_circles=3, 13 | ego_width=2.297, 14 | ego_front_length=4.049, 15 | ego_rear_length=1.127, 16 | resolution=0.2, 17 | ) -> None: 18 | super().__init__() 19 | 20 | ego_length = ego_front_length + ego_rear_length 21 | interval = ego_length / num_circles 22 | 23 | self.N = num_circles 24 | self.width = ego_width 25 | self.length = ego_length 26 | self.rear_length = ego_rear_length 27 | self.resolution = resolution 28 | 29 | self.radius = math.sqrt(ego_width**2 + interval**2) / 2 - resolution 30 | self.offset = torch.Tensor( 31 | [-ego_rear_length + interval / 2 * (2 * i + 1) for i in range(num_circles)] 32 | ) 33 | 34 | def forward(self, trajectory: Tensor, sdf: Tensor): 35 | """ 36 | trajectory: (bs, T, 4) - [x, y, cos0, sin0] 37 | sdf: (bs, H, W) 38 | """ 39 | bs, H, W = sdf.shape 40 | 41 | origin_offset = torch.tensor([W // 2, H // 2], device=sdf.device) 42 | offset = self.offset.to(sdf.device).view(1, 1, self.N, 1) 43 | # (bs, T, N, 2) 44 | centers = trajectory[..., None, :2] + offset * trajectory[..., None, 2:4] 45 | 46 | pixel_coord = torch.stack( 47 | [centers[..., 0] / self.resolution, -centers[..., 1] / self.resolution], 48 | dim=-1, 49 | ) 50 | grid_xy = pixel_coord / origin_offset 51 | valid_mask = (grid_xy < 0.95).all(-1) & (grid_xy > -0.95).all(-1) 52 | on_road_mask = sdf[:, H // 2, W // 2] > 0 53 | 54 | # (bs, T, N) 55 | distance = F.grid_sample( 56 | sdf.unsqueeze(1), grid_xy, mode="bilinear", padding_mode="zeros" 57 | ).squeeze(1) 58 | 59 | cost = self.radius - distance 60 | valid_mask = valid_mask & (cost > 0) & on_road_mask[:, None, None] 61 | cost.masked_fill_(~valid_mask, 0) 62 | 63 | loss = F.l1_loss(cost, torch.zeros_like(cost), reduction="none").sum(-1) 64 | loss = loss.sum() / (valid_mask.sum() + 1e-6) 65 | 66 | return loss 67 | -------------------------------------------------------------------------------- /src/models/pluto/modules/agent_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..layers.common_layers import build_mlp 5 | from ..layers.embedding import NATSequenceEncoder 6 | 7 | 8 | class AgentEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | state_channel=6, 12 | history_channel=9, 13 | dim=128, 14 | hist_steps=21, 15 | use_ego_history=False, 16 | drop_path=0.2, 17 | state_attn_encoder=True, 18 | state_dropout=0.75, 19 | ) -> None: 20 | super().__init__() 21 | self.dim = dim 22 | self.state_channel = state_channel 23 | self.use_ego_history = use_ego_history 24 | self.hist_steps = hist_steps 25 | self.state_attn_encoder = state_attn_encoder 26 | 27 | self.history_encoder = NATSequenceEncoder( 28 | in_chans=history_channel, embed_dim=dim // 4, drop_path_rate=drop_path 29 | ) 30 | 31 | if not use_ego_history: 32 | if not self.state_attn_encoder: 33 | self.ego_state_emb = build_mlp(state_channel, [dim] * 2, norm="bn") 34 | else: 35 | self.ego_state_emb = StateAttentionEncoder( 36 | state_channel, dim, state_dropout 37 | ) 38 | 39 | self.type_emb = nn.Embedding(4, dim) 40 | 41 | @staticmethod 42 | def to_vector(feat, valid_mask): 43 | vec_mask = valid_mask[..., :-1] & valid_mask[..., 1:] 44 | 45 | while len(vec_mask.shape) < len(feat.shape): 46 | vec_mask = vec_mask.unsqueeze(-1) 47 | 48 | return torch.where( 49 | vec_mask, 50 | feat[:, :, 1:, ...] - feat[:, :, :-1, ...], 51 | torch.zeros_like(feat[:, :, 1:, ...]), 52 | ) 53 | 54 | def forward(self, data): 55 | T = self.hist_steps 56 | 57 | position = data["agent"]["position"][:, :, :T] 58 | heading = data["agent"]["heading"][:, :, :T] 59 | velocity = data["agent"]["velocity"][:, :, :T] 60 | shape = data["agent"]["shape"][:, :, :T] 61 | category = data["agent"]["category"].long() 62 | valid_mask = data["agent"]["valid_mask"][:, :, :T] 63 | 64 | heading_vec = self.to_vector(heading, valid_mask) 65 | valid_mask_vec = valid_mask[..., 1:] & valid_mask[..., :-1] 66 | agent_feature = torch.cat( 67 | [ 68 | self.to_vector(position, valid_mask), 69 | self.to_vector(velocity, valid_mask), 70 | torch.stack([heading_vec.cos(), heading_vec.sin()], dim=-1), 71 | shape[:, :, 1:], 72 | valid_mask_vec.float().unsqueeze(-1), 73 | ], 74 | dim=-1, 75 | ) 76 | bs, A, T, _ = agent_feature.shape 77 | agent_feature = agent_feature.view(bs * A, T, -1) 78 | valid_agent_mask = valid_mask.any(-1).flatten() 79 | 80 | x_agent_tmp = self.history_encoder( 81 | agent_feature[valid_agent_mask].permute(0, 2, 1).contiguous() 82 | ) 83 | x_agent = torch.zeros(bs * A, self.dim, device=position.device) 84 | x_agent[valid_agent_mask] = x_agent_tmp 85 | x_agent = x_agent.view(bs, A, self.dim) 86 | 87 | if not self.use_ego_history: 88 | ego_feature = data["current_state"][:, : self.state_channel] 89 | x_ego = self.ego_state_emb(ego_feature) 90 | x_agent[:, 0] = x_ego 91 | 92 | x_type = self.type_emb(category) 93 | 94 | return x_agent + x_type 95 | 96 | 97 | class StateAttentionEncoder(nn.Module): 98 | def __init__(self, state_channel, dim, state_dropout=0.5) -> None: 99 | super().__init__() 100 | 101 | self.state_channel = state_channel 102 | self.state_dropout = state_dropout 103 | self.linears = nn.ModuleList([nn.Linear(1, dim) for _ in range(state_channel)]) 104 | self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=4, batch_first=True) 105 | self.pos_embed = nn.Parameter(torch.Tensor(1, state_channel, dim)) 106 | self.query = nn.Parameter(torch.Tensor(1, 1, dim)) 107 | 108 | nn.init.normal_(self.pos_embed, std=0.02) 109 | nn.init.normal_(self.query, std=0.02) 110 | 111 | def forward(self, x): 112 | x_embed = [] 113 | for i, linear in enumerate(self.linears): 114 | x_embed.append(linear(x[:, i, None])) 115 | x_embed = torch.stack(x_embed, dim=1) 116 | pos_embed = self.pos_embed.repeat(x_embed.shape[0], 1, 1) 117 | x_embed += pos_embed 118 | 119 | if self.training and self.state_dropout > 0: 120 | visible_tokens = torch.zeros( 121 | (x_embed.shape[0], 3), device=x.device, dtype=torch.bool 122 | ) 123 | dropout_tokens = ( 124 | torch.rand((x_embed.shape[0], self.state_channel - 3), device=x.device) 125 | < self.state_dropout 126 | ) 127 | key_padding_mask = torch.concat([visible_tokens, dropout_tokens], dim=1) 128 | else: 129 | key_padding_mask = None 130 | 131 | query = self.query.repeat(x_embed.shape[0], 1, 1) 132 | 133 | x_state = self.attn( 134 | query=query, 135 | key=x_embed, 136 | value=x_embed, 137 | key_padding_mask=key_padding_mask, 138 | )[0] 139 | 140 | return x_state[:, 0] 141 | -------------------------------------------------------------------------------- /src/models/pluto/modules/agent_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..layers.mlp_layer import MLPLayer 5 | 6 | 7 | class AgentPredictor(nn.Module): 8 | def __init__(self, dim, future_steps) -> None: 9 | super().__init__() 10 | 11 | self.future_steps = future_steps 12 | 13 | self.loc_predictor = MLPLayer(dim, 2 * dim, future_steps * 2) 14 | self.yaw_predictor = MLPLayer(dim, 2 * dim, future_steps * 2) 15 | self.vel_predictor = MLPLayer(dim, 2 * dim, future_steps * 2) 16 | 17 | def forward(self, x): 18 | """ 19 | x: (bs, N, dim) 20 | """ 21 | 22 | bs, N, _ = x.shape 23 | 24 | loc = self.loc_predictor(x).view(bs, N, self.future_steps, 2) 25 | yaw = self.yaw_predictor(x).view(bs, N, self.future_steps, 2) 26 | vel = self.vel_predictor(x).view(bs, N, self.future_steps, 2) 27 | 28 | prediction = torch.cat([loc, yaw, vel], dim=-1) 29 | return prediction 30 | -------------------------------------------------------------------------------- /src/models/pluto/modules/map_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..layers.embedding import PointsEncoder 5 | from ..layers.fourier_embedding import FourierEmbedding 6 | 7 | 8 | class MapEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | polygon_channel=6, 12 | dim=128, 13 | use_lane_boundary=False, 14 | ) -> None: 15 | super().__init__() 16 | 17 | self.dim = dim 18 | self.use_lane_boundary = use_lane_boundary 19 | self.polygon_channel = ( 20 | polygon_channel + 4 if use_lane_boundary else polygon_channel 21 | ) 22 | 23 | self.polygon_encoder = PointsEncoder(self.polygon_channel, dim) 24 | self.speed_limit_emb = FourierEmbedding(1, dim, 64) 25 | 26 | self.type_emb = nn.Embedding(3, dim) 27 | self.on_route_emb = nn.Embedding(2, dim) 28 | self.traffic_light_emb = nn.Embedding(4, dim) 29 | self.unknown_speed_emb = nn.Embedding(1, dim) 30 | 31 | def forward(self, data) -> torch.Tensor: 32 | polygon_center = data["map"]["polygon_center"] 33 | polygon_type = data["map"]["polygon_type"].long() 34 | polygon_on_route = data["map"]["polygon_on_route"].long() 35 | polygon_tl_status = data["map"]["polygon_tl_status"].long() 36 | polygon_has_speed_limit = data["map"]["polygon_has_speed_limit"] 37 | polygon_speed_limit = data["map"]["polygon_speed_limit"] 38 | point_position = data["map"]["point_position"] 39 | point_vector = data["map"]["point_vector"] 40 | point_orientation = data["map"]["point_orientation"] 41 | valid_mask = data["map"]["valid_mask"] 42 | 43 | if self.use_lane_boundary: 44 | polygon_feature = torch.cat( 45 | [ 46 | point_position[:, :, 0] - polygon_center[..., None, :2], 47 | point_vector[:, :, 0], 48 | torch.stack( 49 | [ 50 | point_orientation[:, :, 0].cos(), 51 | point_orientation[:, :, 0].sin(), 52 | ], 53 | dim=-1, 54 | ), 55 | point_position[:, :, 1] - point_position[:, :, 0], 56 | point_position[:, :, 2] - point_position[:, :, 0], 57 | ], 58 | dim=-1, 59 | ) 60 | else: 61 | polygon_feature = torch.cat( 62 | [ 63 | point_position[:, :, 0] - polygon_center[..., None, :2], 64 | point_vector[:, :, 0], 65 | torch.stack( 66 | [ 67 | point_orientation[:, :, 0].cos(), 68 | point_orientation[:, :, 0].sin(), 69 | ], 70 | dim=-1, 71 | ), 72 | ], 73 | dim=-1, 74 | ) 75 | 76 | bs, M, P, C = polygon_feature.shape 77 | valid_mask = valid_mask.view(bs * M, P) 78 | polygon_feature = polygon_feature.reshape(bs * M, P, C) 79 | 80 | x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1) 81 | 82 | x_type = self.type_emb(polygon_type) 83 | x_on_route = self.on_route_emb(polygon_on_route) 84 | x_tl_status = self.traffic_light_emb(polygon_tl_status) 85 | x_speed_limit = torch.zeros(bs, M, self.dim, device=x_polygon.device) 86 | x_speed_limit[polygon_has_speed_limit] = self.speed_limit_emb( 87 | polygon_speed_limit[polygon_has_speed_limit].unsqueeze(-1) 88 | ) 89 | x_speed_limit[~polygon_has_speed_limit] = self.unknown_speed_emb.weight 90 | 91 | x_polygon += x_type + x_on_route + x_tl_status + x_speed_limit 92 | 93 | return x_polygon 94 | -------------------------------------------------------------------------------- /src/models/pluto/modules/planning_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from ..layers.embedding import PointsEncoder 8 | from ..layers.fourier_embedding import FourierEmbedding 9 | from ..layers.mlp_layer import MLPLayer 10 | 11 | 12 | class DecoderLayer(nn.Module): 13 | def __init__(self, dim, num_heads, mlp_ratio, dropout) -> None: 14 | super().__init__() 15 | self.dim = dim 16 | 17 | self.r2r_attn = nn.MultiheadAttention( 18 | dim, num_heads, dropout=dropout, batch_first=True 19 | ) 20 | self.m2m_attn = nn.MultiheadAttention( 21 | dim, num_heads, dropout=dropout, batch_first=True 22 | ) 23 | self.cross_attn = nn.MultiheadAttention( 24 | dim, num_heads, dropout=dropout, batch_first=True 25 | ) 26 | 27 | self.ffn = nn.Sequential( 28 | nn.Linear(dim, dim * mlp_ratio), 29 | nn.ReLU(inplace=True), 30 | nn.Dropout(dropout), 31 | nn.Linear(dim * mlp_ratio, dim), 32 | ) 33 | 34 | self.norm1 = nn.LayerNorm(dim) 35 | self.norm2 = nn.LayerNorm(dim) 36 | self.norm3 = nn.LayerNorm(dim) 37 | self.norm4 = nn.LayerNorm(dim) 38 | self.dropout1 = nn.Dropout(dropout) 39 | self.dropout2 = nn.Dropout(dropout) 40 | self.dropout3 = nn.Dropout(dropout) 41 | 42 | def forward( 43 | self, 44 | tgt, 45 | memory, 46 | tgt_key_padding_mask: Optional[Tensor] = None, 47 | memory_key_padding_mask: Optional[Tensor] = None, 48 | m_pos: Optional[Tensor] = None, 49 | ): 50 | """ 51 | tgt: (bs, R, M, dim) 52 | tgt_key_padding_mask: (bs, R) 53 | """ 54 | bs, R, M, D = tgt.shape 55 | 56 | tgt = tgt.transpose(1, 2).reshape(bs * M, R, D) 57 | tgt2 = self.norm1(tgt) 58 | tgt2 = self.r2r_attn( 59 | tgt2, tgt2, tgt2, key_padding_mask=tgt_key_padding_mask.repeat(M, 1) 60 | )[0] 61 | tgt = tgt + self.dropout1(tgt2) 62 | 63 | tgt_tmp = tgt.reshape(bs, M, R, D).transpose(1, 2).reshape(bs * R, M, D) 64 | tgt_valid_mask = ~tgt_key_padding_mask.reshape(-1) 65 | tgt_valid = tgt_tmp[tgt_valid_mask] 66 | tgt2_valid = self.norm2(tgt_valid) 67 | tgt2_valid, _ = self.m2m_attn( 68 | tgt2_valid + m_pos, tgt2_valid + m_pos, tgt2_valid 69 | ) 70 | tgt_valid = tgt_valid + self.dropout2(tgt2_valid) 71 | tgt = torch.zeros_like(tgt_tmp) 72 | tgt[tgt_valid_mask] = tgt_valid 73 | 74 | tgt = tgt.reshape(bs, R, M, D).view(bs, R * M, D) 75 | tgt2 = self.norm3(tgt) 76 | tgt2 = self.cross_attn( 77 | tgt2, memory, memory, key_padding_mask=memory_key_padding_mask 78 | )[0] 79 | 80 | tgt = tgt + self.dropout2(tgt2) 81 | tgt2 = self.norm4(tgt) 82 | tgt2 = self.ffn(tgt2) 83 | tgt = tgt + self.dropout3(tgt2) 84 | tgt = tgt.reshape(bs, R, M, D) 85 | 86 | return tgt 87 | 88 | 89 | class PlanningDecoder(nn.Module): 90 | def __init__( 91 | self, 92 | num_mode, 93 | decoder_depth, 94 | dim, 95 | num_heads, 96 | mlp_ratio, 97 | dropout, 98 | future_steps, 99 | yaw_constraint=False, 100 | cat_x=False, 101 | ) -> None: 102 | super().__init__() 103 | 104 | self.num_mode = num_mode 105 | self.future_steps = future_steps 106 | self.yaw_constraint = yaw_constraint 107 | self.cat_x = cat_x 108 | 109 | self.decoder_blocks = nn.ModuleList( 110 | [ 111 | DecoderLayer(dim, num_heads, mlp_ratio, dropout) 112 | for _ in range(decoder_depth) 113 | ] 114 | ) 115 | 116 | self.r_pos_emb = FourierEmbedding(3, dim, 64) 117 | self.r_encoder = PointsEncoder(6, dim) 118 | 119 | self.q_proj = nn.Linear(2 * dim, dim) 120 | 121 | self.m_emb = nn.Parameter(torch.Tensor(1, 1, num_mode, dim)) 122 | self.m_pos = nn.Parameter(torch.Tensor(1, num_mode, dim)) 123 | 124 | if self.cat_x: 125 | self.cat_x_proj = nn.Linear(2 * dim, dim) 126 | 127 | self.loc_head = MLPLayer(dim, 2 * dim, self.future_steps * 2) 128 | self.yaw_head = MLPLayer(dim, 2 * dim, self.future_steps * 2) 129 | self.vel_head = MLPLayer(dim, 2 * dim, self.future_steps * 2) 130 | self.pi_head = MLPLayer(dim, dim, 1) 131 | 132 | nn.init.normal_(self.m_emb, mean=0.0, std=0.01) 133 | nn.init.normal_(self.m_pos, mean=0.0, std=0.01) 134 | 135 | def forward(self, data, enc_data): 136 | enc_emb = enc_data["enc_emb"] 137 | enc_key_padding_mask = enc_data["enc_key_padding_mask"] 138 | 139 | r_position = data["reference_line"]["position"] 140 | r_vector = data["reference_line"]["vector"] 141 | r_orientation = data["reference_line"]["orientation"] 142 | r_valid_mask = data["reference_line"]["valid_mask"] 143 | r_key_padding_mask = ~r_valid_mask.any(-1) 144 | 145 | r_feature = torch.cat( 146 | [ 147 | r_position - r_position[..., 0:1, :2], 148 | r_vector, 149 | torch.stack([r_orientation.cos(), r_orientation.sin()], dim=-1), 150 | ], 151 | dim=-1, 152 | ) 153 | 154 | bs, R, P, C = r_feature.shape 155 | r_valid_mask = r_valid_mask.view(bs * R, P) 156 | r_feature = r_feature.reshape(bs * R, P, C) 157 | r_emb = self.r_encoder(r_feature, r_valid_mask).view(bs, R, -1) 158 | 159 | r_pos = torch.cat([r_position[:, :, 0], r_orientation[:, :, 0, None]], dim=-1) 160 | r_emb = r_emb + self.r_pos_emb(r_pos) 161 | 162 | r_emb = r_emb.unsqueeze(2).repeat(1, 1, self.num_mode, 1) 163 | m_emb = self.m_emb.repeat(bs, R, 1, 1) 164 | 165 | q = self.q_proj(torch.cat([r_emb, m_emb], dim=-1)) 166 | 167 | for blk in self.decoder_blocks: 168 | q = blk( 169 | q, 170 | enc_emb, 171 | tgt_key_padding_mask=r_key_padding_mask, 172 | memory_key_padding_mask=enc_key_padding_mask, 173 | m_pos=self.m_pos, 174 | ) 175 | assert torch.isfinite(q).all() 176 | 177 | if self.cat_x: 178 | x = enc_emb[:, 0].unsqueeze(1).unsqueeze(2).repeat(1, R, self.num_mode, 1) 179 | q = self.cat_x_proj(torch.cat([q, x], dim=-1)) 180 | 181 | loc = self.loc_head(q).view(bs, R, self.num_mode, self.future_steps, 2) 182 | yaw = self.yaw_head(q).view(bs, R, self.num_mode, self.future_steps, 2) 183 | vel = self.vel_head(q).view(bs, R, self.num_mode, self.future_steps, 2) 184 | pi = self.pi_head(q).squeeze(-1) 185 | 186 | traj = torch.cat([loc, yaw, vel], dim=-1) 187 | 188 | return traj, pi 189 | -------------------------------------------------------------------------------- /src/models/pluto/modules/static_objects_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..layers.fourier_embedding import FourierEmbedding 6 | 7 | 8 | class StaticObjectsEncoder(nn.Module): 9 | def __init__(self, dim) -> None: 10 | super().__init__() 11 | 12 | self.obj_encoder = FourierEmbedding(2, dim, 64) 13 | self.type_emb = nn.Embedding(4, dim) 14 | 15 | nn.init.normal_(self.type_emb.weight, mean=0.0, std=0.01) 16 | 17 | def forward(self, data): 18 | pos = data["static_objects"]["position"] 19 | heading = data["static_objects"]["heading"] 20 | shape = data["static_objects"]["shape"] 21 | category = data["static_objects"]["category"].long() 22 | valid_mask = data["static_objects"]["valid_mask"] # [bs, N] 23 | 24 | obj_emb_tmp = self.obj_encoder(shape) + self.type_emb(category.long()) 25 | obj_emb = torch.zeros_like(obj_emb_tmp) 26 | obj_emb[valid_mask] = obj_emb_tmp[valid_mask] 27 | 28 | heading = (heading + math.pi) % (2 * math.pi) - math.pi 29 | obj_pos = torch.cat([pos, heading.unsqueeze(-1)], dim=-1) 30 | 31 | return obj_emb, obj_pos, ~valid_mask 32 | -------------------------------------------------------------------------------- /src/models/pluto/pluto_model.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling 7 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 8 | from nuplan.planning.training.preprocessing.target_builders.ego_trajectory_target_builder import ( 9 | EgoTrajectoryTargetBuilder, 10 | ) 11 | 12 | from src.feature_builders.pluto_feature_builder import PlutoFeatureBuilder 13 | 14 | from .layers.fourier_embedding import FourierEmbedding 15 | from .layers.transformer import TransformerEncoderLayer 16 | from .modules.agent_encoder import AgentEncoder 17 | from .modules.agent_predictor import AgentPredictor 18 | from .modules.map_encoder import MapEncoder 19 | from .modules.static_objects_encoder import StaticObjectsEncoder 20 | from .modules.planning_decoder import PlanningDecoder 21 | from .layers.mlp_layer import MLPLayer 22 | 23 | # no meaning, required by nuplan 24 | trajectory_sampling = TrajectorySampling(num_poses=8, time_horizon=8, interval_length=1) 25 | 26 | 27 | class PlanningModel(TorchModuleWrapper): 28 | def __init__( 29 | self, 30 | dim=128, 31 | state_channel=6, 32 | polygon_channel=6, 33 | history_channel=9, 34 | history_steps=21, 35 | future_steps=80, 36 | encoder_depth=4, 37 | decoder_depth=4, 38 | drop_path=0.2, 39 | dropout=0.1, 40 | num_heads=8, 41 | num_modes=6, 42 | use_ego_history=False, 43 | state_attn_encoder=True, 44 | state_dropout=0.75, 45 | use_hidden_proj=False, 46 | cat_x=False, 47 | ref_free_traj=False, 48 | feature_builder: PlutoFeatureBuilder = PlutoFeatureBuilder(), 49 | ) -> None: 50 | super().__init__( 51 | feature_builders=[feature_builder], 52 | target_builders=[EgoTrajectoryTargetBuilder(trajectory_sampling)], 53 | future_trajectory_sampling=trajectory_sampling, 54 | ) 55 | 56 | self.dim = dim 57 | self.history_steps = history_steps 58 | self.future_steps = future_steps 59 | self.use_hidden_proj = use_hidden_proj 60 | self.num_modes = num_modes 61 | self.radius = feature_builder.radius 62 | self.ref_free_traj = ref_free_traj 63 | 64 | self.pos_emb = FourierEmbedding(3, dim, 64) 65 | 66 | self.agent_encoder = AgentEncoder( 67 | state_channel=state_channel, 68 | history_channel=history_channel, 69 | dim=dim, 70 | hist_steps=history_steps, 71 | drop_path=drop_path, 72 | use_ego_history=use_ego_history, 73 | state_attn_encoder=state_attn_encoder, 74 | state_dropout=state_dropout, 75 | ) 76 | 77 | self.map_encoder = MapEncoder( 78 | dim=dim, 79 | polygon_channel=polygon_channel, 80 | use_lane_boundary=True, 81 | ) 82 | 83 | self.static_objects_encoder = StaticObjectsEncoder(dim=dim) 84 | 85 | self.encoder_blocks = nn.ModuleList( 86 | TransformerEncoderLayer(dim=dim, num_heads=num_heads, drop_path=dp) 87 | for dp in [x.item() for x in torch.linspace(0, drop_path, encoder_depth)] 88 | ) 89 | self.norm = nn.LayerNorm(dim) 90 | 91 | self.agent_predictor = AgentPredictor(dim=dim, future_steps=future_steps) 92 | self.planning_decoder = PlanningDecoder( 93 | num_mode=num_modes, 94 | decoder_depth=decoder_depth, 95 | dim=dim, 96 | num_heads=num_heads, 97 | mlp_ratio=4, 98 | dropout=dropout, 99 | cat_x=cat_x, 100 | future_steps=future_steps, 101 | ) 102 | 103 | if use_hidden_proj: 104 | self.hidden_proj = nn.Sequential( 105 | nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim) 106 | ) 107 | 108 | if self.ref_free_traj: 109 | self.ref_free_decoder = MLPLayer(dim, 2 * dim, future_steps * 4) 110 | 111 | self.apply(self._init_weights) 112 | 113 | def _init_weights(self, m): 114 | if isinstance(m, nn.Linear): 115 | torch.nn.init.xavier_uniform_(m.weight) 116 | if isinstance(m, nn.Linear) and m.bias is not None: 117 | nn.init.constant_(m.bias, 0) 118 | elif isinstance(m, nn.LayerNorm): 119 | nn.init.constant_(m.bias, 0) 120 | nn.init.constant_(m.weight, 1.0) 121 | elif isinstance(m, nn.BatchNorm1d): 122 | nn.init.ones_(m.weight) 123 | nn.init.zeros_(m.bias) 124 | elif isinstance(m, nn.Embedding): 125 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 126 | 127 | def forward(self, data): 128 | agent_pos = data["agent"]["position"][:, :, self.history_steps - 1] 129 | agent_heading = data["agent"]["heading"][:, :, self.history_steps - 1] 130 | agent_mask = data["agent"]["valid_mask"][:, :, : self.history_steps] 131 | polygon_center = data["map"]["polygon_center"] 132 | polygon_mask = data["map"]["valid_mask"] 133 | 134 | bs, A = agent_pos.shape[0:2] 135 | 136 | position = torch.cat([agent_pos, polygon_center[..., :2]], dim=1) 137 | angle = torch.cat([agent_heading, polygon_center[..., 2]], dim=1) 138 | angle = (angle + math.pi) % (2 * math.pi) - math.pi 139 | pos = torch.cat([position, angle.unsqueeze(-1)], dim=-1) 140 | 141 | agent_key_padding = ~(agent_mask.any(-1)) 142 | polygon_key_padding = ~(polygon_mask.any(-1)) 143 | key_padding_mask = torch.cat([agent_key_padding, polygon_key_padding], dim=-1) 144 | 145 | x_agent = self.agent_encoder(data) 146 | x_polygon = self.map_encoder(data) 147 | x_static, static_pos, static_key_padding = self.static_objects_encoder(data) 148 | 149 | x = torch.cat([x_agent, x_polygon, x_static], dim=1) 150 | 151 | pos = torch.cat([pos, static_pos], dim=1) 152 | pos_embed = self.pos_emb(pos) 153 | 154 | key_padding_mask = torch.cat([key_padding_mask, static_key_padding], dim=-1) 155 | x = x + pos_embed 156 | 157 | for blk in self.encoder_blocks: 158 | x = blk(x, key_padding_mask=key_padding_mask, return_attn_weights=False) 159 | x = self.norm(x) 160 | 161 | prediction = self.agent_predictor(x[:, 1:A]) 162 | 163 | ref_line_available = data["reference_line"]["position"].shape[1] > 0 164 | 165 | if ref_line_available: 166 | trajectory, probability = self.planning_decoder( 167 | data, {"enc_emb": x, "enc_key_padding_mask": key_padding_mask} 168 | ) 169 | else: 170 | trajectory, probability = None, None 171 | 172 | out = { 173 | "trajectory": trajectory, 174 | "probability": probability, # (bs, R, M) 175 | "prediction": prediction, # (bs, A-1, T, 2) 176 | } 177 | 178 | if self.use_hidden_proj: 179 | out["hidden"] = self.hidden_proj(x[:, 0]) 180 | 181 | if self.ref_free_traj: 182 | ref_free_traj = self.ref_free_decoder(x[:, 0]).reshape( 183 | bs, self.future_steps, 4 184 | ) 185 | out["ref_free_trajectory"] = ref_free_traj 186 | 187 | if not self.training: 188 | if self.ref_free_traj: 189 | ref_free_traj_angle = torch.arctan2( 190 | ref_free_traj[..., 3], ref_free_traj[..., 2] 191 | ) 192 | ref_free_traj = torch.cat( 193 | [ref_free_traj[..., :2], ref_free_traj_angle.unsqueeze(-1)], dim=-1 194 | ) 195 | out["output_ref_free_trajectory"] = ref_free_traj 196 | 197 | output_prediction = torch.cat( 198 | [ 199 | prediction[..., :2] + agent_pos[:, 1:A, None], 200 | torch.atan2(prediction[..., 3], prediction[..., 2]).unsqueeze(-1) 201 | + agent_heading[:, 1:A, None, None], 202 | prediction[..., 4:6], 203 | ], 204 | dim=-1, 205 | ) 206 | out["output_prediction"] = output_prediction 207 | 208 | if trajectory is not None: 209 | r_padding_mask = ~data["reference_line"]["valid_mask"].any(-1) 210 | probability.masked_fill_(r_padding_mask.unsqueeze(-1), -1e6) 211 | 212 | angle = torch.atan2(trajectory[..., 3], trajectory[..., 2]) 213 | out_trajectory = torch.cat( 214 | [trajectory[..., :2], angle.unsqueeze(-1)], dim=-1 215 | ) 216 | 217 | bs, R, M, T, _ = out_trajectory.shape 218 | flattened_probability = probability.reshape(bs, R * M) 219 | best_trajectory = out_trajectory.reshape(bs, R * M, T, -1)[ 220 | torch.arange(bs), flattened_probability.argmax(-1) 221 | ] 222 | 223 | out["output_trajectory"] = best_trajectory 224 | out["candidate_trajectories"] = out_trajectory 225 | else: 226 | out["output_trajectory"] = out["output_ref_free_trajectory"] 227 | out["probability"] = torch.zeros(1, 0, 0) 228 | out["candidate_trajectories"] = torch.zeros( 229 | 1, 0, 0, self.future_steps, 3 230 | ) 231 | 232 | return out 233 | -------------------------------------------------------------------------------- /src/optim/warmup_cos_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class WarmupCosLR(_LRScheduler): 7 | def __init__( 8 | self, optimizer, min_lr, lr, warmup_epochs, epochs, last_epoch=-1, verbose=False 9 | ) -> None: 10 | self.min_lr = min_lr 11 | self.lr = lr 12 | self.epochs = epochs 13 | self.warmup_epochs = warmup_epochs 14 | super(WarmupCosLR, self).__init__(optimizer, last_epoch, verbose) 15 | 16 | def state_dict(self): 17 | """Returns the state of the scheduler as a :class:`dict`. 18 | 19 | It contains an entry for every variable in self.__dict__ which 20 | is not the optimizer. 21 | """ 22 | return { 23 | key: value for key, value in self.__dict__.items() if key != "optimizer" 24 | } 25 | 26 | def load_state_dict(self, state_dict): 27 | """Loads the schedulers state. 28 | 29 | Args: 30 | state_dict (dict): scheduler state. Should be an object returned 31 | from a call to :meth:`state_dict`. 32 | """ 33 | self.__dict__.update(state_dict) 34 | 35 | def get_init_lr(self): 36 | lr = self.lr / self.warmup_epochs 37 | return lr 38 | 39 | def get_lr(self): 40 | if self.last_epoch < self.warmup_epochs: 41 | lr = self.lr * (self.last_epoch + 1) / self.warmup_epochs 42 | else: 43 | lr = self.min_lr + 0.5 * (self.lr - self.min_lr) * ( 44 | 1 45 | + math.cos( 46 | math.pi 47 | * (self.last_epoch - self.warmup_epochs) 48 | / (self.epochs - self.warmup_epochs) 49 | ) 50 | ) 51 | if "lr_scale" in self.optimizer.param_groups[0]: 52 | return [lr * group["lr_scale"] for group in self.optimizer.param_groups] 53 | 54 | return [lr for _ in self.optimizer.param_groups] 55 | -------------------------------------------------------------------------------- /src/planners/ml_planner_utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Deque, List, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import numpy.typing as npt 7 | import torch 8 | from nuplan.common.actor_state.ego_state import EgoState 9 | from nuplan.common.actor_state.state_representation import StateSE2 10 | from nuplan.planning.simulation.planner.ml_planner.transform_utils import ( 11 | _get_fixed_timesteps, 12 | _get_velocity_and_acceleration, 13 | _se2_vel_acc_to_ego_state, 14 | ) 15 | 16 | 17 | def normalize_angle(angle): 18 | return (angle + np.pi) % (2 * np.pi) - np.pi 19 | 20 | 21 | def global_trajectory_to_states( 22 | global_trajectory: npt.NDArray[np.float32], 23 | ego_history: Deque[EgoState], 24 | future_horizon: float, 25 | step_interval: float, 26 | include_ego_state: bool = True, 27 | ): 28 | ego_state = ego_history[-1] 29 | timesteps = _get_fixed_timesteps(ego_state, future_horizon, step_interval) 30 | global_states = [StateSE2.deserialize(pose) for pose in global_trajectory] 31 | 32 | velocities, accelerations = _get_velocity_and_acceleration( 33 | global_states, ego_history, timesteps 34 | ) 35 | agent_states = [ 36 | _se2_vel_acc_to_ego_state( 37 | state, 38 | velocity, 39 | acceleration, 40 | timestep, 41 | ego_state.car_footprint.vehicle_parameters, 42 | ) 43 | for state, velocity, acceleration, timestep in zip( 44 | global_states, velocities, accelerations, timesteps 45 | ) 46 | ] 47 | 48 | if include_ego_state: 49 | agent_states.insert(0, ego_state) 50 | else: 51 | init_state = deepcopy(agent_states[0]) 52 | init_state._time_point = ego_state.time_point 53 | agent_states.insert(0, init_state) 54 | 55 | return agent_states 56 | 57 | 58 | def load_checkpoint(checkpoint: str): 59 | ckpt = torch.load(checkpoint, map_location=torch.device("cpu")) 60 | state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} 61 | return state_dict 62 | -------------------------------------------------------------------------------- /src/post_processing/common/enum.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | from enum import IntEnum 6 | 7 | import numpy as np 8 | from nuplan.common.actor_state.ego_state import EgoState 9 | 10 | 11 | class StateIndex: 12 | _X = 0 13 | _Y = 1 14 | _HEADING = 2 15 | _VELOCITY_X = 3 16 | _VELOCITY_Y = 4 17 | _ACCELERATION_X = 5 18 | _ACCELERATION_Y = 6 19 | _STEERING_ANGLE = 7 20 | _STEERING_RATE = 8 21 | _ANGULAR_VELOCITY = 9 22 | _ANGULAR_ACCELERATION = 10 23 | 24 | @classmethod 25 | def size(cls): 26 | valid_attributes = [ 27 | attribute 28 | for attribute in dir(cls) 29 | if attribute.startswith("_") 30 | and not attribute.startswith("__") 31 | and not callable(getattr(cls, attribute)) 32 | ] 33 | return len(valid_attributes) 34 | 35 | @classmethod 36 | @property 37 | def X(cls): 38 | return cls._X 39 | 40 | @classmethod 41 | @property 42 | def Y(cls): 43 | return cls._Y 44 | 45 | @classmethod 46 | @property 47 | def HEADING(cls): 48 | return cls._HEADING 49 | 50 | @classmethod 51 | @property 52 | def VELOCITY_X(cls): 53 | return cls._VELOCITY_X 54 | 55 | @classmethod 56 | @property 57 | def VELOCITY_Y(cls): 58 | return cls._VELOCITY_Y 59 | 60 | @classmethod 61 | @property 62 | def ACCELERATION_X(cls): 63 | return cls._ACCELERATION_X 64 | 65 | @classmethod 66 | @property 67 | def ACCELERATION_Y(cls): 68 | return cls._ACCELERATION_Y 69 | 70 | @classmethod 71 | @property 72 | def STEERING_ANGLE(cls): 73 | return cls._STEERING_ANGLE 74 | 75 | @classmethod 76 | @property 77 | def STEERING_RATE(cls): 78 | return cls._STEERING_RATE 79 | 80 | @classmethod 81 | @property 82 | def ANGULAR_VELOCITY(cls): 83 | return cls._ANGULAR_VELOCITY 84 | 85 | @classmethod 86 | @property 87 | def ANGULAR_ACCELERATION(cls): 88 | return cls._ANGULAR_ACCELERATION 89 | 90 | @classmethod 91 | @property 92 | def POINT(cls): 93 | # assumes X, Y have subsequent indices 94 | return slice(cls._X, cls._Y + 1) 95 | 96 | @classmethod 97 | @property 98 | def STATE_SE2(cls): 99 | # assumes X, Y, HEADING have subsequent indices 100 | return slice(cls._X, cls._HEADING + 1) 101 | 102 | @classmethod 103 | @property 104 | def VELOCITY_2D(cls): 105 | # assumes velocity X, Y have subsequent indices 106 | return slice(cls._VELOCITY_X, cls._VELOCITY_Y + 1) 107 | 108 | @classmethod 109 | @property 110 | def ACCELERATION_2D(cls): 111 | # assumes acceleration X, Y have subsequent indices 112 | return slice(cls._ACCELERATION_X, cls._ACCELERATION_Y + 1) 113 | 114 | 115 | class DynamicStateIndex(IntEnum): 116 | ACCELERATION_X = 0 117 | STEERING_RATE = 1 118 | 119 | 120 | class BBCoordsIndex(IntEnum): 121 | FRONT_LEFT = 0 122 | REAR_LEFT = 1 123 | REAR_RIGHT = 2 124 | FRONT_RIGHT = 3 125 | CENTER = 4 126 | 127 | 128 | class EgoAreaIndex(IntEnum): 129 | MULTIPLE_LANES = 0 130 | NON_DRIVABLE_AREA = 1 131 | ONCOMING_TRAFFIC = 2 132 | 133 | 134 | class MultiMetricIndex(IntEnum): 135 | NO_COLLISION = 0 136 | DRIVABLE_AREA = 1 137 | DRIVING_DIRECTION = 2 138 | 139 | 140 | class WeightedMetricIndex(IntEnum): 141 | PROGRESS = 0 142 | SPEED_LIMIT = 1 143 | COMFORTABLE = 2 144 | TTC = 3 145 | 146 | 147 | class CollisionType(IntEnum): 148 | """Enum for the types of collisions of interest.""" 149 | 150 | STOPPED_EGO_COLLISION = 0 151 | STOPPED_TRACK_COLLISION = 1 152 | ACTIVE_FRONT_COLLISION = 2 153 | ACTIVE_REAR_COLLISION = 3 154 | ACTIVE_LATERAL_COLLISION = 4 155 | 156 | 157 | def ego_state_to_state_array(ego_state: EgoState): 158 | """ 159 | Converts an ego state into an array representation (drops time-stamps and vehicle parameters) 160 | :param ego_state: ego state class 161 | :return: array containing ego state values 162 | """ 163 | state_array = np.zeros(StateIndex.size(), dtype=np.float64) 164 | 165 | state_array[StateIndex.STATE_SE2] = ego_state.rear_axle.serialize() 166 | state_array[ 167 | StateIndex.VELOCITY_2D 168 | ] = ego_state.dynamic_car_state.rear_axle_velocity_2d.array 169 | state_array[ 170 | StateIndex.ACCELERATION_2D 171 | ] = ego_state.dynamic_car_state.rear_axle_acceleration_2d.array 172 | 173 | state_array[StateIndex.STEERING_ANGLE] = ego_state.tire_steering_angle 174 | state_array[ 175 | StateIndex.STEERING_RATE 176 | ] = ego_state.dynamic_car_state.tire_steering_rate 177 | 178 | state_array[ 179 | StateIndex.ANGULAR_VELOCITY 180 | ] = ego_state.dynamic_car_state.angular_velocity 181 | state_array[ 182 | StateIndex.ANGULAR_ACCELERATION 183 | ] = ego_state.dynamic_car_state.angular_acceleration 184 | 185 | return state_array 186 | -------------------------------------------------------------------------------- /src/post_processing/common/geometry.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | from typing import Dict, Any 6 | 7 | import numpy as np 8 | from nuplan.common.actor_state.state_representation import StateSE2 9 | from shapely import Polygon, LineString 10 | 11 | from .enum import CollisionType, StateIndex 12 | from nuplan.planning.simulation.observation.idm.utils import is_agent_behind 13 | 14 | 15 | def compute_agents_vertices( 16 | center: np.ndarray, 17 | angle: np.ndarray, 18 | shape: np.ndarray, 19 | ) -> np.ndarray: 20 | """ 21 | Args: 22 | position: (N, T, 2) 23 | angle: (N, T) 24 | shape: (N, 2) [width, length] 25 | Returns: 26 | 4 corners of oriented box (FL, RL, RR, FR) 27 | vertices: (N, T, 4, 2) 28 | """ 29 | # Extracting dimensions 30 | N, T = center.shape[0], center.shape[1] 31 | 32 | # Reshaping the arrays for calculations 33 | center = center.reshape(N * T, 2) 34 | angle = angle.reshape(N * T) 35 | 36 | if shape.ndim == 2: 37 | shape = (shape / 2).repeat(T, axis=0) 38 | else: 39 | shape = (shape / 2).reshape(N * T, 2) 40 | 41 | # Calculating half width and half_l 42 | half_w = shape[:, 0] 43 | half_l = shape[:, 1] 44 | 45 | # Calculating cos and sin of angles 46 | cos_angle = np.cos(angle)[:, None] 47 | sin_angle = np.sin(angle)[:, None] 48 | rot_mat = np.stack([cos_angle, sin_angle, -sin_angle, cos_angle], axis=-1).reshape( 49 | N * T, 2, 2 50 | ) 51 | 52 | offset_width = np.stack([half_w, half_w, -half_w, -half_w], axis=-1) 53 | offset_length = np.stack([half_l, -half_l, -half_l, half_l], axis=-1) 54 | 55 | vertices = np.stack([offset_length, offset_width], axis=-1) 56 | vertices = np.matmul(vertices, rot_mat) + center[:, None] 57 | 58 | # Calculating vertices 59 | vertices = vertices.reshape(N, T, 4, 2) 60 | 61 | return vertices 62 | 63 | 64 | def ego_rear_to_center(rear_xy, heading, rear_to_center=1.461): 65 | direction = np.stack([np.cos(heading), np.sin(heading)], axis=-1) 66 | center = rear_xy + direction * rear_to_center 67 | return center 68 | 69 | 70 | def get_sub_polygon(polygon: Polygon, ratio=0.4): 71 | vertices = np.array(polygon.exterior.coords) 72 | return Polygon( 73 | [ 74 | vertices[0], 75 | vertices[0] * (1 - ratio) + vertices[1] * ratio, 76 | vertices[3] * (1 - ratio) + vertices[2] * ratio, 77 | vertices[3], 78 | ] 79 | ) 80 | 81 | 82 | def get_collision_type( 83 | state: np.ndarray, 84 | ego_polygon: Polygon, 85 | object_info: Dict[str, Any], 86 | stopped_speed_threshold: float = 5e-02, 87 | ) -> CollisionType: 88 | """ 89 | Classify collision between ego and the track. 90 | :param ego_state: Ego's state at the current timestamp. 91 | :param tracked_object: Tracked object. 92 | :param stopped_speed_threshold: Threshold for 0 speed due to noise. 93 | :return Collision type. 94 | """ 95 | 96 | ego_speed = np.hypot(state[StateIndex.VELOCITY_X], state[StateIndex.VELOCITY_Y]) 97 | 98 | is_ego_stopped = float(ego_speed) <= stopped_speed_threshold 99 | 100 | object_pos: np.ndarray = object_info["pose"] 101 | object_velocity: np.ndarray = object_info["velocity"] 102 | object_polygon: Polygon = object_info["polygon"] 103 | 104 | tracked_object_center = StateSE2(*object_pos) 105 | 106 | ego_rear_axle_pose: StateSE2 = StateSE2(*state[StateIndex.STATE_SE2]) 107 | 108 | # Collisions at (close-to) zero ego speed 109 | if is_ego_stopped: 110 | collision_type = CollisionType.STOPPED_EGO_COLLISION 111 | 112 | # Collisions at (close-to) zero track speed 113 | elif np.linalg.norm(object_velocity) <= stopped_speed_threshold: 114 | collision_type = CollisionType.STOPPED_TRACK_COLLISION 115 | 116 | # Rear collision when both ego and track are not stopped 117 | elif is_agent_behind(ego_rear_axle_pose, tracked_object_center): 118 | collision_type = CollisionType.ACTIVE_REAR_COLLISION 119 | 120 | # Front bumper collision when both ego and track are not stopped 121 | # elif get_sub_polygon(ego_polygon).intersects(object_polygon): 122 | # collision_type = CollisionType.ACTIVE_FRONT_COLLISION 123 | elif LineString( 124 | [ 125 | ego_polygon.exterior.coords[0], 126 | ego_polygon.exterior.coords[3], 127 | ] 128 | ).intersects(object_polygon): 129 | collision_type = CollisionType.ACTIVE_FRONT_COLLISION 130 | 131 | # Lateral collision when both ego and track are not stopped 132 | else: 133 | collision_type = CollisionType.ACTIVE_LATERAL_COLLISION 134 | 135 | return collision_type 136 | -------------------------------------------------------------------------------- /src/post_processing/emergency_brake.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | from nuplan.common.actor_state.ego_state import EgoState 10 | from nuplan.common.actor_state.state_representation import ( 11 | StateSE2, 12 | StateVector2D, 13 | TimePoint, 14 | ) 15 | from nuplan.planning.simulation.trajectory.interpolated_trajectory import ( 16 | InterpolatedTrajectory, 17 | ) 18 | from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling 19 | 20 | from .forward_simulation.forward_simulator import ForwardSimulator 21 | from .common.enum import StateIndex 22 | 23 | 24 | class EmergencyBrake: 25 | def __init__( 26 | self, 27 | trajectory_sampling: TrajectorySampling = TrajectorySampling( 28 | num_poses=80, interval_length=0.1 29 | ), 30 | time_to_infraction_threshold: float = 2.0, 31 | max_ego_speed: float = 8.0, 32 | max_long_accel: float = 2.40, 33 | min_long_accel: float = -2.40, 34 | emergency_decel: float = -4.05, 35 | ): 36 | # trajectory parameters 37 | self._trajectory_sampling = trajectory_sampling 38 | 39 | # braking parameters 40 | self._max_ego_speed: float = max_ego_speed # [m/s] 41 | self._max_long_accel: float = max_long_accel # [m/s^2] 42 | self._min_long_accel: float = min_long_accel # [m/s^2] 43 | self._emergency_decel: float = emergency_decel # [m/s^2] 44 | 45 | # braking condition parameters 46 | self._time_to_infraction_threshold: float = time_to_infraction_threshold 47 | 48 | def brake_if_emergency( 49 | self, 50 | ego_state: EgoState, 51 | time_to_at_fault_collision: float, 52 | ego_trajectory: npt.NDArray[np.float64], 53 | ) -> Optional[InterpolatedTrajectory]: 54 | trajectory = None 55 | ego_speed: float = ego_state.dynamic_car_state.speed 56 | 57 | time_to_infraction = time_to_at_fault_collision 58 | min_brake_time = max(ego_speed / abs(self._min_long_accel) + 0.5, 3.0) 59 | 60 | if time_to_infraction <= min_brake_time and ego_speed <= self._max_ego_speed: 61 | print("Emergency Brake") 62 | min_reaction_time = ego_speed / abs(self._emergency_decel) + 0.5 63 | is_soft_brake_possible = time_to_infraction > min_reaction_time 64 | trajectory = self._generate_ebrake_trajectory( 65 | ego_trajectory, ego_state, soft_brake=is_soft_brake_possible 66 | ) 67 | 68 | return trajectory 69 | 70 | def _generate_ebrake_trajectory( 71 | self, origin_trajectory: np.ndarray, ego_state: EgoState, soft_brake=False 72 | ): 73 | simulator = ForwardSimulator( 74 | dt=self._trajectory_sampling.interval_length, 75 | num_frames=self._trajectory_sampling.num_poses, 76 | estop=True, 77 | soft_brake=soft_brake, 78 | ) 79 | rollout = simulator.forward(origin_trajectory[None, ...], ego_state)[0] 80 | 81 | ego_states, current_time_point = [], ego_state.time_point 82 | delta_t = TimePoint(int(self._trajectory_sampling.interval_length * 1e6)) 83 | 84 | for state in rollout: 85 | ego_states.append( 86 | EgoState.build_from_rear_axle( 87 | rear_axle_pose=StateSE2(*state[:3]), 88 | rear_axle_velocity_2d=StateVector2D(*state[StateIndex.VELOCITY_2D]), 89 | rear_axle_acceleration_2d=StateVector2D( 90 | *state[StateIndex.ACCELERATION_2D] 91 | ), 92 | tire_steering_angle=state[StateIndex.STEERING_ANGLE], 93 | time_point=current_time_point, 94 | vehicle_parameters=ego_state.car_footprint.vehicle_parameters, 95 | ) 96 | ) 97 | current_time_point += delta_t 98 | 99 | return InterpolatedTrajectory(ego_states) 100 | -------------------------------------------------------------------------------- /src/post_processing/forward_simulation/batch_kinematic_bicycle.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | import copy 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | from nuplan.common.actor_state.ego_state import EgoState 10 | from nuplan.common.actor_state.state_representation import TimePoint 11 | from nuplan.common.actor_state.vehicle_parameters import ( 12 | VehicleParameters, 13 | get_pacifica_parameters, 14 | ) 15 | from nuplan.common.geometry.compute import principal_value 16 | from ..common.enum import DynamicStateIndex, StateIndex 17 | 18 | 19 | def forward_integrate( 20 | init: npt.NDArray[np.float64], 21 | delta: npt.NDArray[np.float64], 22 | sampling_time: TimePoint, 23 | ) -> npt.NDArray[np.float64]: 24 | """ 25 | Performs a simple euler integration. 26 | :param init: Initial state 27 | :param delta: The rate of change of the state. 28 | :param sampling_time: The time duration to propagate for. 29 | :return: The result of integration 30 | """ 31 | return init + delta * sampling_time.time_s 32 | 33 | 34 | class BatchKinematicBicycleModel: 35 | """ 36 | A batch-wise operating class describing the kinematic motion model where the rear axle is the point of reference. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | vehicle: VehicleParameters = get_pacifica_parameters(), 42 | max_steering_angle: float = np.pi / 3, 43 | accel_time_constant: float = 0.2, 44 | steering_angle_time_constant: float = 0.05, 45 | ): 46 | """ 47 | Construct BatchKinematicBicycleModel. 48 | :param vehicle: Vehicle parameters. 49 | :param max_steering_angle: [rad] Maximum absolute value steering angle allowed by model. 50 | :param accel_time_constant: low pass filter time constant for acceleration in s 51 | :param steering_angle_time_constant: low pass filter time constant for steering angle in s 52 | """ 53 | self._vehicle = vehicle 54 | self._max_steering_angle = max_steering_angle 55 | self._accel_time_constant = accel_time_constant 56 | self._steering_angle_time_constant = steering_angle_time_constant 57 | 58 | def get_state_dot(self, states: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: 59 | """ 60 | Calculates the changing rate of state array representation. 61 | :param states: array describing the state of the ego-vehicle 62 | :return: change rate across several state values 63 | """ 64 | state_dots = np.zeros(states.shape, dtype=np.float64) 65 | 66 | longitudinal_speeds = states[:, StateIndex.VELOCITY_X] 67 | 68 | state_dots[:, StateIndex.X] = longitudinal_speeds * np.cos( 69 | states[:, StateIndex.HEADING] 70 | ) 71 | state_dots[:, StateIndex.Y] = longitudinal_speeds * np.sin( 72 | states[:, StateIndex.HEADING] 73 | ) 74 | state_dots[:, StateIndex.HEADING] = ( 75 | longitudinal_speeds 76 | * np.tan(states[:, StateIndex.STEERING_ANGLE]) 77 | / self._vehicle.wheel_base 78 | ) 79 | 80 | state_dots[:, StateIndex.VELOCITY_2D] = states[:, StateIndex.ACCELERATION_2D] 81 | state_dots[:, StateIndex.ACCELERATION_2D] = 0.0 82 | 83 | state_dots[:, StateIndex.STEERING_ANGLE] = states[:, StateIndex.STEERING_RATE] 84 | 85 | return state_dots 86 | 87 | def _update_commands( 88 | self, 89 | states: npt.NDArray[np.float64], 90 | command_states: npt.NDArray[np.float64], 91 | sampling_time: TimePoint, 92 | ) -> EgoState: 93 | """ 94 | This function applies some first order control delay/a low pass filter to acceleration/steering. 95 | 96 | :param state: Ego state 97 | :param ideal_dynamic_state: The desired dynamic state for propagation 98 | :param sampling_time: The time duration to propagate for 99 | :return: propagating_state including updated dynamic_state 100 | """ 101 | 102 | propagating_state: npt.NDArray[np.float64] = copy.deepcopy(states) 103 | 104 | dt_control = sampling_time.time_s 105 | 106 | accel = states[:, StateIndex.ACCELERATION_X] 107 | steering_angle = states[:, StateIndex.STEERING_ANGLE] 108 | 109 | ideal_accel_x = command_states[:, DynamicStateIndex.ACCELERATION_X] 110 | ideal_steering_angle = ( 111 | dt_control * command_states[:, DynamicStateIndex.STEERING_RATE] 112 | + steering_angle 113 | ) 114 | 115 | updated_accel_x = ( 116 | dt_control 117 | / (dt_control + self._accel_time_constant) 118 | * (ideal_accel_x - accel) 119 | + accel 120 | ) 121 | updated_steering_angle = ( 122 | dt_control 123 | / (dt_control + self._steering_angle_time_constant) 124 | * (ideal_steering_angle - steering_angle) 125 | + steering_angle 126 | ) 127 | updated_steering_rate = (updated_steering_angle - steering_angle) / dt_control 128 | 129 | propagating_state[:, StateIndex.ACCELERATION_X] = updated_accel_x 130 | propagating_state[:, StateIndex.ACCELERATION_Y] = 0.0 131 | propagating_state[:, StateIndex.STEERING_RATE] = updated_steering_rate 132 | 133 | return propagating_state 134 | 135 | def propagate_state( 136 | self, 137 | states: npt.NDArray[np.float64], 138 | command_states: npt.NDArray[np.float64], 139 | sampling_time: TimePoint, 140 | ) -> npt.NDArray[np.float64]: 141 | """ 142 | Propagates ego state array forward with motion model. 143 | :param states: state array representation of the ego-vehicle 144 | :param command_states: command array representation of controller 145 | :param sampling_time: time to propagate [s] 146 | :return: updated tate array representation of the ego-vehicle 147 | """ 148 | 149 | assert len(states) == len( 150 | command_states 151 | ), "Batch size of states and command_states does not match!" 152 | 153 | propagating_state = self._update_commands(states, command_states, sampling_time) 154 | output_state = copy.deepcopy(states) 155 | 156 | # Compute state derivatives 157 | state_dot = self.get_state_dot(propagating_state) 158 | 159 | output_state[:, StateIndex.X] = forward_integrate( 160 | states[:, StateIndex.X], state_dot[:, StateIndex.X], sampling_time 161 | ) 162 | output_state[:, StateIndex.Y] = forward_integrate( 163 | states[:, StateIndex.Y], state_dot[:, StateIndex.Y], sampling_time 164 | ) 165 | 166 | output_state[:, StateIndex.HEADING] = principal_value( 167 | forward_integrate( 168 | states[:, StateIndex.HEADING], 169 | state_dot[:, StateIndex.HEADING], 170 | sampling_time, 171 | ) 172 | ) 173 | 174 | output_state[:, StateIndex.VELOCITY_X] = forward_integrate( 175 | states[:, StateIndex.VELOCITY_X], 176 | state_dot[:, StateIndex.VELOCITY_X], 177 | sampling_time, 178 | ) 179 | 180 | # Lateral velocity is always zero in kinematic bicycle model 181 | output_state[:, StateIndex.VELOCITY_Y] = 0.0 182 | 183 | # Integrate steering angle and clip to bounds 184 | output_state[:, StateIndex.STEERING_ANGLE] = np.clip( 185 | forward_integrate( 186 | propagating_state[:, StateIndex.STEERING_ANGLE], 187 | state_dot[:, StateIndex.STEERING_ANGLE], 188 | sampling_time, 189 | ), 190 | -self._max_steering_angle, 191 | self._max_steering_angle, 192 | ) 193 | 194 | output_state[:, StateIndex.ANGULAR_VELOCITY] = ( 195 | output_state[:, StateIndex.VELOCITY_X] 196 | * np.tan(output_state[:, StateIndex.STEERING_ANGLE]) 197 | / self._vehicle.wheel_base 198 | ) 199 | 200 | output_state[:, StateIndex.ACCELERATION_2D] = state_dot[ 201 | :, StateIndex.VELOCITY_2D 202 | ] 203 | 204 | output_state[:, StateIndex.ANGULAR_ACCELERATION] = ( 205 | output_state[:, StateIndex.ANGULAR_VELOCITY] 206 | - states[:, StateIndex.ANGULAR_VELOCITY] 207 | ) / sampling_time.time_s 208 | 209 | output_state[:, StateIndex.STEERING_RATE] = state_dot[ 210 | :, StateIndex.STEERING_ANGLE 211 | ] 212 | 213 | return output_state 214 | -------------------------------------------------------------------------------- /src/post_processing/forward_simulation/forward_simulator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from nuplan.common.actor_state.ego_state import EgoState 3 | from nuplan.common.actor_state.state_representation import TimeDuration, TimePoint 4 | from nuplan.planning.simulation.simulation_time_controller.simulation_iteration import ( 5 | SimulationIteration, 6 | ) 7 | 8 | from .batch_kinematic_bicycle import BatchKinematicBicycleModel 9 | from .batch_lqr import BatchLQRTracker 10 | from ..common.enum import StateIndex, ego_state_to_state_array 11 | 12 | 13 | class ForwardSimulator: 14 | def __init__( 15 | self, 16 | dt: float = 0.1, 17 | num_frames: int = 40, 18 | estop: bool = False, 19 | soft_brake: bool = False, 20 | ) -> None: 21 | self.dt = dt 22 | self.interval = int(dt * 10) 23 | self.num_frames = num_frames 24 | self.motion_model = BatchKinematicBicycleModel() 25 | self.tracker = BatchLQRTracker( 26 | discretization_time=dt, 27 | tracking_horizon=int(1 / dt), 28 | estop=estop, 29 | soft_brake=soft_brake, 30 | ) 31 | 32 | def forward(self, candidate_trajectories: np.ndarray, init_ego_state: EgoState): 33 | """ 34 | Args: 35 | candidate_trajectories: (N, 80+1, S), sampled at 10 Hz 36 | """ 37 | N = candidate_trajectories.shape[0] 38 | rollout_states = np.zeros( 39 | (N, self.num_frames + 1, StateIndex.size()), dtype=np.float64 40 | ) 41 | rollout_states[:, 0] = ego_state_to_state_array(init_ego_state) 42 | 43 | t_now = init_ego_state.time_point 44 | delta_t = TimeDuration.from_s(self.dt) 45 | 46 | current_iteration = SimulationIteration(t_now, 0) 47 | next_iteration = SimulationIteration(t_now + delta_t, 1) 48 | 49 | self.tracker.update(candidate_trajectories[:, :: self.interval]) 50 | 51 | for t in range(1, self.num_frames + 1): 52 | sampling_time: TimePoint = ( 53 | next_iteration.time_point - current_iteration.time_point 54 | ) 55 | command_states = self.tracker.track_trajectory( 56 | current_iteration, 57 | next_iteration, 58 | rollout_states[:, t - 1], 59 | ) 60 | 61 | rollout_states[:, t] = self.motion_model.propagate_state( 62 | states=rollout_states[:, t - 1], 63 | command_states=command_states, 64 | sampling_time=sampling_time, 65 | ) 66 | 67 | current_iteration = next_iteration 68 | next_iteration = SimulationIteration( 69 | current_iteration.time_point + delta_t, 1 + t 70 | ) 71 | 72 | return rollout_states 73 | -------------------------------------------------------------------------------- /src/post_processing/observation/world_from_prediction.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | from typing import Dict, List, Optional 6 | 7 | import numpy as np 8 | import shapely.creation 9 | from nuplan.common.actor_state.ego_state import EgoState 10 | from nuplan.common.maps.abstract_map_objects import LaneGraphEdgeMapObject 11 | from nuplan.common.maps.maps_datatypes import ( 12 | TrafficLightStatusData, 13 | TrafficLightStatusType, 14 | ) 15 | from nuplan.planning.simulation.observation.observation_type import DetectionsTracks 16 | 17 | from src.scenario_manager.occupancy_map import OccupancyMap 18 | 19 | from ..common.geometry import compute_agents_vertices 20 | 21 | 22 | class WorldFromPrediction: 23 | def __init__(self, dt=0.1, num_frames=40, base_radius=50) -> None: 24 | self.dt = dt 25 | self.num_frames = num_frames 26 | self.interval = int(dt // 0.1) 27 | 28 | # todo: determined by velocity 29 | self.radius = max(base_radius * dt * num_frames / 4, base_radius) 30 | 31 | self.occupancy_map: Optional[List[OccupancyMap]] = None 32 | self.drivable_area: Optional[OccupancyMap] = None 33 | self.objects_info: Optional[Dict[str, np.ndarray]] = None 34 | 35 | self.red_light_prefix = "red_light" 36 | 37 | self.collided_tokens = None 38 | self._static_object_tokens = None 39 | self._agent_tokens = None 40 | self._ego_state = None 41 | 42 | def __getitem__(self, idx: int) -> OccupancyMap: 43 | assert 0 <= idx < len(self.occupancy_map), "index out of range" 44 | return self.occupancy_map[idx] 45 | 46 | def __len__(self) -> int: 47 | return len(self.occupancy_map) 48 | 49 | def update( 50 | self, 51 | ego_state: EgoState, 52 | detections: DetectionsTracks, 53 | traffic_light_data: List[TrafficLightStatusData], 54 | agents_info: Dict[str, np.ndarray], 55 | route_lane_dict: Dict[str, LaneGraphEdgeMapObject], 56 | ): 57 | self._ego_state = ego_state 58 | self.occupancy_map: List[OccupancyMap] = [] 59 | 60 | tl_tokens, tl_polygons = self._get_route_red_traffic_lights( 61 | traffic_light_data, route_lane_dict 62 | ) 63 | statics_tokens, statics_polygon = self._get_static_obstacles( 64 | ego_state, detections 65 | ) 66 | agents_tokens, agents_vertices = self._get_dynamic_agents_from_prediction( 67 | agents_info 68 | ) 69 | has_agents = len(agents_tokens) > 0 70 | agents_polygon = np.array([], dtype=np.object_) 71 | 72 | for i in range(self.num_frames + 1): 73 | if has_agents: 74 | agents_polygon = agents_vertices[:, i] 75 | agents_polygon = shapely.creation.polygons(agents_polygon) 76 | 77 | frame_tokens = statics_tokens + agents_tokens + tl_tokens 78 | frame_polygons = np.concatenate( 79 | [statics_polygon, agents_polygon, tl_polygons], axis=0 80 | ) 81 | frame_occupancy_map = OccupancyMap( 82 | tokens=frame_tokens, geometries=frame_polygons 83 | ) 84 | self.occupancy_map.append(frame_occupancy_map) 85 | 86 | # update initial ego collision status 87 | self.collided_tokens = [] 88 | ego_polygon = ego_state.car_footprint.geometry 89 | intersect_tokens = self.occupancy_map[0].intersects(ego_polygon) 90 | for token in intersect_tokens: 91 | if token.startswith(self.red_light_prefix): 92 | if not ego_polygon.within(self.occupancy_map[0][token]): 93 | continue 94 | self.collided_tokens.append(token) 95 | 96 | self._static_object_tokens = set(statics_tokens) 97 | self._agent_tokens = set(agents_tokens) 98 | 99 | def _get_static_obstacles(self, ego_state: EgoState, detections: DetectionsTracks): 100 | self._static_objects = {} 101 | tokens, polygons = [], [] 102 | 103 | for static_obstacle in detections.tracked_objects.get_static_objects(): 104 | if ( 105 | np.linalg.norm(static_obstacle.center.array - ego_state.center.array) 106 | > self.radius 107 | ): 108 | continue 109 | if len(self.drivable_area.intersects(static_obstacle.box.geometry)) > 0: 110 | tokens.append(static_obstacle.track_token) 111 | polygons.append(static_obstacle.box.geometry) 112 | self._static_objects[static_obstacle.track_token] = static_obstacle 113 | 114 | if len(tokens) == 0: 115 | polygons = np.array([], dtype=np.object_) 116 | 117 | return tokens, polygons 118 | 119 | def _get_dynamic_agents_from_prediction( 120 | self, 121 | agents_info: Dict[str, np.ndarray], 122 | ): 123 | tokens = agents_info["tokens"] 124 | shape = agents_info["shape"] 125 | category = agents_info["category"] 126 | velocity = agents_info["velocity"] 127 | predictions = agents_info["predictions"][:, :: self.interval] 128 | 129 | agents_vertices = compute_agents_vertices( 130 | center=predictions[..., :2], angle=predictions[..., 2], shape=shape 131 | ) 132 | 133 | self._agents_info = {} 134 | 135 | for i in range(len(tokens)): 136 | self._agents_info[tokens[i]] = { 137 | "shape": shape[i], 138 | "velocity": velocity[i], 139 | "prediction": predictions[i], 140 | } 141 | 142 | return tokens, agents_vertices 143 | 144 | def _get_route_red_traffic_lights( 145 | self, 146 | traffic_light_data: List[TrafficLightStatusData], 147 | route_lane_dict: Dict[str, LaneGraphEdgeMapObject], 148 | ): 149 | tokens, polygons = [], [] 150 | 151 | for data in traffic_light_data: 152 | if data.status != TrafficLightStatusType.RED: 153 | continue 154 | lane_connector_id = str(data.lane_connector_id) 155 | if lane_connector_id in route_lane_dict.keys(): 156 | lane_connector = route_lane_dict[lane_connector_id] 157 | tokens.append(f"{self.red_light_prefix}_{lane_connector_id}") 158 | polygons.append(lane_connector.polygon) 159 | 160 | if len(tokens) == 0: 161 | polygons = np.array([], dtype=np.object_) 162 | 163 | return tokens, polygons 164 | 165 | def get_object_at_frame(self, token, frame_idx): 166 | if token in self._static_object_tokens: 167 | return { 168 | "is_agent": False, 169 | "pose": self._static_objects[token].center, 170 | "velocity": np.zeros(2), 171 | "polygon": self.occupancy_map[frame_idx][token], 172 | } 173 | else: 174 | return { 175 | "is_agent": True, 176 | "pose": self._agents_info[token]["prediction"][frame_idx], 177 | "velocity": self._agents_info[token]["velocity"][frame_idx], 178 | "polygon": self.occupancy_map[frame_idx][token], 179 | } 180 | -------------------------------------------------------------------------------- /src/scenario_manager/cost_map_manager.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Dict, List, Set 3 | 4 | import cv2 5 | import numpy as np 6 | from nuplan.common.actor_state.state_representation import Point2D 7 | from nuplan.common.actor_state.static_object import StaticObject 8 | from nuplan.common.actor_state.tracked_objects import TrackedObjects 9 | from nuplan.common.maps.abstract_map import AbstractMap 10 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer 11 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario 12 | from scipy import ndimage 13 | from shapely import Polygon 14 | 15 | DA = [SemanticMapLayer.LANE, SemanticMapLayer.LANE_CONNECTOR] 16 | 17 | 18 | class CostMapManager: 19 | def __init__( 20 | self, 21 | origin: np.ndarray, 22 | angle: float, 23 | map_api: AbstractMap, 24 | height: int = 500, 25 | width: int = 500, 26 | resolution: float = 0.2, 27 | ) -> None: 28 | self.map_api = map_api 29 | self.height = height 30 | self.width = width 31 | self.resolution = resolution 32 | self.resolution_hw = np.array([resolution, -resolution], dtype=np.float32) 33 | self.origin = origin 34 | self.angle = angle 35 | self.offset = np.array([height / 2, width / 2], dtype=np.float32) 36 | self.rot_mat = np.array( 37 | [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], 38 | dtype=np.float64, 39 | ) 40 | 41 | @classmethod 42 | def from_scenario(cls, scenario: AbstractScenario): 43 | ego_state = scenario.initial_ego_state 44 | origin = ego_state.rear_axle.point.array 45 | angle = ego_state.rear_axle.heading 46 | 47 | return cls(origin=origin, angle=angle, map_api=scenario.map_api) 48 | 49 | def build_cost_maps( 50 | self, 51 | static_objects: list[StaticObject], 52 | agents: Dict[str, np.ndarray] = None, 53 | agents_polygon: List[Polygon] = None, 54 | route_roadblock_ids: Set[str] = None, 55 | ): 56 | drivable_area_mask = np.zeros((self.height, self.width), dtype=np.uint8) 57 | speed_limit_mask = np.zeros((self.height, self.width), dtype=np.float32) 58 | 59 | radius = max(self.height, self.width) * self.resolution / 2 60 | da_objects_dict = self.map_api.get_proximal_map_objects( 61 | Point2D(*self.origin), radius, DA 62 | ) 63 | da_objects = itertools.chain.from_iterable(da_objects_dict.values()) 64 | 65 | for obj in da_objects: 66 | self.fill_polygon(drivable_area_mask, obj.polygon, value=1) 67 | 68 | speed_limit_mps = obj.speed_limit_mps if obj.speed_limit_mps else 50 69 | self.fill_polygon(speed_limit_mask, obj.polygon, value=speed_limit_mps) 70 | 71 | for static_ojb in static_objects: 72 | if np.linalg.norm(static_ojb.center.array - self.origin, axis=-1) > radius: 73 | continue 74 | self.fill_convex_polygon( 75 | drivable_area_mask, static_ojb.box.geometry, value=0 76 | ) 77 | 78 | if agents is not None: 79 | # parking vehicles as static obstacles 80 | position = agents["position"] 81 | valid_mask = agents["valid_mask"] 82 | for pos, mask, polygon in zip(position, valid_mask, agents_polygon): 83 | if mask.sum() < 50: 84 | continue 85 | pos = pos[mask] 86 | displacement = np.linalg.norm(pos[-1] - pos[0]) 87 | if displacement < 1.0: 88 | self.fill_convex_polygon(drivable_area_mask, polygon, value=0) 89 | 90 | distance = ndimage.distance_transform_edt(drivable_area_mask) 91 | inv_distance = ndimage.distance_transform_edt(1 - drivable_area_mask) 92 | drivable_area_sdf = distance - inv_distance 93 | drivable_area_sdf *= self.resolution 94 | 95 | return { 96 | "cost_maps": drivable_area_sdf[:, :, None].astype(np.float16), # (H, W. C) 97 | } 98 | 99 | def global_to_pixel(self, coord: np.ndarray): 100 | coord = np.matmul(coord - self.origin, self.rot_mat) 101 | coord = coord / self.resolution_hw + self.offset 102 | return coord 103 | 104 | def fill_polygon(self, mask, polygon, value=1): 105 | polygon = self.global_to_pixel(np.stack(polygon.exterior.coords.xy, axis=1)) 106 | cv2.fillPoly(mask, [np.round(polygon).astype(np.int32)], value) 107 | 108 | def fill_convex_polygon(self, mask, polygon, value=1): 109 | polygon = self.global_to_pixel(np.stack(polygon.exterior.coords.xy, axis=1)) 110 | cv2.fillConvexPoly(mask, np.round(polygon).astype(np.int32), value) 111 | 112 | def fill_polyline(self, mask, polyline, value=1): 113 | polyline = self.global_to_pixel(polyline) 114 | cv2.polylines( 115 | mask, 116 | [np.round(polyline.reshape(-1, 1, 2)).astype(np.int32)], 117 | isClosed=False, 118 | color=value, 119 | thickness=1, 120 | ) 121 | -------------------------------------------------------------------------------- /src/scenario_manager/occupancy_map.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from: 2 | # https://github.com/autonomousvision/tuplan_garage (Apache License 2.0) 3 | # & https://github.com/motional/nuplan-devkit (Apache License 2.0) 4 | 5 | from enum import Enum 6 | from typing import Any, Dict, List 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | import shapely.vectorized 11 | from nuplan.planning.simulation.occupancy_map.abstract_occupancy_map import Geometry 12 | from shapely.strtree import STRtree 13 | 14 | 15 | class OccupancyType(Enum): 16 | DYNAMIC = 0, "dynamic" 17 | STATIC = 1, "static" 18 | RED_LIGHT = 2, "red_light" 19 | 20 | 21 | class OccupancyMap: 22 | def __init__( 23 | self, 24 | tokens: List[str], 25 | geometries: npt.NDArray[np.object_], 26 | types: List[Enum] = None, 27 | node_capacity: int = 10, 28 | attribute: Dict[str, Any] = None, 29 | ): 30 | self._tokens: List[str] = tokens 31 | self._types: List[Enum] = types 32 | self._token_to_idx: Dict[str, int] = { 33 | token: idx for idx, token in enumerate(tokens) 34 | } 35 | 36 | self._geometries = geometries 37 | self._attribute = attribute 38 | self._node_capacity = node_capacity 39 | self._str_tree = STRtree(self._geometries, node_capacity) 40 | 41 | def __getitem__(self, token) -> Geometry: 42 | """ 43 | Retrieves geometry of token. 44 | :param token: geometry identifier 45 | :return: Geometry of token 46 | """ 47 | return self._geometries[self._token_to_idx[token]] 48 | 49 | def __len__(self) -> int: 50 | """ 51 | Number of geometries in the occupancy map 52 | :return: int 53 | """ 54 | return len(self._tokens) 55 | 56 | def get_type(self, token: str) -> Enum: 57 | """ 58 | Retrieves type of token. 59 | :param token: geometry identifier 60 | :return: type of token 61 | """ 62 | return self._types[self._token_to_idx[token]] 63 | 64 | @property 65 | def tokens(self) -> List[str]: 66 | """ 67 | Getter for track tokens in occupancy map 68 | :return: list of strings 69 | """ 70 | return self._tokens 71 | 72 | @property 73 | def token_to_idx(self) -> Dict[str, int]: 74 | """ 75 | Getter for track tokens in occupancy map 76 | :return: dictionary of tokens and indices 77 | """ 78 | return self._token_to_idx 79 | 80 | def intersects(self, geometry: Geometry) -> List[str]: 81 | """ 82 | Searches for intersecting geometries in the occupancy map 83 | :param geometry: geometries to query 84 | :return: list of tokens for intersecting geometries 85 | """ 86 | indices = self.query(geometry, predicate="intersects") 87 | return [self._tokens[idx] for idx in indices] 88 | 89 | def get_subset_by_intersection(self, geometry: Geometry) -> "OccupancyMap": 90 | indices = self.query(geometry, predicate="intersects") 91 | polygons = [self._geometries[i] for i in indices] 92 | 93 | return OccupancyMap(self._tokens, polygons) 94 | 95 | def get_subset_by_type(self, type): 96 | assert self._types is not None, "OccupancyMap: No types defined!" 97 | indices = [i for i, t in enumerate(self._types) if t == type] 98 | polygons = [self._geometries[i] for i in indices] 99 | 100 | return OccupancyMap(self._tokens, polygons) 101 | 102 | def query(self, geometry: Geometry, predicate=None): 103 | """ 104 | Function to directly calls shapely's query function on str-tree 105 | :param geometry: geometries to query 106 | :param predicate: see shapely, defaults to None 107 | :return: query output 108 | """ 109 | return self._str_tree.query(geometry, predicate=predicate) 110 | 111 | def points_in_polygons( 112 | self, points: npt.NDArray[np.float64] 113 | ) -> npt.NDArray[np.bool_]: 114 | """ 115 | Determines wether input-points are in polygons of the occupancy map 116 | :param points: input-points 117 | :return: boolean array of shape (polygons, input-points) 118 | """ 119 | output = np.zeros((len(self._geometries), len(points)), dtype=bool) 120 | for i, polygon in enumerate(self._geometries): 121 | output[i] = shapely.vectorized.contains(polygon, points[:, 0], points[:, 1]) 122 | 123 | return output 124 | 125 | def points_in_polygons_with_attribute( 126 | self, points: npt.NDArray[np.float64], attribute_name: str 127 | ): 128 | """ 129 | Determines wether input-points are in polygons of the occupancy map 130 | :param points: input-points 131 | :return: boolean array of shape (polygons, input-points) 132 | """ 133 | output = np.zeros((len(self._geometries), len(points)), dtype=bool) 134 | attribute = np.zeros((len(self._geometries), len(points))) 135 | for i, polygon in enumerate(self._geometries): 136 | output[i] = shapely.vectorized.contains(polygon, points[:, 0], points[:, 1]) 137 | attribute[i] = self._attribute[attribute_name][i] 138 | 139 | return output, attribute 140 | -------------------------------------------------------------------------------- /src/scenario_manager/scenario_manager.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from typing import List 4 | 5 | import numpy as np 6 | from nuplan.common.actor_state.ego_state import EgoState 7 | from nuplan.common.actor_state.state_representation import Point2D, StateSE2 8 | from nuplan.common.actor_state.tracked_objects import TrackedObjects 9 | from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType 10 | from nuplan.common.maps.maps_datatypes import ( 11 | SemanticMapLayer, 12 | TrafficLightStatusData, 13 | TrafficLightStatusType, 14 | ) 15 | from nuplan.planning.simulation.observation.idm.utils import ( 16 | create_path_from_se2, 17 | path_to_linestring, 18 | ) 19 | from nuplan.planning.simulation.path.utils import trim_path 20 | from shapely.geometry import Point, Polygon 21 | from shapely.geometry.base import CAP_STYLE 22 | from shapely.ops import unary_union 23 | 24 | from .occupancy_map import OccupancyMap, OccupancyType 25 | from .route_manager import RouteManager 26 | 27 | DRIVABLE_LAYERS = { 28 | SemanticMapLayer.ROADBLOCK, 29 | SemanticMapLayer.ROADBLOCK_CONNECTOR, 30 | SemanticMapLayer.CARPARK_AREA, 31 | } 32 | 33 | STATIC_OBJECT_TYPES = { 34 | TrackedObjectType.CZONE_SIGN, 35 | TrackedObjectType.BARRIER, 36 | TrackedObjectType.TRAFFIC_CONE, 37 | TrackedObjectType.GENERIC_OBJECT, 38 | } 39 | 40 | DYNAMIC_OBJECT_TYPES = { 41 | TrackedObjectType.VEHICLE, 42 | TrackedObjectType.BICYCLE, 43 | TrackedObjectType.PEDESTRIAN, 44 | } 45 | 46 | 47 | class ScenarioManager: 48 | def __init__( 49 | self, 50 | map_api, 51 | ego_state: EgoState, 52 | route_roadblocks_ids: List[str], 53 | radius=50, 54 | ) -> None: 55 | self._map_api = map_api 56 | self._radius = radius 57 | self._drivable_area_map: OccupancyMap = None # [lanes, lane connectors] 58 | self._obstacle_map: OccupancyMap = None # [agents, red lights, etc] 59 | 60 | self._ego_state = ego_state 61 | self._ego_path = None 62 | self._ego_path_linestring = None 63 | self._ego_trimmed_path = None 64 | self._ego_progress = None 65 | 66 | self._route_manager = RouteManager(map_api, route_roadblocks_ids, radius) 67 | 68 | @property 69 | def drivable_area_map(self): 70 | return self._drivable_area_map 71 | 72 | def get_route_roadblock_ids(self, process=True) -> List[str]: 73 | if not self._route_manager.initialized: 74 | self._route_manager.load_route(self._ego_state, process) 75 | return self._route_manager.route_roadblock_ids 76 | 77 | def get_route_lane_dicts(self): 78 | assert self._route_manager.initialized 79 | return self._route_manager._route_lane_dict 80 | 81 | def update_ego_state(self, ego_state: EgoState): 82 | self._ego_state = ego_state 83 | 84 | def update_drivable_area_map(self): 85 | """ 86 | Builds occupancy map of drivable area. 87 | :param ego_state: EgoState 88 | """ 89 | 90 | position: Point2D = self._ego_state.center.point 91 | drivable_area = self._map_api.get_proximal_map_objects( 92 | position, self._radius, DRIVABLE_LAYERS 93 | ) 94 | 95 | drivable_polygons: List[Polygon] = [] 96 | drivable_polygon_ids: List[str] = [] 97 | lane_speed_limit = [] 98 | 99 | for road in [SemanticMapLayer.ROADBLOCK, SemanticMapLayer.ROADBLOCK_CONNECTOR]: 100 | for roadblock in drivable_area[road]: 101 | for lane in roadblock.interior_edges: 102 | drivable_polygons.append(lane.polygon) 103 | drivable_polygon_ids.append(lane.id) 104 | speed_limit = lane.speed_limit_mps if lane.speed_limit_mps else 50 105 | lane_speed_limit.append(speed_limit) 106 | 107 | for carpark in drivable_area[SemanticMapLayer.CARPARK_AREA]: 108 | drivable_polygons.append(carpark.polygon) 109 | drivable_polygon_ids.append(carpark.id) 110 | speed_limit = lane.speed_limit_mps if lane.speed_limit_mps else 50 111 | lane_speed_limit.append(speed_limit) 112 | 113 | self._drivable_area_map = OccupancyMap( 114 | drivable_polygon_ids, 115 | drivable_polygons, 116 | attribute={"speed_limit": lane_speed_limit}, 117 | ) 118 | self._route_manager.update_drivable_area_map(self._drivable_area_map) 119 | 120 | def update_obstacle_map( 121 | self, 122 | detections: TrackedObjects, 123 | traffic_light_status: List[TrafficLightStatusData], 124 | ): 125 | """ 126 | Builds occupancy map of obstacles. 127 | :param ego_state: EgoState 128 | """ 129 | tokens = [] 130 | types = [] 131 | polygons = [] 132 | 133 | for obj in detections.tracked_objects: 134 | if ( 135 | np.linalg.norm(self._ego_state.center.array - obj.center.array) 136 | < self._radius 137 | ): 138 | obj_type = ( 139 | OccupancyType.DYNAMIC 140 | if obj.tracked_object_type in DYNAMIC_OBJECT_TYPES 141 | else OccupancyType.STATIC 142 | ) 143 | tokens.append(obj.track_token) 144 | types.append(obj_type) 145 | polygons.append(obj.box.geometry) 146 | 147 | for data in traffic_light_status: 148 | if ( 149 | data.status == TrafficLightStatusType.RED 150 | and str(data.lane_connector_id) in self._route_manager.route_lane_ids 151 | ): 152 | tokens.append(data.lane_connector_id) 153 | types.append(OccupancyType.RED_LIGHT) 154 | polygons.append( 155 | self._map_api.get_map_object( 156 | str(data.lane_connector_id), SemanticMapLayer.LANE_CONNECTOR 157 | ).polygon 158 | ) 159 | 160 | self._obstacle_map = OccupancyMap(tokens, polygons, types) 161 | 162 | def update_ego_path(self, length=50): 163 | ego_path: List[StateSE2] = self._route_manager.get_ego_path(self._ego_state) 164 | self._ego_path = create_path_from_se2(ego_path) 165 | self._ego_path_linestring = path_to_linestring(ego_path) 166 | 167 | start_progress = self._ego_path.get_start_progress() 168 | end_progress = self._ego_path.get_end_progress() 169 | 170 | with warnings.catch_warnings(): 171 | # https://github.com/shapely/shapely/issues/1796 172 | warnings.simplefilter("ignore") 173 | self._ego_progress = self._ego_path_linestring.project( 174 | Point([self._ego_state.center.x, self._ego_state.center.y]) 175 | ) 176 | 177 | trimmed_path = trim_path( 178 | self._ego_path, 179 | max(start_progress, min(self._ego_progress, end_progress)), 180 | min(self._ego_progress + length, end_progress), 181 | ) 182 | self._ego_trimmed_path = trimmed_path 183 | 184 | np_path = np.array([p.point.array for p in trimmed_path]) 185 | return np_path 186 | 187 | def get_leading_objects(self): 188 | expanded_ego_path = path_to_linestring(self._ego_trimmed_path).buffer( 189 | self._ego_state.car_footprint.width / 2, cap_style=CAP_STYLE.square 190 | ) 191 | expanded_ego_path = unary_union( 192 | [expanded_ego_path, self._ego_state.car_footprint.geometry] 193 | ) 194 | intersecting_objects = self._obstacle_map.intersects(expanded_ego_path) 195 | 196 | if len(intersecting_objects) == 0: 197 | return [] 198 | 199 | leading_objects = [] 200 | ego_polygon = self._ego_state.car_footprint.geometry 201 | 202 | for obj_token in intersecting_objects: 203 | leading_objects.append( 204 | ( 205 | obj_token, 206 | self._obstacle_map.get_type(obj_token), 207 | self._obstacle_map[obj_token].distance(ego_polygon), 208 | ) 209 | ) 210 | 211 | self.leading_objects = sorted(leading_objects, key=lambda x: x[2]) 212 | 213 | return self.leading_objects 214 | 215 | def get_occupancy_object(self, token: str): 216 | return self._obstacle_map[token] 217 | 218 | def get_ego_path_points(self, start_progress, end_progress): 219 | start_progress += self._ego_progress 220 | end_progress += self._ego_progress 221 | return np.array( 222 | [ 223 | np.array([p.x, p.y, p.heading], dtype=np.float64) 224 | for p in self._ego_trimmed_path 225 | if p.progress >= start_progress and p.progress <= end_progress 226 | ] 227 | ) 228 | 229 | def get_reference_lines(self, length=100): 230 | return self._route_manager.get_reference_lines(self._ego_state, length=length) 231 | 232 | def get_cached_reference_lines(self): 233 | if self._route_manager.reference_lines: 234 | return self._route_manager.reference_lines 235 | else: 236 | raise ValueError("Reference lines not cached") 237 | 238 | def object_in_drivable_area(self, polygon: Polygon): 239 | return len(self._drivable_area_map.intersects(polygon)) > 0 240 | -------------------------------------------------------------------------------- /src/scenario_manager/utils/bfs_roadblock.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Dict, Optional, Tuple, Union, List 3 | 4 | from nuplan.common.maps.abstract_map import AbstractMap 5 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject 6 | 7 | 8 | class BreadthFirstSearchRoadBlock: 9 | """ 10 | A class that performs iterative breadth first search. The class operates on the roadblock graph. 11 | """ 12 | 13 | def __init__( 14 | self, start_roadblock_id: int, map_api: Optional[AbstractMap], forward_search: str = True 15 | ): 16 | """ 17 | Constructor of BreadthFirstSearchRoadBlock class 18 | :param start_roadblock_id: roadblock id where graph starts 19 | :param map_api: map class in nuPlan 20 | :param forward_search: whether to search in driving direction, defaults to True 21 | """ 22 | self._map_api: Optional[AbstractMap] = map_api 23 | self._queue = deque([self.id_to_roadblock(start_roadblock_id), None]) 24 | self._parent: Dict[str, Optional[RoadBlockGraphEdgeMapObject]] = dict() 25 | self._forward_search = forward_search 26 | 27 | # lazy loaded 28 | self._target_roadblock_ids: List[str] = None 29 | 30 | def search( 31 | self, target_roadblock_id: Union[str, List[str]], max_depth: int 32 | ) -> Tuple[List[RoadBlockGraphEdgeMapObject], bool]: 33 | """ 34 | Apply BFS to find route to target roadblock. 35 | :param target_roadblock_id: id of target roadblock 36 | :param max_depth: maximum search depth 37 | :return: tuple of route and whether a path was found 38 | """ 39 | 40 | if isinstance(target_roadblock_id, str): 41 | target_roadblock_id = [target_roadblock_id] 42 | self._target_roadblock_ids = target_roadblock_id 43 | 44 | start_edge = self._queue[0] 45 | 46 | # Initial search states 47 | path_found: bool = False 48 | end_edge: RoadBlockGraphEdgeMapObject = start_edge 49 | end_depth: int = 1 50 | depth: int = 1 51 | 52 | self._parent[start_edge.id + f"_{depth}"] = None 53 | 54 | while self._queue: 55 | current_edge = self._queue.popleft() 56 | 57 | # Early exit condition 58 | if self._check_end_condition(depth, max_depth): 59 | break 60 | 61 | # Depth tracking 62 | if current_edge is None: 63 | depth += 1 64 | self._queue.append(None) 65 | if self._queue[0] is None: 66 | break 67 | continue 68 | 69 | # Goal condition 70 | if self._check_goal_condition(current_edge, depth, max_depth): 71 | end_edge = current_edge 72 | end_depth = depth 73 | path_found = True 74 | break 75 | 76 | neighbors = ( 77 | current_edge.outgoing_edges if self._forward_search else current_edge.incoming_edges 78 | ) 79 | 80 | # Populate queue 81 | for next_edge in neighbors: 82 | # if next_edge.id in self._candidate_lane_edge_ids_old: 83 | self._queue.append(next_edge) 84 | self._parent[next_edge.id + f"_{depth + 1}"] = current_edge 85 | end_edge = next_edge 86 | end_depth = depth + 1 87 | 88 | return self._construct_path(end_edge, end_depth), path_found 89 | 90 | def id_to_roadblock(self, id: str) -> RoadBlockGraphEdgeMapObject: 91 | """ 92 | Retrieves roadblock from map-api based on id 93 | :param id: id of roadblock 94 | :return: roadblock class 95 | """ 96 | block = self._map_api._get_roadblock(id) 97 | block = block or self._map_api._get_roadblock_connector(id) 98 | return block 99 | 100 | @staticmethod 101 | def _check_end_condition(depth: int, max_depth: int) -> bool: 102 | """ 103 | Check if the search should end regardless if the goal condition is met. 104 | :param depth: The current depth to check. 105 | :param target_depth: The target depth to check against. 106 | :return: whether depth exceeds the target depth. 107 | """ 108 | return depth > max_depth 109 | 110 | def _check_goal_condition( 111 | self, 112 | current_edge: RoadBlockGraphEdgeMapObject, 113 | depth: int, 114 | max_depth: int, 115 | ) -> bool: 116 | """ 117 | Check if the current edge is at the target roadblock at the given depth. 118 | :param current_edge: edge to check. 119 | :param depth: current depth to check. 120 | :param max_depth: maximum depth the edge should be at. 121 | :return: True if the lane edge is contain the in the target roadblock. False, otherwise. 122 | """ 123 | return current_edge.id in self._target_roadblock_ids and depth <= max_depth 124 | 125 | def _construct_path( 126 | self, end_edge: RoadBlockGraphEdgeMapObject, depth: int 127 | ) -> List[RoadBlockGraphEdgeMapObject]: 128 | """ 129 | Constructs a path when goal was found. 130 | :param end_edge: The end edge to start back propagating back to the start edge. 131 | :param depth: The depth of the target edge. 132 | :return: The constructed path as a list of RoadBlockGraphEdgeMapObject 133 | """ 134 | path = [end_edge] 135 | path_id = [end_edge.id] 136 | 137 | while self._parent[end_edge.id + f"_{depth}"] is not None: 138 | path.append(self._parent[end_edge.id + f"_{depth}"]) 139 | path_id.append(path[-1].id) 140 | end_edge = self._parent[end_edge.id + f"_{depth}"] 141 | depth -= 1 142 | 143 | if self._forward_search: 144 | path.reverse() 145 | path_id.reverse() 146 | 147 | return (path, path_id) 148 | -------------------------------------------------------------------------------- /src/scenario_manager/utils/dijkstra.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | import numpy as np 3 | from nuplan.common.maps.abstract_map_objects import ( 4 | LaneGraphEdgeMapObject, 5 | RoadBlockGraphEdgeMapObject, 6 | ) 7 | 8 | 9 | class Dijkstra: 10 | """ 11 | A class that performs dijkstra's shortest path. The class operates on lane level graph search. 12 | The goal condition is specified to be if the lane can be found at the target roadblock or roadblock connector. 13 | """ 14 | 15 | def __init__(self, start_edge: LaneGraphEdgeMapObject, candidate_lane_edge_ids: List[str]): 16 | """ 17 | Constructor for the Dijkstra class. 18 | :param start_edge: The starting edge for the search 19 | :param candidate_lane_edge_ids: The candidates lane ids that can be included in the search. 20 | """ 21 | self._queue = list([start_edge]) 22 | self._parent: Dict[str, Optional[LaneGraphEdgeMapObject]] = dict() 23 | self._candidate_lane_edge_ids = candidate_lane_edge_ids 24 | 25 | def search( 26 | self, target_roadblock: RoadBlockGraphEdgeMapObject 27 | ) -> Tuple[List[LaneGraphEdgeMapObject], bool]: 28 | """ 29 | Performs dijkstra's shortest path to find a route to the target roadblock. 30 | :param target_roadblock: The target roadblock the path should end at. 31 | :return: 32 | - A route starting from the given start edge 33 | - A bool indicating if the route is successfully found. Successful means that there exists a path 34 | from the start edge to an edge contained in the end roadblock. 35 | If unsuccessful the shortest deepest path is returned. 36 | """ 37 | start_edge = self._queue[0] 38 | 39 | # Initial search states 40 | path_found: bool = False 41 | end_edge: LaneGraphEdgeMapObject = start_edge 42 | 43 | self._parent[start_edge.id] = None 44 | self._frontier = [start_edge.id] 45 | self._dist = [1] 46 | self._depth = [1] 47 | 48 | self._expanded = [] 49 | self._expanded_id = [] 50 | self._expanded_dist = [] 51 | self._expanded_depth = [] 52 | 53 | while len(self._queue) > 0: 54 | dist, idx = min((val, idx) for (idx, val) in enumerate(self._dist)) 55 | current_edge = self._queue[idx] 56 | current_depth = self._depth[idx] 57 | 58 | del self._dist[idx], self._queue[idx], self._frontier[idx], self._depth[idx] 59 | 60 | if self._check_goal_condition(current_edge, target_roadblock): 61 | end_edge = current_edge 62 | path_found = True 63 | break 64 | 65 | self._expanded.append(current_edge) 66 | self._expanded_id.append(current_edge.id) 67 | self._expanded_dist.append(dist) 68 | self._expanded_depth.append(current_depth) 69 | 70 | # Populate queue 71 | for next_edge in current_edge.outgoing_edges: 72 | if not next_edge.id in self._candidate_lane_edge_ids: 73 | continue 74 | 75 | alt = dist + self._edge_cost(next_edge) 76 | if next_edge.id not in self._expanded_id and next_edge.id not in self._frontier: 77 | self._parent[next_edge.id] = current_edge 78 | self._queue.append(next_edge) 79 | self._frontier.append(next_edge.id) 80 | self._dist.append(alt) 81 | self._depth.append(current_depth + 1) 82 | end_edge = next_edge 83 | 84 | elif next_edge.id in self._frontier: 85 | next_edge_idx = self._frontier.index(next_edge.id) 86 | current_cost = self._dist[next_edge_idx] 87 | if alt < current_cost: 88 | self._parent[next_edge.id] = current_edge 89 | self._dist[next_edge_idx] = alt 90 | self._depth[next_edge_idx] = current_depth + 1 91 | 92 | if not path_found: 93 | # filter max depth 94 | max_depth = max(self._expanded_depth) 95 | idx_max_depth = list(np.where(np.array(self._expanded_depth) == max_depth)[0]) 96 | dist_at_max_depth = [self._expanded_dist[i] for i in idx_max_depth] 97 | 98 | dist, _idx = min((val, idx) for (idx, val) in enumerate(dist_at_max_depth)) 99 | end_edge = self._expanded[idx_max_depth[_idx]] 100 | 101 | return self._construct_path(end_edge), path_found 102 | 103 | @staticmethod 104 | def _edge_cost(lane: LaneGraphEdgeMapObject) -> float: 105 | """ 106 | Edge cost of given lane. 107 | :param lane: lane class 108 | :return: length of lane 109 | """ 110 | return lane.baseline_path.length 111 | 112 | @staticmethod 113 | def _check_end_condition(depth: int, target_depth: int) -> bool: 114 | """ 115 | Check if the search should end regardless if the goal condition is met. 116 | :param depth: The current depth to check. 117 | :param target_depth: The target depth to check against. 118 | :return: True if: 119 | - The current depth exceeds the target depth. 120 | """ 121 | return depth > target_depth 122 | 123 | @staticmethod 124 | def _check_goal_condition( 125 | current_edge: LaneGraphEdgeMapObject, 126 | target_roadblock: RoadBlockGraphEdgeMapObject, 127 | ) -> bool: 128 | """ 129 | Check if the current edge is at the target roadblock at the given depth. 130 | :param current_edge: The edge to check. 131 | :param target_roadblock: The target roadblock the edge should be contained in. 132 | :return: whether the current edge is in the target roadblock 133 | """ 134 | return current_edge.get_roadblock_id() == target_roadblock.id 135 | 136 | def _construct_path(self, end_edge: LaneGraphEdgeMapObject) -> List[LaneGraphEdgeMapObject]: 137 | """ 138 | :param end_edge: The end edge to start back propagating back to the start edge. 139 | :param depth: The depth of the target edge. 140 | :return: The constructed path as a list of LaneGraphEdgeMapObject 141 | """ 142 | path = [end_edge] 143 | while self._parent[end_edge.id] is not None: 144 | node = self._parent[end_edge.id] 145 | path.append(node) 146 | end_edge = node 147 | path.reverse() 148 | 149 | return path 150 | -------------------------------------------------------------------------------- /src/scenario_manager/utils/route_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy as np 4 | from nuplan.common.actor_state.ego_state import EgoState 5 | from nuplan.common.actor_state.state_representation import StateSE2 6 | from nuplan.common.maps.abstract_map import AbstractMap 7 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject 8 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer 9 | from nuplan.planning.simulation.occupancy_map.strtree_occupancy_map import ( 10 | STRTreeOccupancyMapFactory, 11 | ) 12 | 13 | from .bfs_roadblock import BreadthFirstSearchRoadBlock 14 | 15 | 16 | def normalize_angle(angle: float) -> float: 17 | return (angle + np.pi) % (2 * np.pi) - np.pi 18 | 19 | 20 | def get_current_roadblock_candidates( 21 | ego_state: EgoState, 22 | map_api: AbstractMap, 23 | route_roadblocks_dict: Dict[str, RoadBlockGraphEdgeMapObject], 24 | heading_error_thresh: float = np.pi / 4, 25 | displacement_error_thresh: float = 3, 26 | ) -> Tuple[RoadBlockGraphEdgeMapObject, List[RoadBlockGraphEdgeMapObject]]: 27 | """ 28 | Determines a set of roadblock candidate where ego is located 29 | :param ego_state: class containing ego state 30 | :param map_api: map object 31 | :param route_roadblocks_dict: dictionary of on-route roadblocks 32 | :param heading_error_thresh: maximum heading error, defaults to np.pi/4 33 | :param displacement_error_thresh: maximum displacement, defaults to 3 34 | :return: tuple of most promising roadblock and other candidates 35 | """ 36 | ego_pose: StateSE2 = ego_state.rear_axle 37 | roadblock_candidates = [] 38 | 39 | layers = [SemanticMapLayer.ROADBLOCK, SemanticMapLayer.ROADBLOCK_CONNECTOR] 40 | roadblock_dict = map_api.get_proximal_map_objects( 41 | point=ego_pose.point, radius=2.5, layers=layers 42 | ) 43 | roadblock_candidates = ( 44 | roadblock_dict[SemanticMapLayer.ROADBLOCK] 45 | + roadblock_dict[SemanticMapLayer.ROADBLOCK_CONNECTOR] 46 | ) 47 | 48 | if not roadblock_candidates: 49 | for layer in layers: 50 | roadblock_id_, distance = map_api.get_distance_to_nearest_map_object( 51 | point=ego_pose.point, layer=layer 52 | ) 53 | roadblock = map_api.get_map_object(roadblock_id_, layer) 54 | 55 | if roadblock: 56 | roadblock_candidates.append(roadblock) 57 | 58 | on_route_candidates, on_route_candidate_displacement_errors = [], [] 59 | candidates, candidate_displacement_errors = [], [] 60 | 61 | roadblock_displacement_errors = [] 62 | roadblock_heading_errors = [] 63 | 64 | for idx, roadblock in enumerate(roadblock_candidates): 65 | lane_displacement_error, lane_heading_error = np.inf, np.inf 66 | 67 | for lane in roadblock.interior_edges: 68 | lane_discrete_path: List[StateSE2] = lane.baseline_path.discrete_path 69 | lane_discrete_points = np.array( 70 | [state.point.array for state in lane_discrete_path], dtype=np.float64 71 | ) 72 | lane_state_distances = ( 73 | (lane_discrete_points - ego_pose.point.array[None, ...]) ** 2.0 74 | ).sum(axis=-1) ** 0.5 75 | argmin = np.argmin(lane_state_distances) 76 | 77 | heading_error = np.abs( 78 | normalize_angle(lane_discrete_path[argmin].heading - ego_pose.heading) 79 | ) 80 | displacement_error = lane_state_distances[argmin] 81 | 82 | if displacement_error < lane_displacement_error: 83 | lane_heading_error, lane_displacement_error = ( 84 | heading_error, 85 | displacement_error, 86 | ) 87 | 88 | if ( 89 | heading_error < heading_error_thresh 90 | and displacement_error < displacement_error_thresh 91 | ): 92 | if roadblock.id in route_roadblocks_dict.keys(): 93 | on_route_candidates.append(roadblock) 94 | on_route_candidate_displacement_errors.append(displacement_error) 95 | else: 96 | candidates.append(roadblock) 97 | candidate_displacement_errors.append(displacement_error) 98 | 99 | roadblock_displacement_errors.append(lane_displacement_error) 100 | roadblock_heading_errors.append(lane_heading_error) 101 | 102 | if on_route_candidates: # prefer on-route roadblocks 103 | return ( 104 | on_route_candidates[np.argmin(on_route_candidate_displacement_errors)], 105 | on_route_candidates, 106 | ) 107 | elif candidates: # fallback to most promising candidate 108 | return candidates[np.argmin(candidate_displacement_errors)], candidates 109 | 110 | # otherwise, just find any close roadblock 111 | return ( 112 | roadblock_candidates[np.argmin(roadblock_displacement_errors)], 113 | roadblock_candidates, 114 | ) 115 | 116 | 117 | def route_roadblock_correction( 118 | ego_state: EgoState, 119 | map_api: AbstractMap, 120 | route_roadblock_ids: List[str], 121 | search_depth_backward: int = 15, 122 | search_depth_forward: int = 30, 123 | ) -> List[str]: 124 | """ 125 | Applies several methods to correct route roadblocks. 126 | :param ego_state: class containing ego state 127 | :param map_api: map object 128 | :param route_roadblocks_dict: dictionary of on-route roadblocks 129 | :param search_depth_backward: depth of forward BFS search, defaults to 15 130 | :param search_depth_forward: depth of backward BFS search, defaults to 30 131 | :return: list of roadblock id's of corrected route 132 | """ 133 | 134 | route_roadblock_dict = {} 135 | for id_ in route_roadblock_ids: 136 | block = map_api.get_map_object(id_, SemanticMapLayer.ROADBLOCK) 137 | block = block or map_api.get_map_object( 138 | id_, SemanticMapLayer.ROADBLOCK_CONNECTOR 139 | ) 140 | route_roadblock_dict[id_] = block 141 | 142 | starting_block, starting_block_candidates = get_current_roadblock_candidates( 143 | ego_state, map_api, route_roadblock_dict 144 | ) 145 | starting_block_ids = [roadblock.id for roadblock in starting_block_candidates] 146 | 147 | route_roadblocks = list(route_roadblock_dict.values()) 148 | route_roadblock_ids = list(route_roadblock_dict.keys()) 149 | 150 | # Fix 1: when agent starts off-route 151 | if starting_block.id not in route_roadblock_ids: 152 | # Backward search if current roadblock not in route 153 | graph_search = BreadthFirstSearchRoadBlock( 154 | route_roadblock_ids[0], map_api, forward_search=False 155 | ) 156 | (path, path_id), path_found = graph_search.search( 157 | starting_block_ids, max_depth=search_depth_backward 158 | ) 159 | 160 | if path_found: 161 | route_roadblocks[:0] = path[:-1] 162 | route_roadblock_ids[:0] = path_id[:-1] 163 | 164 | else: 165 | # Forward search to any route roadblock 166 | graph_search = BreadthFirstSearchRoadBlock( 167 | starting_block.id, map_api, forward_search=True 168 | ) 169 | (path, path_id), path_found = graph_search.search( 170 | route_roadblock_ids[:3], max_depth=search_depth_forward 171 | ) 172 | 173 | if path_found: 174 | end_roadblock_idx = np.argmax( 175 | np.array(route_roadblock_ids) == path_id[-1] 176 | ) 177 | 178 | route_roadblocks = route_roadblocks[end_roadblock_idx + 1 :] 179 | route_roadblock_ids = route_roadblock_ids[end_roadblock_idx + 1 :] 180 | 181 | route_roadblocks[:0] = path 182 | route_roadblock_ids[:0] = path_id 183 | 184 | # Fix 2: check if roadblocks are linked, search for links if not 185 | roadblocks_to_append = {} 186 | for i in range(len(route_roadblocks) - 1): 187 | next_incoming_block_ids = [ 188 | _roadblock.id for _roadblock in route_roadblocks[i + 1].incoming_edges 189 | ] 190 | is_incoming = route_roadblock_ids[i] in next_incoming_block_ids 191 | 192 | if is_incoming: 193 | continue 194 | 195 | graph_search = BreadthFirstSearchRoadBlock( 196 | route_roadblock_ids[i], map_api, forward_search=True 197 | ) 198 | (path, path_id), path_found = graph_search.search( 199 | route_roadblock_ids[i + 1], max_depth=search_depth_forward 200 | ) 201 | 202 | if path_found and path and len(path) >= 3: 203 | path, path_id = path[1:-1], path_id[1:-1] 204 | roadblocks_to_append[i] = (path, path_id) 205 | 206 | # append missing intermediate roadblocks 207 | offset = 1 208 | for i, (path, path_id) in roadblocks_to_append.items(): 209 | route_roadblocks[i + offset : i + offset] = path 210 | route_roadblock_ids[i + offset : i + offset] = path_id 211 | offset += len(path) 212 | 213 | # Fix 3: cut route-loops 214 | route_roadblocks, route_roadblock_ids = remove_route_loops( 215 | route_roadblocks, route_roadblock_ids 216 | ) 217 | 218 | return route_roadblock_ids 219 | 220 | 221 | def remove_route_loops( 222 | route_roadblocks: List[RoadBlockGraphEdgeMapObject], 223 | route_roadblock_ids: List[str], 224 | ) -> Tuple[List[str], List[RoadBlockGraphEdgeMapObject]]: 225 | """ 226 | Remove ending of route, if the roadblock are intersecting the route (forming a loop). 227 | :param route_roadblocks: input route roadblocks 228 | :param route_roadblock_ids: input route roadblocks ids 229 | :return: tuple of ids and roadblocks of route without loops 230 | """ 231 | 232 | roadblock_occupancy_map = None 233 | loop_idx = None 234 | 235 | for idx, roadblock in enumerate(route_roadblocks): 236 | # loops only occur at intersection, thus searching for roadblock-connectors. 237 | if str(roadblock.__class__.__name__) == "NuPlanRoadBlockConnector": 238 | if not roadblock_occupancy_map: 239 | roadblock_occupancy_map = STRTreeOccupancyMapFactory.get_from_geometry( 240 | [roadblock.polygon], [roadblock.id] 241 | ) 242 | continue 243 | 244 | strtree, index_by_id = roadblock_occupancy_map._build_strtree() 245 | indices = strtree.query(roadblock.polygon) 246 | if len(indices) > 0: 247 | for geom in strtree.geometries.take(indices): 248 | area = geom.intersection(roadblock.polygon).area 249 | if area > 1: 250 | loop_idx = idx 251 | break 252 | if loop_idx: 253 | break 254 | 255 | roadblock_occupancy_map.insert(roadblock.id, roadblock.polygon) 256 | 257 | if loop_idx: 258 | route_roadblocks = route_roadblocks[:loop_idx] 259 | route_roadblock_ids = route_roadblock_ids[:loop_idx] 260 | 261 | return route_roadblocks, route_roadblock_ids 262 | -------------------------------------------------------------------------------- /src/utils/collision_checker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nuplan.common.actor_state.vehicle_parameters import ( 3 | VehicleParameters, 4 | get_pacifica_parameters, 5 | ) 6 | 7 | 8 | class CollisionChecker: 9 | def __init__( 10 | self, 11 | vehicle: VehicleParameters = get_pacifica_parameters(), 12 | ) -> None: 13 | self._vehicle = vehicle 14 | self._sdc_half_length = vehicle.length / 2 15 | self._sdc_half_width = vehicle.width / 2 16 | 17 | self._sdc_normalized_corners = torch.stack( 18 | [ 19 | torch.tensor([vehicle.length / 2, vehicle.width / 2]), 20 | torch.tensor([vehicle.length / 2, -vehicle.width / 2]), 21 | torch.tensor([-vehicle.length / 2, -vehicle.width / 2]), 22 | torch.tensor([-vehicle.length / 2, vehicle.width / 2]), 23 | ], 24 | dim=0, 25 | ) 26 | 27 | def to_device(self, device): 28 | self._sdc_normalized_corners = self._sdc_normalized_corners.to(device) 29 | 30 | def build_bbox_from_center(self, center, heading, width, length): 31 | """ 32 | params: 33 | center: [bs, N, (x, y)] 34 | heading: [bs, N] 35 | width: [bs, N] 36 | length: [bs, N] 37 | return: 38 | corners: [bs, 4, (x, y)] 39 | heading_vec, tanh_vec: [bs, 2] 40 | """ 41 | cos = torch.cos(heading) 42 | sin = torch.sin(heading) 43 | 44 | heading_vec = torch.stack([cos, sin], dim=-1) * length.unsqueeze(-1) / 2 45 | tanh_vec = torch.stack([-sin, cos], dim=-1) * width.unsqueeze(-1) / 2 46 | 47 | corners = torch.stack( 48 | [ 49 | center + heading_vec + tanh_vec, 50 | center - heading_vec + tanh_vec, 51 | center - heading_vec - tanh_vec, 52 | center + heading_vec - tanh_vec, 53 | ], 54 | dim=-2, 55 | ) 56 | 57 | return corners, heading_vec, tanh_vec 58 | 59 | def collision_check(self, ego_state, objects, objects_width, objects_length): 60 | """performing batch-wise collision check using Separating Axis Theorem 61 | params: 62 | ego_states: [bs, (x, y, theta)], center of the ego 63 | objects: [bs, N, (x, y, theta)], center of the objects 64 | returns: 65 | is_collided: [bs, N] 66 | """ 67 | 68 | bs, N = objects.shape[:2] 69 | 70 | # rotate object to ego's local frame 71 | cos, sin = torch.cos(ego_state[:, 2]), torch.sin(ego_state[:, 2]) 72 | rotate_mat = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(bs, 2, 2) 73 | 74 | rotated_objects = objects.clone() 75 | rotated_objects[..., :2] = torch.matmul( 76 | rotated_objects[..., :2] - ego_state[:, :2].unsqueeze(1), rotate_mat 77 | ) 78 | rotated_objects[..., 2] -= ego_state[..., 2].unsqueeze(1) 79 | 80 | # [bs, N, 4, 2], [bs, N, 2], [bs, N, 2] 81 | object_corners, axis1, axis2 = self.build_bbox_from_center( 82 | rotated_objects[..., :2], 83 | rotated_objects[..., 2], 84 | objects_width, 85 | objects_length, 86 | ) 87 | 88 | ego_corners = self._sdc_normalized_corners.reshape(1, 1, 4, 2).repeat( 89 | bs, N, 1, 1 90 | ) # [bs, N, 4, 2] 91 | 92 | all_corners = torch.concat( 93 | [object_corners, ego_corners], dim=-2 94 | ) # [bs, N, 8, 2] 95 | 96 | x_projection = object_corners[..., 0] 97 | y_projection = object_corners[..., 1] 98 | axis1_projection = torch.matmul(all_corners, axis1.unsqueeze(-1)).squeeze(-1) 99 | axis2_projection = torch.matmul(all_corners, axis2.unsqueeze(-1)).squeeze(-1) 100 | 101 | x_separated = (x_projection.max(-1)[0] < -self._sdc_half_length) | ( 102 | x_projection.min(-1)[0] > self._sdc_half_length 103 | ) 104 | y_separated = (y_projection.max(-1)[0] < -self._sdc_half_width) | ( 105 | y_projection.min(-1)[0] > self._sdc_half_width 106 | ) 107 | axis1_separated = ( 108 | axis1_projection[..., :4].max(-1)[0] < axis1_projection[..., 4:].min(-1)[0] 109 | ) | ( 110 | axis1_projection[..., :4].min(-1)[0] > axis1_projection[..., 4:].max(-1)[0] 111 | ) 112 | axis2_separated = ( 113 | axis2_projection[..., :4].max(-1)[0] < axis2_projection[..., 4:].min(-1)[0] 114 | ) | ( 115 | axis2_projection[..., :4].min(-1)[0] > axis2_projection[..., 4:].max(-1)[0] 116 | ) 117 | 118 | collision = ~(x_separated | y_separated | axis1_separated | axis2_separated) 119 | 120 | return collision 121 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pprint 4 | from pathlib import Path 5 | import numpy 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import cv2 10 | 11 | 12 | def to_tensor(data): 13 | if isinstance(data, dict): 14 | return {k: to_tensor(v) for k, v in data.items()} 15 | elif isinstance(data, numpy.ndarray): 16 | if data.dtype == numpy.float64: 17 | return torch.from_numpy(data).float() 18 | else: 19 | return torch.from_numpy(data) 20 | elif isinstance(data, numpy.number): 21 | return torch.tensor(data).float() 22 | elif isinstance(data, list): 23 | return data 24 | elif isinstance(data, int): 25 | return torch.tensor(data) 26 | elif isinstance(data, tuple): 27 | return to_tensor(data[0]) 28 | else: 29 | print(type(data), data) 30 | raise NotImplementedError 31 | 32 | 33 | def to_numpy(data): 34 | if isinstance(data, dict): 35 | return {k: to_numpy(v) for k, v in data.items()} 36 | elif isinstance(data, torch.Tensor): 37 | if data.requires_grad: 38 | return data.detach().cpu().numpy() 39 | else: 40 | return data.cpu().numpy() 41 | else: 42 | print(type(data), data) 43 | raise NotImplementedError 44 | 45 | 46 | def enable_grad(data): 47 | if isinstance(data, dict): 48 | return {k: enable_grad(v) for k, v in data.items()} 49 | elif isinstance(data, torch.Tensor): 50 | if data.dtype == torch.float32: 51 | data.requires_grad = True 52 | else: 53 | raise NotImplementedError 54 | 55 | 56 | def to_device(data, device): 57 | if isinstance(data, dict): 58 | return {k: to_device(v, device) for k, v in data.items()} 59 | elif isinstance(data, torch.Tensor): 60 | return data.to(device) 61 | else: 62 | raise NotImplementedError 63 | 64 | 65 | def print_dict_tensor(data, prefix=""): 66 | for k, v in data.items(): 67 | if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): 68 | print(f"{prefix}{k}: {v.shape}") 69 | elif isinstance(v, dict): 70 | print(f"{prefix}{k}:") 71 | print_dict_tensor(v, " ") 72 | 73 | 74 | def print_simulation_results(file=None): 75 | if file is not None: 76 | df = pd.read_parquet(file) 77 | else: 78 | root = Path(os.getcwd()) / "aggregator_metric" 79 | result = list(root.glob("*.parquet")) 80 | result = max(result, key=lambda item: item.stat().st_ctime) 81 | df = pd.read_parquet(result) 82 | final_score = df[df["scenario"] == "final_score"] 83 | final_score = final_score.to_dict(orient="records")[0] 84 | pprint.PrettyPrinter(indent=4).pprint(final_score) 85 | 86 | 87 | def load_checkpoint(checkpoint): 88 | ckpt = torch.load(checkpoint, map_location=torch.device("cpu")) 89 | state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} 90 | return state_dict 91 | 92 | 93 | def safe_index(ls, value): 94 | try: 95 | return ls.index(value) 96 | except ValueError: 97 | return None 98 | 99 | 100 | def shift_and_rotate_img(img, shift, angle, resolution, cval=-200): 101 | """ 102 | img: (H, W, C) 103 | shift: (H_shift, W_shift, 0) 104 | resolution: float 105 | angle: float 106 | """ 107 | rows, cols = img.shape[:2] 108 | shift = shift / resolution 109 | translation_matrix = np.float32([[1, 0, shift[1]], [0, 1, shift[0]]]) 110 | translated_img = cv2.warpAffine( 111 | img, translation_matrix, (cols, rows), borderValue=cval 112 | ) 113 | M = cv2.getRotationMatrix2D((cols / 2, rows / 2), math.degrees(angle), 1) 114 | rotated_img = cv2.warpAffine(translated_img, M, (cols, rows), borderValue=cval) 115 | if len(img.shape) == 3 and len(rotated_img.shape) == 2: 116 | rotated_img = rotated_img[..., np.newaxis] 117 | return rotated_img.astype(np.float32) 118 | 119 | 120 | def crop_img_from_center(img, crop_size): 121 | h, w = img.shape[:2] 122 | h_crop, w_crop = crop_size 123 | h_start = (h - h_crop) // 2 124 | w_start = (w - w_crop) // 2 125 | return img[h_start : h_start + h_crop, w_start : w_start + w_crop].astype( 126 | np.float32 127 | ) 128 | --------------------------------------------------------------------------------