├── config ├── custom_trainer │ └── planTF.yaml ├── training │ └── train_planTF.yaml ├── data_augmentation │ └── state_perturbation.yaml ├── model │ └── planTF.yaml ├── planner │ └── planTF.yaml ├── default_nuboard.yaml ├── scenario_filter │ ├── single_right_turn.yaml │ ├── mini.yaml │ ├── training_scenarios_1M.yaml │ ├── test14-random.yaml │ └── test14-hard.yaml ├── default_training.yaml ├── lightning │ └── custom_lightning.yaml └── default_simulation.yaml ├── src ├── metrics │ ├── __init__.py │ ├── utils.py │ ├── mr.py │ ├── min_fde.py │ └── min_ade.py ├── utils │ ├── conversion.py │ └── collision_checker.py ├── models │ └── planTF │ │ ├── modules │ │ ├── trajectory_decoder.py │ │ ├── map_encoder.py │ │ └── agent_encoder.py │ │ ├── layers │ │ ├── common_layers.py │ │ ├── transformer_encoder_layer.py │ │ └── embedding.py │ │ ├── planning_model.py │ │ └── lightning_trainer.py ├── planners │ ├── planner_utils.py │ └── imitation_planner.py ├── optim │ └── warmup_cos_lr.py ├── feature_builders │ ├── common │ │ ├── utils.py │ │ ├── bfs_roadblock.py │ │ └── route_utils.py │ └── nuplan_feature_builder.py ├── features │ └── nuplan_feature.py ├── data_augmentation │ └── state_perturbation.py └── custom_training │ ├── custom_training_builder.py │ └── custom_datamodule.py ├── requirements.txt ├── script ├── setup_env.sh ├── plantf_single_scenarios.sh ├── plantf_benchmarks.sh ├── raster_model_benchmarks.sh └── urbandriver_benchmarks.sh ├── docs └── other_baselines.md ├── run_nuboard.py ├── .gitignore ├── run_training.py ├── run_simulation.py └── README.md /config/custom_trainer/planTF.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.planTF.lightning_trainer.LightningTrainer 2 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .min_ade import minADE 2 | from .min_fde import minFDE 3 | from .mr import MR -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | pytorch-lightning==2.0.1 3 | torchmetrics==0.10.2 4 | tensorboard 5 | wandb==0.14.2 6 | numba 7 | rich==13.3.4 -------------------------------------------------------------------------------- /script/setup_env.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu116 2 | pip3 install natten==0.14.6 -f https://shi-labs.com/natten/wheels/cu116/torch1.12/index.html 3 | pip install -r ./requirements.txt 4 | -------------------------------------------------------------------------------- /script/plantf_single_scenarios.sh: -------------------------------------------------------------------------------- 1 | cwd=$(pwd) 2 | CKPT_ROOT="$cwd/checkpoints" 3 | PLANNER="planTF" 4 | 5 | python run_simulation.py \ 6 | +simulation=closed_loop_nonreactive_agents \ 7 | planner=planTF \ 8 | scenario_builder=nuplan_challenge \ 9 | scenario_filter=single_right_turn \ 10 | worker=sequential \ 11 | verbose=true \ 12 | planner.imitation_planner.planner_ckpt="$CKPT_ROOT/$PLANNER.ckpt" -------------------------------------------------------------------------------- /config/training/train_planTF.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | job_name: planTF 3 | py_func: train 4 | objective_aggregate_mode: mean 5 | 6 | defaults: 7 | - override /data_augmentation: 8 | - state_perturbation 9 | - override /splitter: nuplan 10 | - override /model: planTF 11 | - override /scenario_filter: training_scenarios_1M 12 | - override /custom_trainer: planTF 13 | - override /lightning: custom_lightning 14 | -------------------------------------------------------------------------------- /config/data_augmentation/state_perturbation.yaml: -------------------------------------------------------------------------------- 1 | perturbation_nuplan: 2 | _target_: src.data_augmentation.state_perturbation.StatePerturbation 3 | _convert_: "all" 4 | 5 | dt: 0.1 # the time interval between trajectory points 6 | hist_len: 21 7 | low: [-1.0, -0.75, -0.35, -1, -0.5, -0.2, -0.1] 8 | high: [1.0, 0.75, 0.35, 1, 0.5, 0.2, 0.1] 9 | augment_prob: 0.5 # probability of applying data augmentation for training 10 | normalize: True 11 | -------------------------------------------------------------------------------- /config/model/planTF.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.planTF.planning_model.PlanningModel 2 | _convert_: "all" 3 | 4 | dim: 128 5 | state_channel: 6 6 | polygon_channel: 6 7 | history_channel: 9 8 | history_steps: 21 9 | future_steps: 80 10 | encoder_depth: 4 11 | drop_path: 0.2 12 | num_heads: 8 13 | num_modes: 6 14 | state_dropout: 0.75 15 | use_ego_history: false 16 | state_attn_encoder: true 17 | 18 | feature_builder: 19 | _target_: src.feature_builders.nuplan_feature_builder.NuplanFeatureBuilder 20 | _convert_: "all" 21 | radius: 100 22 | history_horizon: 2 23 | future_horizon: 8 24 | sample_interval: 0.1 25 | max_agents: 32 26 | -------------------------------------------------------------------------------- /script/plantf_benchmarks.sh: -------------------------------------------------------------------------------- 1 | cwd=$(pwd) 2 | CKPT_ROOT="$cwd/checkpoints" 3 | 4 | PLANNER="planTF" 5 | SPLIT=$1 6 | CHALLENGES="closed_loop_nonreactive_agents closed_loop_reactive_agents open_loop_boxes" 7 | 8 | for challenge in $CHALLENGES; do 9 | python run_simulation.py \ 10 | +simulation=$challenge \ 11 | planner=$PLANNER \ 12 | scenario_builder=nuplan_challenge \ 13 | scenario_filter=$SPLIT \ 14 | worker.threads_per_node=20 \ 15 | experiment_uid=$SPLIT/$PLANNER \ 16 | verbose=true \ 17 | planner.imitation_planner.planner_ckpt="$CKPT_ROOT/$PLANNER.ckpt" 18 | done 19 | 20 | 21 | -------------------------------------------------------------------------------- /script/raster_model_benchmarks.sh: -------------------------------------------------------------------------------- 1 | cwd=$(pwd) 2 | CKPT_ROOT="$cwd/checkpoints" 3 | 4 | PLANNER="raster_model" 5 | SPLIT=$1 6 | CHALLENGES="closed_loop_nonreactive_agents closed_loop_reactive_agents open_loop_boxes" 7 | 8 | for challenge in $CHALLENGES; do 9 | python run_simulation.py \ 10 | +simulation=$challenge \ 11 | model=raster_model \ 12 | planner=ml_planner \ 13 | 'planner.ml_planner.model_config=${model}' \ 14 | scenario_builder=nuplan_challenge \ 15 | scenario_filter=$SPLIT \ 16 | worker.threads_per_node=20 \ 17 | experiment_uid=$SPLIT/$planner \ 18 | verbose=true \ 19 | planner.ml_planner.checkpoint_path="$CKPT_ROOT/$PLANNER.ckpt" 20 | done 21 | 22 | 23 | -------------------------------------------------------------------------------- /script/urbandriver_benchmarks.sh: -------------------------------------------------------------------------------- 1 | cwd=$(pwd) 2 | CKPT_ROOT="$cwd/checkpoints" 3 | 4 | PLANNER="urban_driver_open_loop" 5 | SPLIT=$1 6 | CHALLENGES="closed_loop_nonreactive_agents closed_loop_reactive_agents open_loop_boxes" 7 | 8 | for challenge in $CHALLENGES; do 9 | python run_simulation.py \ 10 | +simulation=$challenge \ 11 | model=urban_driver_open_loop_model \ 12 | planner=ml_planner \ 13 | 'planner.ml_planner.model_config=${model}' \ 14 | scenario_builder=nuplan_challenge \ 15 | scenario_filter=$SPLIT \ 16 | worker.threads_per_node=20 \ 17 | experiment_uid=$SPLIT/$planner \ 18 | verbose=true \ 19 | planner.ml_planner.checkpoint_path="$CKPT_ROOT/$PLANNER.ckpt" 20 | done -------------------------------------------------------------------------------- /src/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sort_predictions(predictions, probability, k=6): 5 | """Sort the predictions based on the probability of each mode. 6 | Args: 7 | predictions (torch.Tensor): The predicted trajectories [b, k, t, 2]. 8 | probability (torch.Tensor): The probability of each mode [b, k]. 9 | Returns: 10 | torch.Tensor: The sorted predictions [b, k', t, 2]. 11 | """ 12 | indices = torch.argsort(probability, dim=-1, descending=True) 13 | sorted_prob = probability[torch.arange(probability.size(0))[:, None], indices] 14 | sorted_predictions = predictions[ 15 | torch.arange(predictions.size(0))[:, None], indices 16 | ] 17 | return sorted_predictions[:, :k], sorted_prob[:, :k] 18 | -------------------------------------------------------------------------------- /config/planner/planTF.yaml: -------------------------------------------------------------------------------- 1 | imitation_planner: 2 | _target_: src.planners.imitation_planner.ImitationPlanner 3 | _convert_: "all" 4 | 5 | replan_interval: 1 6 | 7 | planner: 8 | _target_: src.models.planTF.planning_model.PlanningModel 9 | _convert_: "all" 10 | 11 | dim: 128 12 | state_channel: 6 13 | polygon_channel: 6 14 | history_channel: 9 15 | history_steps: 21 16 | future_steps: 80 17 | encoder_depth: 4 18 | drop_path: 0.2 19 | num_heads: 8 20 | num_modes: 6 21 | state_dropout: 0.75 22 | use_ego_history: false 23 | state_attn_encoder: true 24 | 25 | feature_builder: 26 | _target_: src.feature_builders.nuplan_feature_builder.NuplanFeatureBuilder 27 | _convert_: "all" 28 | radius: 100 29 | history_horizon: 2 30 | future_horizon: 8 31 | sample_interval: 0.1 32 | max_agents: 32 33 | 34 | planner_ckpt: 35 | -------------------------------------------------------------------------------- /config/default_nuboard.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: . 4 | output_subdir: null # Store hydra's config breakdown here for debugging 5 | searchpath: # Only in these paths are discoverable 6 | - pkg://nuplan.planning.script.config.common 7 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/ 8 | 9 | 10 | defaults: 11 | - default_common 12 | - simulation_metric: 13 | - default_metrics 14 | - override hydra/job_logging: none # Disable hydra's logging 15 | - override hydra/hydra_logging: none # Disable hydra's logging 16 | 17 | log_config: False # Whether to log the final config after all overrides and interpolations 18 | port_number: 5006 19 | simulation_path: null 20 | resource_prefix: null 21 | profiler_path: null 22 | -------------------------------------------------------------------------------- /src/utils/conversion.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | 4 | 5 | def to_tensor(data): 6 | if isinstance(data, dict): 7 | return {k: to_tensor(v) for k, v in data.items()} 8 | elif isinstance(data, numpy.ndarray): 9 | if data.dtype == bool: 10 | return torch.from_numpy(data).bool() 11 | else: 12 | return torch.from_numpy(data).float() 13 | elif isinstance(data, numpy.number): 14 | return torch.tensor(data).float() 15 | else: 16 | print(type(data)) 17 | raise NotImplementedError 18 | 19 | 20 | def to_numpy(data): 21 | if isinstance(data, dict): 22 | return {k: to_numpy(v) for k, v in data.items()} 23 | elif isinstance(data, torch.Tensor): 24 | if data.requires_grad: 25 | return data.detach().cpu().numpy() 26 | else: 27 | return data.cpu().numpy() 28 | else: 29 | raise NotImplementedError 30 | 31 | 32 | def to_device(data, device): 33 | if isinstance(data, dict): 34 | return {k: to_device(v, device) for k, v in data.items()} 35 | elif isinstance(data, torch.Tensor): 36 | return data.to(device) 37 | else: 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /src/models/planTF/modules/trajectory_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TrajectoryDecoder(nn.Module): 7 | def __init__(self, embed_dim, num_modes, future_steps, out_channels) -> None: 8 | super().__init__() 9 | 10 | self.embed_dim = embed_dim 11 | self.num_modes = num_modes 12 | self.future_steps = future_steps 13 | self.out_channels = out_channels 14 | 15 | self.multimodal_proj = nn.Linear(embed_dim, num_modes * embed_dim) 16 | 17 | hidden = 2 * embed_dim 18 | self.loc = nn.Sequential( 19 | nn.Linear(embed_dim, hidden), 20 | nn.LayerNorm(hidden), 21 | nn.ReLU(inplace=True), 22 | nn.Linear(hidden, future_steps * out_channels), 23 | ) 24 | self.pi = nn.Sequential( 25 | nn.Linear(embed_dim, hidden), 26 | nn.LayerNorm(hidden), 27 | nn.ReLU(inplace=True), 28 | nn.Linear(hidden, 1), 29 | ) 30 | 31 | def forward(self, x): 32 | x = self.multimodal_proj(x).view(-1, self.num_modes, self.embed_dim) 33 | loc = self.loc(x).view(-1, self.num_modes, self.future_steps, self.out_channels) 34 | pi = self.pi(x).squeeze(-1) 35 | 36 | return loc, pi 37 | -------------------------------------------------------------------------------- /src/models/planTF/layers/common_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_mlp(c_in, channels, norm=None, activation="relu"): 5 | layers = [] 6 | num_layers = len(channels) 7 | 8 | if norm is not None: 9 | norm = get_norm(norm) 10 | 11 | activation = get_activation(activation) 12 | 13 | for k in range(num_layers): 14 | if k == num_layers - 1: 15 | layers.append(nn.Linear(c_in, channels[k], bias=True)) 16 | else: 17 | if norm is None: 18 | layers.extend([nn.Linear(c_in, channels[k], bias=True), activation()]) 19 | else: 20 | layers.extend( 21 | [ 22 | nn.Linear(c_in, channels[k], bias=False), 23 | norm(channels[k]), 24 | activation(), 25 | ] 26 | ) 27 | c_in = channels[k] 28 | 29 | return nn.Sequential(*layers) 30 | 31 | 32 | def get_norm(norm: str): 33 | if norm == "bn": 34 | return nn.BatchNorm1d 35 | elif norm == "ln": 36 | return nn.LayerNorm 37 | else: 38 | raise NotImplementedError 39 | 40 | 41 | def get_activation(activation: str): 42 | if activation == "relu": 43 | return nn.ReLU 44 | elif activation == "gelu": 45 | return nn.GELU 46 | else: 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /config/scenario_filter/single_right_turn.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: "all" 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: 6 | - 8de10fd86b825304 7 | 8 | log_names: # Filter scenarios by log names 9 | - 2021.05.25.12.30.39_veh-25_00321_01196 10 | map_names: null # Filter scenarios by map names 11 | 12 | num_scenarios_per_type: null # Number of scenarios per type 13 | limit_total_scenarios: null # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 14 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 15 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 16 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 17 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 18 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 19 | 20 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios 21 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 22 | shuffle: false # Whether to shuffle the scenarios 23 | -------------------------------------------------------------------------------- /config/scenario_filter/mini.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: 'all' 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: null # List of scenario tokens to include 6 | 7 | log_names: null # Filter scenarios by log names 8 | map_names: null # Filter scenarios by map names 9 | 10 | num_scenarios_per_type: null # Number of scenarios per type 11 | limit_total_scenarios: 100 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 17 | 18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios 19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 20 | shuffle: true # Whether to shuffle the scenarios 21 | -------------------------------------------------------------------------------- /src/metrics/mr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class MR(Metric): 8 | full_state_update: Optional[bool] = False 9 | higher_is_better: Optional[bool] = False 10 | 11 | def __init__( 12 | self, 13 | miss_threshold: float = 2.0, 14 | compute_on_step: bool = True, 15 | dist_sync_on_step: bool = False, 16 | process_group: Optional[Any] = None, 17 | dist_sync_fn: Callable = None, 18 | ) -> None: 19 | super(MR, self).__init__( 20 | compute_on_step=compute_on_step, 21 | dist_sync_on_step=dist_sync_on_step, 22 | process_group=process_group, 23 | dist_sync_fn=dist_sync_fn, 24 | ) 25 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 26 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 27 | self.miss_threshold = miss_threshold 28 | 29 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 30 | with torch.no_grad(): 31 | pred = outputs["trajectory"] 32 | missed_pred = ( 33 | torch.norm( 34 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1 35 | ) 36 | > self.miss_threshold 37 | ) 38 | self.sum += missed_pred.all(-1).sum() 39 | self.count += pred.shape[0] 40 | 41 | def compute(self) -> torch.Tensor: 42 | return self.sum / self.count 43 | -------------------------------------------------------------------------------- /config/scenario_filter/training_scenarios_1M.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: 'all' 3 | 4 | scenario_types: null # List of scenario types to include 5 | scenario_tokens: null # List of scenario tokens to include 6 | 7 | log_names: null # Filter scenarios by log names 8 | map_names: null # Filter scenarios by map names 9 | 10 | num_scenarios_per_type: null # Number of scenarios per type 11 | limit_total_scenarios: 1000000 # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 12 | timestamp_threshold_s: null # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 13 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 14 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 15 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 16 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 17 | 18 | expand_scenarios: true # Whether to expand multi-sample scenarios to multiple single-sample scenarios 19 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 20 | shuffle: true # Whether to shuffle the scenarios 21 | -------------------------------------------------------------------------------- /src/metrics/min_fde.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from .utils import sort_predictions 7 | 8 | 9 | class minFDE(Metric): 10 | full_state_update: Optional[bool] = False 11 | higher_is_better: Optional[bool] = False 12 | 13 | def __init__( 14 | self, 15 | k=6, 16 | compute_on_step: bool = True, 17 | dist_sync_on_step: bool = False, 18 | process_group: Optional[Any] = None, 19 | dist_sync_fn: Callable = None, 20 | ) -> None: 21 | super(minFDE, self).__init__( 22 | compute_on_step=compute_on_step, 23 | dist_sync_on_step=dist_sync_on_step, 24 | process_group=process_group, 25 | dist_sync_fn=dist_sync_fn, 26 | ) 27 | self.k = k 28 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 29 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 30 | 31 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 32 | with torch.no_grad(): 33 | pred, _ = sort_predictions( 34 | outputs["trajectory"], outputs["probability"], k=self.k 35 | ) 36 | fde = torch.norm( 37 | pred[..., -1, :2] - target.unsqueeze(1)[..., -1, :2], p=2, dim=-1 38 | ) 39 | min_fde = fde.min(-1)[0] 40 | self.sum += min_fde.sum() 41 | self.count += pred.shape[0] 42 | 43 | def compute(self) -> torch.Tensor: 44 | return self.sum / self.count 45 | -------------------------------------------------------------------------------- /src/planners/planner_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Deque 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | import torch 6 | from nuplan.common.actor_state.ego_state import EgoState 7 | from nuplan.common.actor_state.state_representation import StateSE2 8 | from nuplan.planning.simulation.planner.ml_planner.transform_utils import ( 9 | _get_fixed_timesteps, 10 | _get_velocity_and_acceleration, 11 | _se2_vel_acc_to_ego_state, 12 | ) 13 | 14 | 15 | def global_trajectory_to_states( 16 | global_trajectory: npt.NDArray[np.float32], 17 | ego_history: Deque[EgoState], 18 | future_horizon: float, 19 | step_interval: float, 20 | include_ego_state: bool = True, 21 | ): 22 | ego_state = ego_history[-1] 23 | timesteps = _get_fixed_timesteps(ego_state, future_horizon, step_interval) 24 | global_states = [StateSE2.deserialize(pose) for pose in global_trajectory] 25 | 26 | velocities, accelerations = _get_velocity_and_acceleration( 27 | global_states, ego_history, timesteps 28 | ) 29 | agent_states = [ 30 | _se2_vel_acc_to_ego_state( 31 | state, 32 | velocity, 33 | acceleration, 34 | timestep, 35 | ego_state.car_footprint.vehicle_parameters, 36 | ) 37 | for state, velocity, acceleration, timestep in zip( 38 | global_states, velocities, accelerations, timesteps 39 | ) 40 | ] 41 | 42 | if include_ego_state: 43 | agent_states.insert(0, ego_state) 44 | 45 | return agent_states 46 | 47 | 48 | def load_checkpoint(checkpoint: str): 49 | ckpt = torch.load(checkpoint, map_location=torch.device("cpu")) 50 | state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()} 51 | return state_dict 52 | -------------------------------------------------------------------------------- /src/metrics/min_ade.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from .utils import sort_predictions 7 | 8 | 9 | class minADE(Metric): 10 | """Minimum Average Displacement Error 11 | minADE: The average L2 distance between the best forecasted trajectory and the ground truth. 12 | The best here refers to the trajectory that has the minimum endpoint error. 13 | """ 14 | 15 | full_state_update: Optional[bool] = False 16 | higher_is_better: Optional[bool] = False 17 | 18 | def __init__( 19 | self, 20 | k=6, 21 | compute_on_step: bool = True, 22 | dist_sync_on_step: bool = False, 23 | process_group: Optional[Any] = None, 24 | dist_sync_fn: Callable = None, 25 | ) -> None: 26 | super(minADE, self).__init__( 27 | compute_on_step=compute_on_step, 28 | dist_sync_on_step=dist_sync_on_step, 29 | process_group=process_group, 30 | dist_sync_fn=dist_sync_fn, 31 | ) 32 | self.k = k 33 | self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum") 34 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 35 | 36 | def update(self, outputs: Dict[str, torch.Tensor], target: torch.Tensor) -> None: 37 | with torch.no_grad(): 38 | pred, _ = sort_predictions( 39 | outputs["trajectory"], outputs["probability"], k=self.k 40 | ) 41 | ade = torch.norm( 42 | pred[..., :2] - target.unsqueeze(1)[..., :2], p=2, dim=-1 43 | ).mean(-1) 44 | min_ade = ade.min(-1)[0] 45 | self.sum += min_ade.sum() 46 | self.count += pred.size(0) 47 | 48 | def compute(self) -> torch.Tensor: 49 | return self.sum / self.count 50 | -------------------------------------------------------------------------------- /config/scenario_filter/test14-random.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: "all" 3 | 4 | scenario_types: # List of scenario types to include 5 | - starting_left_turn 6 | - starting_right_turn 7 | - starting_straight_traffic_light_intersection_traversal 8 | - stopping_with_lead 9 | - high_lateral_acceleration 10 | - high_magnitude_speed 11 | - low_magnitude_speed 12 | - traversing_pickup_dropoff 13 | - waiting_for_pedestrian_to_cross 14 | - behind_long_vehicle 15 | - stationary_in_traffic 16 | - near_multiple_vehicles 17 | - changing_lane 18 | - following_lane_with_lead 19 | scenario_tokens: null # List of scenario tokens to include 20 | 21 | log_names: null # Filter scenarios by log names 22 | map_names: null # Filter scenarios by map names 23 | 24 | num_scenarios_per_type: 20 # Number of scenarios per type 25 | limit_total_scenarios: null # Limit total scenarios (float = fraction, int = num) - this filter can be applied on top of num_scenarios_per_type 26 | timestamp_threshold_s: 15 # Filter scenarios to ensure scenarios have more than `timestamp_threshold_s` seconds between their initial lidar timestamps 27 | ego_displacement_minimum_m: null # Whether to remove scenarios where the ego moves less than a certain amount 28 | ego_start_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from below 29 | ego_stop_speed_threshold: null # Limit to scenarios where the ego reaches a certain speed from above 30 | speed_noise_tolerance: null # Value at or below which a speed change between two timepoints should be ignored as noise. 31 | 32 | expand_scenarios: false # Whether to expand multi-sample scenarios to multiple single-sample scenarios 33 | remove_invalid_goals: true # Whether to remove scenarios where the mission goal is invalid 34 | shuffle: false # Whether to shuffle the scenarios 35 | -------------------------------------------------------------------------------- /src/optim/warmup_cos_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class WarmupCosLR(_LRScheduler): 7 | def __init__( 8 | self, optimizer, min_lr, lr, warmup_epochs, epochs, last_epoch=-1, verbose=False 9 | ) -> None: 10 | self.min_lr = min_lr 11 | self.lr = lr 12 | self.epochs = epochs 13 | self.warmup_epochs = warmup_epochs 14 | super(WarmupCosLR, self).__init__(optimizer, last_epoch, verbose) 15 | 16 | def state_dict(self): 17 | """Returns the state of the scheduler as a :class:`dict`. 18 | 19 | It contains an entry for every variable in self.__dict__ which 20 | is not the optimizer. 21 | """ 22 | return { 23 | key: value for key, value in self.__dict__.items() if key != "optimizer" 24 | } 25 | 26 | def load_state_dict(self, state_dict): 27 | """Loads the schedulers state. 28 | 29 | Args: 30 | state_dict (dict): scheduler state. Should be an object returned 31 | from a call to :meth:`state_dict`. 32 | """ 33 | self.__dict__.update(state_dict) 34 | 35 | def get_init_lr(self): 36 | lr = self.lr / self.warmup_epochs 37 | return lr 38 | 39 | def get_lr(self): 40 | if self.last_epoch < self.warmup_epochs: 41 | lr = self.lr * (self.last_epoch + 1) / self.warmup_epochs 42 | else: 43 | lr = self.min_lr + 0.5 * (self.lr - self.min_lr) * ( 44 | 1 45 | + math.cos( 46 | math.pi 47 | * (self.last_epoch - self.warmup_epochs) 48 | / (self.epochs - self.warmup_epochs) 49 | ) 50 | ) 51 | if "lr_scale" in self.optimizer.param_groups[0]: 52 | return [lr * group["lr_scale"] for group in self.optimizer.param_groups] 53 | 54 | return [lr for _ in self.optimizer.param_groups] 55 | -------------------------------------------------------------------------------- /src/feature_builders/common/utils.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | 5 | def normalize_angle(angle: np.ndarray): 6 | return (angle + np.pi) % (2 * np.pi) - np.pi 7 | 8 | 9 | @numba.njit 10 | def rotate_round_z_axis(points: np.ndarray, angle: float): 11 | rotate_mat = np.array( 12 | [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]] 13 | ) 14 | return points @ rotate_mat 15 | 16 | 17 | def interpolate_polyline(points: np.ndarray, t: int) -> np.ndarray: 18 | """copy from av2-api""" 19 | 20 | if points.ndim != 2: 21 | raise ValueError("Input array must be (N,2) or (N,3) in shape.") 22 | 23 | # the number of points on the curve itself 24 | n, _ = points.shape 25 | 26 | # equally spaced in arclength -- the number of points that will be uniformly interpolated 27 | eq_spaced_points = np.linspace(0, 1, t) 28 | 29 | # Compute the chordal arclength of each segment. 30 | # Compute differences between each x coord, to get the dx's 31 | # Do the same to get dy's. Then the hypotenuse length is computed as a norm. 32 | chordlen: np.ndarray = np.linalg.norm(np.diff(points, axis=0), axis=1) # type: ignore 33 | # Normalize the arclengths to a unit total 34 | chordlen = chordlen / np.sum(chordlen) 35 | # cumulative arclength 36 | 37 | cumarc: np.ndarray = np.zeros(len(chordlen) + 1) 38 | cumarc[1:] = np.cumsum(chordlen) 39 | 40 | # which interval did each point fall in, in terms of eq_spaced_points? (bin index) 41 | tbins: np.ndarray = np.digitize(eq_spaced_points, bins=cumarc).astype(int) # type: ignore 42 | 43 | # #catch any problems at the ends 44 | tbins[np.where((tbins <= 0) | (eq_spaced_points <= 0))] = 1 # type: ignore 45 | tbins[np.where((tbins >= n) | (eq_spaced_points >= 1))] = n - 1 46 | 47 | chordlen[tbins - 1] = np.where( 48 | chordlen[tbins - 1] == 0, chordlen[tbins - 1] + 1e-6, chordlen[tbins - 1] 49 | ) 50 | 51 | s = np.divide((eq_spaced_points - cumarc[tbins - 1]), chordlen[tbins - 1]) 52 | anchors = points[tbins - 1, :] 53 | # broadcast to scale each row of `points` by a different row of s 54 | offsets = (points[tbins, :] - points[tbins - 1, :]) * s.reshape(-1, 1) 55 | points_interp: np.ndarray = anchors + offsets 56 | 57 | return points_interp 58 | -------------------------------------------------------------------------------- /config/default_training.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${output_dir} 4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging 5 | searchpath: # Only in these paths are discoverable 6 | - pkg://nuplan.planning.script.config.common 7 | - pkg://nuplan.planning.script.config.training 8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/ 9 | - config/training 10 | 11 | defaults: 12 | - default_experiment 13 | - default_common 14 | 15 | # Trainer and callbacks 16 | - lightning: custom_lightning 17 | - callbacks: default_callbacks 18 | 19 | # Optimizer settings 20 | - optimizer: adam # [adam, adamw] supported optimizers 21 | - lr_scheduler: null # [one_cycle_lr] supported lr_schedulers 22 | - warm_up_lr_scheduler: null # [linear_warm_up, constant_warm_up] supported warm up lr schedulers 23 | 24 | # Data Loading 25 | - data_loader: default_data_loader 26 | - splitter: ??? 27 | 28 | # Objectives and metrics 29 | - objective: 30 | - training_metric: 31 | - data_augmentation: null 32 | - data_augmentation_scheduler: null # [default_augmentation_schedulers, stepwise_augmentation_probability_scheduler, stepwise_noise_parameter_scheduler] supported data augmentation schedulers 33 | - scenario_type_weights: default_scenario_type_weights 34 | - custom_trainer: null 35 | 36 | nuplan_trainer: false 37 | experiment_name: "training" 38 | objective_aggregate_mode: ??? # How to aggregate multiple objectives, can be 'mean', 'max', 'sum' 39 | 40 | # Cache parameters 41 | cache: 42 | cache_path: # Local/remote path to store all preprocessed artifacts from the data pipeline 43 | use_cache_without_dataset: false # Load all existing features from a local/remote cache without loading the dataset 44 | force_feature_computation: false # Recompute features even if a cache exists 45 | cleanup_cache: false # Cleanup cached data in the cache_path, this ensures that new data are generated if the same cache_path is passed 46 | 47 | # Mandatory parameters 48 | py_func: ??? # Function to be run inside main (can be "train", "test", "cache") 49 | epochs: 25 50 | warmup_epochs: 3 51 | lr: 1e-3 52 | weight_decay: 0.0001 53 | checkpoint: 54 | 55 | # wandb settings 56 | wandb: 57 | mode: disable 58 | project: nuplan 59 | name: ${experiment_name} 60 | log_model: all 61 | artifact: 62 | run_id: 63 | -------------------------------------------------------------------------------- /config/lightning/custom_lightning.yaml: -------------------------------------------------------------------------------- 1 | distributed_training: 2 | equal_variance_scaling_strategy: true # scales lr and betas either linearly if false (multiply by num GPUs) or with equal_variance if true (multiply by sqaure root of num GPUs) 3 | 4 | trainer: 5 | checkpoint: 6 | resume_training: false # load the model from the last epoch and resume training 7 | save_top_k: 5 # save the top K models in terms of performance 8 | monitor: loss/val_loss # metric to monitor for performance 9 | mode: min # minimize/maximize metric 10 | 11 | params: 12 | # max_time: 00:16:00:00 # training time before the process is terminated 13 | 14 | max_epochs: ${epochs} # maximum number of training epochs 15 | # check_val_every_n_epoch: 1 # run validation set every n training epochs 16 | val_check_interval: 1.0 # [%] run validation set every X% of training set 17 | 18 | limit_train_batches: # how much of training dataset to check (float = fraction, int = num_batches) 19 | limit_val_batches: # how much of validation dataset to check (float = fraction, int = num_batches) 20 | limit_test_batches: # how much of test dataset to check (float = fraction, int = num_batches) 21 | 22 | devices: -1 # number of GPUs to utilize (-1 means all available GPUs) 23 | accelerator: gpu # distribution method 24 | precision: 32 # floating point precision 25 | # amp_level: O2 # AMP optimization level 26 | # num_nodes: 1 # Number of nodes used for training 27 | 28 | # auto_scale_batch_size: false 29 | # auto_lr_find: false # tunes LR before beginning training 30 | # terminate_on_nan: true # terminates training if a nan is encountered in loss/weights 31 | 32 | num_sanity_val_steps: 1 # number of validation steps to run before training begins 33 | fast_dev_run: false # runs 1 batch of train/val/test for sanity 34 | 35 | # accumulate_grad_batches: 1 # accumulates gradients every n batches 36 | # track_grad_norm: -1 # logs the p-norm for inspection 37 | gradient_clip_val: 5.0 # value to clip gradients 38 | gradient_clip_algorithm: norm # [value, norm] method to clip gradients 39 | sync_batchnorm: true 40 | strategy: ddp_find_unused_parameters_false 41 | 42 | # checkpoint_callback: true # enable default checkpoint 43 | 44 | overfitting: 45 | enable: false # run an overfitting test instead of training 46 | 47 | params: 48 | max_epochs: 150 # number of epochs to overfit the same batches 49 | overfit_batches: 1 # number of batches to overfit 50 | -------------------------------------------------------------------------------- /src/models/planTF/layers/transformer_encoder_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from timm.models.layers import DropPath 6 | from torch import Tensor 7 | 8 | 9 | class Mlp(nn.Module): 10 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 11 | 12 | def __init__( 13 | self, 14 | in_features, 15 | hidden_features=None, 16 | out_features=None, 17 | act_layer=nn.GELU, 18 | drop=0.0, 19 | ): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | 24 | self.fc1 = nn.Linear(in_features, hidden_features) 25 | self.act = act_layer() 26 | self.drop1 = nn.Dropout(drop) 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop2 = nn.Dropout(drop) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop1(x) 34 | x = self.fc2(x) 35 | x = self.drop2(x) 36 | return x 37 | 38 | 39 | class TransformerEncoderLayer(nn.Module): 40 | def __init__( 41 | self, 42 | dim, 43 | num_heads, 44 | mlp_ratio=4.0, 45 | qkv_bias=False, 46 | drop=0.0, 47 | attn_drop=0.0, 48 | drop_path=0.0, 49 | act_layer=nn.GELU, 50 | norm_layer=nn.LayerNorm, 51 | ): 52 | super().__init__() 53 | self.norm1 = norm_layer(dim) 54 | self.attn = torch.nn.MultiheadAttention( 55 | dim, 56 | num_heads=num_heads, 57 | add_bias_kv=qkv_bias, 58 | dropout=attn_drop, 59 | batch_first=True, 60 | ) 61 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 62 | 63 | self.norm2 = norm_layer(dim) 64 | self.mlp = Mlp( 65 | in_features=dim, 66 | hidden_features=int(dim * mlp_ratio), 67 | act_layer=act_layer, 68 | drop=drop, 69 | ) 70 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 71 | 72 | def forward( 73 | self, 74 | src, 75 | mask: Optional[Tensor] = None, 76 | key_padding_mask: Optional[Tensor] = None, 77 | ): 78 | src2 = self.norm1(src) 79 | src2 = self.attn( 80 | query=src2, 81 | key=src2, 82 | value=src2, 83 | attn_mask=mask, 84 | key_padding_mask=key_padding_mask, 85 | )[0] 86 | src = src + self.drop_path1(src2) 87 | src = src + self.drop_path2(self.mlp(self.norm2(src))) 88 | return src 89 | -------------------------------------------------------------------------------- /src/models/planTF/modules/map_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..layers.embedding import PointsEncoder 5 | 6 | 7 | class MapEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | polygon_channel=6, 11 | dim=128, 12 | ) -> None: 13 | super().__init__() 14 | 15 | self.dim = dim 16 | self.polygon_encoder = PointsEncoder(polygon_channel, dim) 17 | self.speed_limit_emb = nn.Sequential( 18 | nn.Linear(1, dim), nn.ReLU(), nn.Linear(dim, dim) 19 | ) 20 | 21 | self.type_emb = nn.Embedding(3, dim) 22 | self.on_route_emb = nn.Embedding(2, dim) 23 | self.traffic_light_emb = nn.Embedding(4, dim) 24 | self.unknown_speed_emb = nn.Embedding(1, dim) 25 | 26 | def forward(self, data) -> torch.Tensor: 27 | polygon_center = data["map"]["polygon_center"] 28 | polygon_type = data["map"]["polygon_type"].long() 29 | polygon_on_route = data["map"]["polygon_on_route"].long() 30 | polygon_tl_status = data["map"]["polygon_tl_status"].long() 31 | polygon_has_speed_limit = data["map"]["polygon_has_speed_limit"] 32 | polygon_speed_limit = data["map"]["polygon_speed_limit"] 33 | point_position = data["map"]["point_position"] 34 | point_vector = data["map"]["point_vector"] 35 | point_orientation = data["map"]["point_orientation"] 36 | valid_mask = data["map"]["valid_mask"] 37 | 38 | polygon_feature = torch.cat( 39 | [ 40 | point_position[:, :, 0] - polygon_center[..., None, :2], 41 | point_vector[:, :, 0], 42 | torch.stack( 43 | [ 44 | point_orientation[:, :, 0].cos(), 45 | point_orientation[:, :, 0].sin(), 46 | ], 47 | dim=-1, 48 | ), 49 | ], 50 | dim=-1, 51 | ) 52 | 53 | bs, M, P, C = polygon_feature.shape 54 | valid_mask = valid_mask.view(bs * M, P) 55 | polygon_feature = polygon_feature.reshape(bs * M, P, C) 56 | 57 | x_polygon = self.polygon_encoder(polygon_feature, valid_mask).view(bs, M, -1) 58 | 59 | x_type = self.type_emb(polygon_type) 60 | x_on_route = self.on_route_emb(polygon_on_route) 61 | x_tl_status = self.traffic_light_emb(polygon_tl_status) 62 | x_speed_limit = torch.zeros(bs, M, self.dim, device=x_polygon.device) 63 | x_speed_limit[polygon_has_speed_limit] = self.speed_limit_emb( 64 | polygon_speed_limit[polygon_has_speed_limit].unsqueeze(-1) 65 | ) 66 | x_speed_limit[~polygon_has_speed_limit] = self.unknown_speed_emb.weight 67 | 68 | x_polygon += x_type + x_on_route + x_tl_status + x_speed_limit 69 | 70 | return x_polygon 71 | -------------------------------------------------------------------------------- /docs/other_baselines.md: -------------------------------------------------------------------------------- 1 | # Gallery 2 | 3 | - [RasterModel](#rastermodel) 4 | - [UrbanDriver (open-loop)](#urbandriver-open-loop) 5 | 6 | ## RasterModel 7 | 8 | ### Feature cache 9 | 10 | ``` 11 | python ./run_training.py \ 12 | +training=training_raster_model \ 13 | py_func=cache \ 14 | scenario_builder=nuplan \ 15 | cache.cache_path=/nuplan/exp/cache_rater_model_1M \ 16 | cache.cleanup_cache=true \ 17 | scenario_filter=training_scenarios_1M \ 18 | worker.threads_per_node=40 19 | ``` 20 | 21 | ### Training 22 | 23 | ``` 24 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ./run_training.py \ 25 | +training=training_raster_model \ 26 | py_func=train \ 27 | scenario_builder=nuplan \ 28 | cache.cache_path=/nuplan/exp/cache_raster_model_1M \ 29 | data_loader.params.batch_size=32 \ 30 | data_loader.params.num_workers=32 \ 31 | cache.use_cache_without_dataset=true \ 32 | worker=single_machine_thread_pool \ 33 | worker.max_workers=32 \ 34 | optimizer=adam \ 35 | optimizer.lr=1e-4 \ 36 | lightning.trainer.params.max_epochs=60 \ 37 | lr_scheduler=multistep_lr \ 38 | lr_scheduler.milestones='[20, 40]' \ 39 | lr_scheduler.gamma=0.1 \ 40 | wandb.mode=online wandb.project=nuplan_baseline wandb.name=RasterModel_1M 41 | ``` 42 | 43 | ### Evaluation 44 | 45 | - run **Test14-random**: `sh ./script/raster_model_benchmarks.sh test14-random` 46 | - run **Test14-hard**: `sh ./script/raster_model_benchmarks.sh test14-hard` 47 | - run **Val14** (this may take a long time): `sh ./script/raster_model_benchmarks.sh val14` 48 | 49 | ## UrbanDriver (open-loop) 50 | 51 | ### Feature cache 52 | 53 | ``` 54 | python ./run_training.py \ 55 | +training=training_urban_driver_open_loop_model \ 56 | py_func=cache \ 57 | scenario_builder=nuplan \ 58 | cache.cache_path=/nuplan/exp/cache_urban_driver_1M \ 59 | cache.cleanup_cache=true \ 60 | scenario_filter=training_scenarios_1M \ 61 | worker.threads_per_node=40 62 | ``` 63 | 64 | ### Training 65 | 66 | ``` 67 | CUDA_VISIBLE_DEVICES=0,1,2,3 python ./run_training.py \ 68 | +training=training_urban_driver_open_loop_model \ 69 | py_func=train \ 70 | scenario_builder=nuplan \ 71 | cache.cache_path=/nuplan/exp/cache_urban_driver_1M \ 72 | data_loader.params.batch_size=32 \ 73 | data_loader.params.num_workers=32 \ 74 | cache.use_cache_without_dataset=true \ 75 | worker=single_machine_thread_pool \ 76 | worker.max_workers=32 \ 77 | optimizer=adam \ 78 | optimizer.lr=1e-4 \ 79 | lightning.trainer.params.max_epochs=30 \ 80 | lr_scheduler=multistep_lr \ 81 | lr_scheduler.milestones='[20]' \ 82 | lr_scheduler.gamma=0.1 \ 83 | wandb.mode=online wandb.project=nuplan_baseline wandb.name=urban_driver_open_loop_1M 84 | ``` 85 | 86 | ### Evaluation 87 | 88 | - run **Test14-random**: `sh ./script/urbandriver_benchmarks.sh test14-random` 89 | - run **Test14-hard**: `sh ./script/urbandriver_benchmarks.sh test14-hard` 90 | - run **Val14** (this may take a long time): `sh ./script/urbandriver_benchmarks.sh val14` -------------------------------------------------------------------------------- /run_nuboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Motional 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | from pathlib import Path 18 | 19 | import hydra 20 | from hydra.utils import instantiate 21 | from omegaconf import DictConfig 22 | 23 | from nuplan.common.actor_state.vehicle_parameters import VehicleParameters 24 | from nuplan.planning.nuboard.nuboard import NuBoard 25 | from nuplan.planning.script.builders.scenario_building_builder import ( 26 | build_scenario_builder, 27 | ) 28 | from nuplan.planning.script.builders.utils.utils_config import update_config_for_nuboard 29 | from nuplan.planning.script.utils import set_default_path 30 | 31 | logging.basicConfig(level=logging.INFO) 32 | logger = logging.getLogger(__name__) 33 | 34 | # If set, use the env. variable to overwrite the default dataset and experiment paths 35 | set_default_path() 36 | 37 | # If set, use the env. variable to overwrite the Hydra config 38 | CONFIG_PATH = os.getenv("NUPLAN_HYDRA_CONFIG_PATH", "config/nuboard") 39 | 40 | if os.environ.get("NUPLAN_HYDRA_CONFIG_PATH") is not None: 41 | CONFIG_PATH = os.path.join("../../../../", CONFIG_PATH) 42 | 43 | if os.path.basename(CONFIG_PATH) != "nuboard": 44 | CONFIG_PATH = os.path.join(CONFIG_PATH, "nuboard") 45 | CONFIG_NAME = "default_nuboard" 46 | 47 | 48 | def initialize_nuboard(cfg: DictConfig) -> NuBoard: 49 | """ 50 | Sets up dependencies and instantiates a NuBoard object. 51 | :param cfg: DictConfig. Configuration that is used to run the experiment. 52 | :return: NuBoard object. 53 | """ 54 | # Update and override configs for nuboard 55 | update_config_for_nuboard(cfg=cfg) 56 | 57 | scenario_builder = build_scenario_builder(cfg) 58 | 59 | # Build vehicle parameters 60 | vehicle_parameters: VehicleParameters = instantiate( 61 | cfg.scenario_builder.vehicle_parameters 62 | ) 63 | profiler_path = None 64 | if cfg.profiler_path: 65 | profiler_path = Path(cfg.profiler_path) 66 | 67 | nuboard = NuBoard( 68 | profiler_path=profiler_path, 69 | nuboard_paths=cfg.simulation_path, 70 | scenario_builder=scenario_builder, 71 | port_number=cfg.port_number, 72 | resource_prefix=cfg.resource_prefix, 73 | vehicle_parameters=vehicle_parameters, 74 | ) 75 | 76 | return nuboard 77 | 78 | 79 | @hydra.main(config_path="./config", config_name="default_nuboard") 80 | def main(cfg: DictConfig) -> None: 81 | """ 82 | Execute all available challenges simultaneously on the same scenario. 83 | :param cfg: DictConfig. Configuration that is used to run the experiment. 84 | """ 85 | nuboard = initialize_nuboard(cfg) 86 | nuboard.run() 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | parts/ 18 | sdist/ 19 | var/ 20 | wheels/ 21 | share/python-wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .nox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | *.py,cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | cover/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | db.sqlite3-journal 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | .pybuilder/ 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | # For a library or package, you might want to ignore these files since the code is 85 | # intended to run in multiple environments; otherwise, check them in: 86 | # .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # poetry 96 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 97 | # This is especially recommended for binary packages to ensure reproducibility, and is more 98 | # commonly ignored for libraries. 99 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 100 | #poetry.lock 101 | 102 | # pdm 103 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 104 | #pdm.lock 105 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 106 | # in version control. 107 | # https://pdm.fming.dev/#use-with-ide 108 | .pdm.toml 109 | 110 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 111 | __pypackages__/ 112 | 113 | # Celery stuff 114 | celerybeat-schedule 115 | celerybeat.pid 116 | 117 | # SageMath parsed files 118 | *.sage.py 119 | 120 | # Environments 121 | .venv 122 | # env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | 146 | # pytype static type analyzer 147 | .pytype/ 148 | 149 | # Cython debug symbols 150 | cython_debug/ 151 | 152 | # PyCharm 153 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 154 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 155 | # and can be added to the global gitignore or merged into this file. For a more nuclear 156 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 157 | #.idea/ 158 | 159 | .vscode 160 | wandb/ 161 | outputs/ 162 | *.ckpt 163 | checkpoints/ -------------------------------------------------------------------------------- /run_training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | from nuplan.planning.script.builders.folder_builder import ( 7 | build_training_experiment_folder, 8 | ) 9 | from nuplan.planning.script.builders.logging_builder import build_logger 10 | from nuplan.planning.script.builders.worker_pool_builder import build_worker 11 | from nuplan.planning.script.profiler_context_manager import ProfilerContextManager 12 | from nuplan.planning.script.utils import set_default_path 13 | from nuplan.planning.training.experiments.caching import cache_data 14 | from omegaconf import DictConfig 15 | 16 | from src.custom_training.custom_training_builder import ( 17 | TrainingEngine, 18 | build_training_engine, 19 | update_config_for_training, 20 | ) 21 | 22 | logging.getLogger("numba").setLevel(logging.WARNING) 23 | logger = logging.getLogger(__name__) 24 | 25 | # If set, use the env. variable to overwrite the default dataset and experiment paths 26 | set_default_path() 27 | 28 | 29 | @hydra.main(config_path="./config", config_name="default_training") 30 | def main(cfg: DictConfig) -> Optional[TrainingEngine]: 31 | """ 32 | Main entrypoint for training/validation experiments. 33 | :param cfg: omegaconf dictionary 34 | """ 35 | pl.seed_everything(cfg.seed, workers=True) 36 | 37 | # Configure logger 38 | build_logger(cfg) 39 | 40 | # Override configs based on setup, and print config 41 | update_config_for_training(cfg) 42 | 43 | # Create output storage folder 44 | build_training_experiment_folder(cfg=cfg) 45 | 46 | # Build worker 47 | worker = build_worker(cfg) 48 | 49 | if cfg.py_func == "train": 50 | # Build training engine 51 | with ProfilerContextManager( 52 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 53 | ): 54 | engine = build_training_engine(cfg, worker) 55 | 56 | # Run training 57 | logger.info("Starting training...") 58 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "training"): 59 | engine.trainer.fit( 60 | model=engine.model, 61 | datamodule=engine.datamodule, 62 | ckpt_path=cfg.checkpoint, 63 | ) 64 | return engine 65 | if cfg.py_func == "validate": 66 | # Build training engine 67 | with ProfilerContextManager( 68 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 69 | ): 70 | engine = build_training_engine(cfg, worker) 71 | 72 | # Run training 73 | logger.info("Starting training...") 74 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "validate"): 75 | engine.trainer.validate( 76 | model=engine.model, 77 | datamodule=engine.datamodule, 78 | ckpt_path=cfg.checkpoint, 79 | ) 80 | return engine 81 | elif cfg.py_func == "test": 82 | # Build training engine 83 | with ProfilerContextManager( 84 | cfg.output_dir, cfg.enable_profiling, "build_training_engine" 85 | ): 86 | engine = build_training_engine(cfg, worker) 87 | 88 | # Test model 89 | logger.info("Starting testing...") 90 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "testing"): 91 | engine.trainer.test(model=engine.model, datamodule=engine.datamodule) 92 | return engine 93 | elif cfg.py_func == "cache": 94 | # Precompute and cache all features 95 | logger.info("Starting caching...") 96 | with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "caching"): 97 | cache_data(cfg=cfg, worker=worker) 98 | return None 99 | else: 100 | raise NameError(f"Function {cfg.py_func} does not exist") 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /config/default_simulation.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${output_dir} 4 | output_subdir: ${output_dir}/code/hydra # Store hydra's config breakdown here for debugging 5 | searchpath: # Only in these paths are discoverable 6 | - pkg://nuplan.planning.script.config.common 7 | - pkg://nuplan.planning.script.config.simulation 8 | - pkg://nuplan.planning.script.experiments # Put experiments configs in script/experiments/ 9 | - config/simulation 10 | - config/scenario_filter 11 | 12 | defaults: 13 | # Add ungrouped items 14 | - default_experiment 15 | - default_common 16 | - default_submission 17 | 18 | - simulation_metric: 19 | - default_metrics 20 | - callback: 21 | - simulation_log_callback 22 | - main_callback: 23 | - time_callback 24 | - metric_file_callback 25 | - metric_aggregator_callback 26 | - metric_summary_callback 27 | - splitter: nuplan 28 | 29 | # Hyperparameters need to be specified 30 | - observation: null 31 | - ego_controller: null 32 | - planner: null 33 | - simulation_time_controller: step_simulation_time_controller 34 | - metric_aggregator: 35 | - default_weighted_average 36 | 37 | - override hydra/job_logging: none # Disable hydra's logging 38 | - override hydra/hydra_logging: none # Disable hydra's logging 39 | 40 | experiment_name: 'simulation' 41 | aggregated_metric_folder_name: 'aggregator_metric' # Aggregated metric folder name 42 | aggregator_save_path: ${output_dir}/${aggregated_metric_folder_name} 43 | 44 | 45 | # Progress Visualization 46 | enable_simulation_progress_bar: true # Show for every simulation its progress 47 | 48 | # Simulation Setup 49 | simulation_history_buffer_duration: 2.0 # [s] The look back duration to initialize the simulation history buffer with 50 | 51 | # Number (or fractional, e.g., 0.25) of GPUs available for single simulation (per scenario and planner). 52 | # This number can also be < 1 because we allow multiple models to be loaded into a single GPU. 53 | # In case this number is 0 or null, no GPU is used for simulation and all cpu cores are leveraged 54 | # Note, that the user have to make sure that if a number < 1 is chosen, the model will fit 1 / num_gpus into GPU memory 55 | number_of_gpus_allocated_per_simulation: 1 56 | 57 | # This number specifies number of CPU threads that are used for simulation 58 | # In case this is null, then each simulation will use unlimited resources. 59 | # That will typically swamp the host computer, leading to slowdowns and failure. 60 | number_of_cpus_allocated_per_simulation: 1 61 | 62 | # Set false to disable metric computation 63 | run_metric: true 64 | 65 | # Set to rerun metrics with existing simulation logs without setting run_metric to false. 66 | simulation_log_main_path: null 67 | 68 | # If false, continue running the simulation even it a scenario has failed 69 | exit_on_failure: false 70 | 71 | # Maximum number of workers to be used for running simulation callbacks outside the main process 72 | max_callback_workers: 4 73 | 74 | # Disable callback parallelization when using the Sequential worker. By default, when running with the sequential worker, 75 | # on_simulation_end callbacks are not submitted to a parallel worker. 76 | disable_callback_parallelization: true 77 | 78 | # Distributed processing mode. If multi-node simulation is enable, this parameter selects how the scenarios distributed 79 | # to each node. The modes are: 80 | # - SCENARIO_BASED: Works in two stages, first getting a list of all, scenarios to process, then breaking up that 81 | # list and distributing across the workers 82 | # - LOG_FILE_BASED: Works in a single stage, breaking up the scenarios based on what log file they are in and 83 | # distributing the number of log files evenly across all workers 84 | # - SINGLE_NODE: Does no distribution, processes all scenarios in config 85 | distributed_mode: 'SINGLE_NODE' -------------------------------------------------------------------------------- /src/utils/collision_checker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nuplan.common.actor_state.vehicle_parameters import ( 3 | VehicleParameters, 4 | get_pacifica_parameters, 5 | ) 6 | 7 | 8 | class CollisionChecker: 9 | def __init__( 10 | self, 11 | vehicle: VehicleParameters = get_pacifica_parameters(), 12 | ) -> None: 13 | self._vehicle = vehicle 14 | self._sdc_half_length = vehicle.length / 2 15 | self._sdc_half_width = vehicle.width / 2 16 | 17 | self._sdc_normalized_corners = torch.stack( 18 | [ 19 | torch.tensor([vehicle.length / 2, vehicle.width / 2]), 20 | torch.tensor([vehicle.length / 2, -vehicle.width / 2]), 21 | torch.tensor([-vehicle.length / 2, -vehicle.width / 2]), 22 | torch.tensor([-vehicle.length / 2, vehicle.width / 2]), 23 | ], 24 | dim=0, 25 | ) 26 | 27 | def to_device(self, device): 28 | self._sdc_normalized_corners = self._sdc_normalized_corners.to(device) 29 | 30 | def build_bbox_from_center(self, center, heading, width, length): 31 | """ 32 | params: 33 | center: [bs, N, (x, y)] 34 | heading: [bs, N] 35 | width: [bs, N] 36 | length: [bs, N] 37 | return: 38 | corners: [bs, 4, (x, y)] 39 | heading_vec, tanh_vec: [bs, 2] 40 | """ 41 | cos = torch.cos(heading) 42 | sin = torch.sin(heading) 43 | 44 | heading_vec = torch.stack([cos, sin], dim=-1) * length.unsqueeze(-1) / 2 45 | tanh_vec = torch.stack([-sin, cos], dim=-1) * width.unsqueeze(-1) / 2 46 | 47 | corners = torch.stack( 48 | [ 49 | center + heading_vec + tanh_vec, 50 | center - heading_vec + tanh_vec, 51 | center - heading_vec - tanh_vec, 52 | center + heading_vec - tanh_vec, 53 | ], 54 | dim=-2, 55 | ) 56 | 57 | return corners, heading_vec, tanh_vec 58 | 59 | def collision_check(self, ego_state, objects, objects_width, objects_length): 60 | """performing batch-wise collision check using Separating Axis Theorem 61 | params: 62 | ego_states: [bs, (x, y, theta)], center of the ego 63 | objects: [bs, N, (x, y, theta)], center of the objects 64 | returns: 65 | is_collided: [bs, N] 66 | """ 67 | 68 | bs, N = objects.shape[:2] 69 | 70 | # rotate object to ego's local frame 71 | cos, sin = torch.cos(ego_state[:, 2]), torch.sin(ego_state[:, 2]) 72 | rotate_mat = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(bs, 2, 2) 73 | 74 | rotated_objects = objects.clone() 75 | rotated_objects[..., :2] = torch.matmul( 76 | rotated_objects[..., :2] - ego_state[:, :2].unsqueeze(1), rotate_mat 77 | ) 78 | rotated_objects[..., 2] -= ego_state[..., 2].unsqueeze(1) 79 | 80 | # [bs, N, 4, 2], [bs, N, 2], [bs, N, 2] 81 | object_corners, axis1, axis2 = self.build_bbox_from_center( 82 | rotated_objects[..., :2], 83 | rotated_objects[..., 2], 84 | objects_width, 85 | objects_length, 86 | ) 87 | 88 | ego_corners = self._sdc_normalized_corners.reshape(1, 1, 4, 2).repeat( 89 | bs, N, 1, 1 90 | ) # [bs, N, 4, 2] 91 | 92 | all_corners = torch.concat( 93 | [object_corners, ego_corners], dim=-2 94 | ) # [bs, N, 8, 2] 95 | 96 | x_projection = object_corners[..., 0] 97 | y_projection = object_corners[..., 1] 98 | axis1_projection = torch.matmul(all_corners, axis1.unsqueeze(-1)).squeeze(-1) 99 | axis2_projection = torch.matmul(all_corners, axis2.unsqueeze(-1)).squeeze(-1) 100 | 101 | x_separated = (x_projection.max(-1)[0] < -self._sdc_half_length) | ( 102 | x_projection.min(-1)[0] > self._sdc_half_length 103 | ) 104 | y_separated = (y_projection.max(-1)[0] < -self._sdc_half_width) | ( 105 | y_projection.min(-1)[0] > self._sdc_half_width 106 | ) 107 | axis1_separated = ( 108 | axis1_projection[..., :4].max(-1)[0] < axis1_projection[..., 4:].min(-1)[0] 109 | ) | ( 110 | axis1_projection[..., :4].min(-1)[0] > axis1_projection[..., 4:].max(-1)[0] 111 | ) 112 | axis2_separated = ( 113 | axis2_projection[..., :4].max(-1)[0] < axis2_projection[..., 4:].min(-1)[0] 114 | ) | ( 115 | axis2_projection[..., :4].min(-1)[0] > axis2_projection[..., 4:].max(-1)[0] 116 | ) 117 | 118 | collision = ~(x_separated | y_separated | axis1_separated | axis2_separated) 119 | 120 | return collision 121 | -------------------------------------------------------------------------------- /src/features/nuplan_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List 5 | 6 | import numpy as np 7 | import torch 8 | from nuplan.planning.training.preprocessing.features.abstract_model_feature import ( 9 | AbstractModelFeature, 10 | ) 11 | from torch.nn.utils.rnn import pad_sequence 12 | 13 | from src.utils.conversion import to_device, to_numpy, to_tensor 14 | 15 | 16 | @dataclass 17 | class NuplanFeature(AbstractModelFeature): 18 | data: Dict[str, Any] 19 | 20 | @classmethod 21 | def collate(cls, feature_list: List[NuplanFeature]) -> NuplanFeature: 22 | batch_data = {} 23 | for key in ["agent", "map"]: 24 | batch_data[key] = { 25 | k: pad_sequence( 26 | [f.data[key][k] for f in feature_list], batch_first=True 27 | ) 28 | for k in feature_list[0].data[key].keys() 29 | } 30 | for key in ["current_state", "origin", "angle"]: 31 | batch_data[key] = torch.stack([f.data[key] for f in feature_list], dim=0) 32 | 33 | return NuplanFeature(data=batch_data) 34 | 35 | def to_feature_tensor(self) -> NuplanFeature: 36 | new_data = {} 37 | for k, v in self.data.items(): 38 | new_data[k] = to_tensor(v) 39 | return NuplanFeature(data=new_data) 40 | 41 | def to_numpy(self) -> NuplanFeature: 42 | new_data = {} 43 | for k, v in self.data.items(): 44 | new_data[k] = to_numpy(v) 45 | return NuplanFeature(data=new_data) 46 | 47 | def to_device(self, device: torch.device) -> NuplanFeature: 48 | new_data = {} 49 | for k, v in self.data.items(): 50 | new_data[k] = to_device(v, device) 51 | return NuplanFeature(data=new_data) 52 | 53 | def serialize(self) -> Dict[str, Any]: 54 | return self.data 55 | 56 | @classmethod 57 | def deserialize(cls, data: Dict[str, Any]) -> NuplanFeature: 58 | return NuplanFeature(data=data) 59 | 60 | def unpack(self) -> List[AbstractModelFeature]: 61 | raise NotImplementedError 62 | 63 | def is_valid(self) -> bool: 64 | return self.data["polylines"].shape[0] > 0 65 | 66 | @classmethod 67 | def normalize( 68 | self, data, first_time=False, radius=None, hist_steps=21 69 | ) -> NuplanFeature: 70 | cur_state = data["current_state"] 71 | center_xy, center_angle = cur_state[:2].copy(), cur_state[2].copy() 72 | 73 | rotate_mat = np.array( 74 | [ 75 | [np.cos(center_angle), -np.sin(center_angle)], 76 | [np.sin(center_angle), np.cos(center_angle)], 77 | ], 78 | dtype=np.float64, 79 | ) 80 | 81 | data["current_state"][:3] = 0 82 | data["agent"]["position"] = np.matmul( 83 | data["agent"]["position"] - center_xy, rotate_mat 84 | ) 85 | data["agent"]["velocity"] = np.matmul(data["agent"]["velocity"], rotate_mat) 86 | data["agent"]["heading"] -= center_angle 87 | 88 | data["map"]["point_position"] = np.matmul( 89 | data["map"]["point_position"] - center_xy, rotate_mat 90 | ) 91 | data["map"]["point_vector"] = np.matmul(data["map"]["point_vector"], rotate_mat) 92 | data["map"]["point_orientation"] -= center_angle 93 | 94 | data["map"]["polygon_center"][..., :2] = np.matmul( 95 | data["map"]["polygon_center"][..., :2] - center_xy, rotate_mat 96 | ) 97 | data["map"]["polygon_center"][..., 2] -= center_angle 98 | data["map"]["polygon_position"] = np.matmul( 99 | data["map"]["polygon_position"] - center_xy, rotate_mat 100 | ) 101 | data["map"]["polygon_orientation"] -= center_angle 102 | 103 | target_position = ( 104 | data["agent"]["position"][:, hist_steps:] 105 | - data["agent"]["position"][:, hist_steps - 1][:, None] 106 | ) 107 | target_heading = ( 108 | data["agent"]["heading"][:, hist_steps:] 109 | - data["agent"]["heading"][:, hist_steps - 1][:, None] 110 | ) 111 | target = np.concatenate([target_position, target_heading[..., None]], -1) 112 | target[~data["agent"]["valid_mask"][:, hist_steps:]] = 0 113 | data["agent"]["target"] = target 114 | 115 | if first_time: 116 | point_position = data["map"]["point_position"] 117 | x_max, x_min = radius, -radius 118 | y_max, y_min = radius, -radius 119 | valid_mask = ( 120 | (point_position[:, 0, :, 0] < x_max) 121 | & (point_position[:, 0, :, 0] > x_min) 122 | & (point_position[:, 0, :, 1] < y_max) 123 | & (point_position[:, 0, :, 1] > y_min) 124 | ) 125 | valid_polygon = valid_mask.any(-1) 126 | data["map"]["valid_mask"] = valid_mask 127 | 128 | for k, v in data["map"].items(): 129 | data["map"][k] = v[valid_polygon] 130 | 131 | data["origin"] = center_xy 132 | data["angle"] = center_angle 133 | 134 | return NuplanFeature(data=data) 135 | -------------------------------------------------------------------------------- /src/models/planTF/modules/agent_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..layers.common_layers import build_mlp 5 | from ..layers.embedding import NATSequenceEncoder 6 | 7 | 8 | class AgentEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | state_channel=6, 12 | history_channel=9, 13 | dim=128, 14 | hist_steps=21, 15 | use_ego_history=False, 16 | drop_path=0.2, 17 | state_attn_encoder=True, 18 | state_dropout=0.75, 19 | ) -> None: 20 | super().__init__() 21 | self.dim = dim 22 | self.state_channel = state_channel 23 | self.use_ego_history = use_ego_history 24 | self.hist_steps = hist_steps 25 | self.state_attn_encoder = state_attn_encoder 26 | 27 | self.history_encoder = NATSequenceEncoder( 28 | in_chans=history_channel, embed_dim=dim // 4, drop_path_rate=drop_path 29 | ) 30 | 31 | if not use_ego_history: 32 | if not self.state_attn_encoder: 33 | self.ego_state_emb = build_mlp(state_channel, [dim] * 2, norm="bn") 34 | else: 35 | self.ego_state_emb = StateAttentionEncoder( 36 | state_channel, dim, state_dropout 37 | ) 38 | 39 | self.type_emb = nn.Embedding(4, dim) 40 | 41 | @staticmethod 42 | def to_vector(feat, valid_mask): 43 | vec_mask = valid_mask[..., :-1] & valid_mask[..., 1:] 44 | 45 | while len(vec_mask.shape) < len(feat.shape): 46 | vec_mask = vec_mask.unsqueeze(-1) 47 | 48 | return torch.where( 49 | vec_mask, 50 | feat[:, :, 1:, ...] - feat[:, :, :-1, ...], 51 | torch.zeros_like(feat[:, :, 1:, ...]), 52 | ) 53 | 54 | def forward(self, data): 55 | T = self.hist_steps 56 | 57 | position = data["agent"]["position"][:, :, :T] 58 | heading = data["agent"]["heading"][:, :, :T] 59 | velocity = data["agent"]["velocity"][:, :, :T] 60 | shape = data["agent"]["shape"][:, :, :T] 61 | category = data["agent"]["category"].long() 62 | valid_mask = data["agent"]["valid_mask"][:, :, :T] 63 | 64 | heading_vec = self.to_vector(heading, valid_mask) 65 | valid_mask_vec = valid_mask[..., 1:] & valid_mask[..., :-1] 66 | agent_feature = torch.cat( 67 | [ 68 | self.to_vector(position, valid_mask), 69 | self.to_vector(velocity, valid_mask), 70 | torch.stack([heading_vec.cos(), heading_vec.sin()], dim=-1), 71 | shape[:, :, 1:], 72 | valid_mask_vec.float().unsqueeze(-1), 73 | ], 74 | dim=-1, 75 | ) 76 | bs, A, T, _ = agent_feature.shape 77 | agent_feature = agent_feature.view(bs * A, T, -1) 78 | valid_agent_mask = valid_mask.any(-1).flatten() 79 | 80 | x_agent_tmp = self.history_encoder( 81 | agent_feature[valid_agent_mask].permute(0, 2, 1).contiguous() 82 | ) 83 | x_agent = torch.zeros(bs * A, self.dim, device=position.device) 84 | x_agent[valid_agent_mask] = x_agent_tmp 85 | x_agent = x_agent.view(bs, A, self.dim) 86 | 87 | if not self.use_ego_history: 88 | ego_feature = data["current_state"][:, : self.state_channel] 89 | x_ego = self.ego_state_emb(ego_feature) 90 | x_agent[:, 0] = x_ego 91 | 92 | x_type = self.type_emb(category) 93 | 94 | return x_agent + x_type 95 | 96 | 97 | class StateAttentionEncoder(nn.Module): 98 | def __init__(self, state_channel, dim, state_dropout=0.5) -> None: 99 | super().__init__() 100 | 101 | self.state_channel = state_channel 102 | self.state_dropout = state_dropout 103 | self.linears = nn.ModuleList([nn.Linear(1, dim) for _ in range(state_channel)]) 104 | self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=4, batch_first=True) 105 | self.pos_embed = nn.Parameter(torch.Tensor(1, state_channel, dim)) 106 | self.query = nn.Parameter(torch.Tensor(1, 1, dim)) 107 | 108 | nn.init.normal_(self.pos_embed, std=0.02) 109 | nn.init.normal_(self.query, std=0.02) 110 | 111 | def forward(self, x): 112 | x_embed = [] 113 | for i, linear in enumerate(self.linears): 114 | x_embed.append(linear(x[:, i, None])) 115 | x_embed = torch.stack(x_embed, dim=1) 116 | pos_embed = self.pos_embed.repeat(x_embed.shape[0], 1, 1) 117 | x_embed += pos_embed 118 | 119 | if self.training and self.state_dropout > 0: 120 | visible_tokens = torch.zeros( 121 | (x_embed.shape[0], 3), device=x.device, dtype=torch.bool 122 | ) 123 | dropout_tokens = ( 124 | torch.rand((x_embed.shape[0], self.state_channel - 3), device=x.device) 125 | < self.state_dropout 126 | ) 127 | key_padding_mask = torch.concat([visible_tokens, dropout_tokens], dim=1) 128 | else: 129 | key_padding_mask = None 130 | 131 | query = self.query.repeat(x_embed.shape[0], 1, 1) 132 | 133 | x_state = self.attn( 134 | query=query, 135 | key=x_embed, 136 | value=x_embed, 137 | key_padding_mask=key_padding_mask, 138 | )[0] 139 | 140 | return x_state[:, 0] 141 | -------------------------------------------------------------------------------- /run_simulation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from typing import List, Optional, Union 7 | 8 | import hydra 9 | import pandas as pd 10 | import pytorch_lightning as pl 11 | from nuplan.common.utils.s3_utils import is_s3_path 12 | from nuplan.planning.script.builders.simulation_builder import build_simulations 13 | from nuplan.planning.script.builders.simulation_callback_builder import ( 14 | build_callbacks_worker, 15 | build_simulation_callbacks, 16 | ) 17 | from nuplan.planning.script.utils import ( 18 | run_runners, 19 | set_default_path, 20 | set_up_common_builder, 21 | ) 22 | from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner 23 | from omegaconf import DictConfig, OmegaConf 24 | 25 | logging.basicConfig(level=logging.INFO) 26 | logger = logging.getLogger(__name__) 27 | 28 | # If set, use the env. variable to overwrite the default dataset and experiment paths 29 | set_default_path() 30 | 31 | # If set, use the env. variable to overwrite the Hydra config 32 | CONFIG_PATH = os.getenv("NUPLAN_HYDRA_CONFIG_PATH", "config/simulation") 33 | 34 | 35 | def print_simulation_results(file=None): 36 | if file is not None: 37 | df = pd.read_parquet(file) 38 | else: 39 | root = Path(os.getcwd()) / "aggregator_metric" 40 | result = list(root.glob("*.parquet")) 41 | result = max(result, key=lambda item: item.stat().st_ctime) 42 | df = pd.read_parquet(result) 43 | final_score = df[df["scenario"] == "final_score"] 44 | final_score = final_score.to_dict(orient="records")[0] 45 | pprint.PrettyPrinter(indent=4).pprint(final_score) 46 | 47 | 48 | def run_simulation( 49 | cfg: DictConfig, 50 | planners: Optional[Union[AbstractPlanner, List[AbstractPlanner]]] = None, 51 | ) -> None: 52 | """ 53 | Execute all available challenges simultaneously on the same scenario. Helper function for main to allow planner to 54 | be specified via config or directly passed as argument. 55 | :param cfg: Configuration that is used to run the experiment. 56 | Already contains the changes merged from the experiment's config to default config. 57 | :param planners: Pre-built planner(s) to run in simulation. Can either be a single planner or list of planners. 58 | """ 59 | # Fix random seed 60 | pl.seed_everything(cfg.seed, workers=True) 61 | 62 | profiler_name = "building_simulation" 63 | common_builder = set_up_common_builder(cfg=cfg, profiler_name=profiler_name) 64 | 65 | # Build simulation callbacks 66 | callbacks_worker_pool = build_callbacks_worker(cfg) 67 | callbacks = build_simulation_callbacks( 68 | cfg=cfg, output_dir=common_builder.output_dir, worker=callbacks_worker_pool 69 | ) 70 | 71 | # Remove planner from config to make sure run_simulation does not receive multiple planner specifications. 72 | if planners and "planner" in cfg.keys(): 73 | logger.info("Using pre-instantiated planner. Ignoring planner in config") 74 | OmegaConf.set_struct(cfg, False) 75 | cfg.pop("planner") 76 | OmegaConf.set_struct(cfg, True) 77 | 78 | # Construct simulations 79 | if isinstance(planners, AbstractPlanner): 80 | planners = [planners] 81 | 82 | runners = build_simulations( 83 | cfg=cfg, 84 | callbacks=callbacks, 85 | worker=common_builder.worker, 86 | pre_built_planners=planners, 87 | callbacks_worker=callbacks_worker_pool, 88 | ) 89 | 90 | if common_builder.profiler: 91 | # Stop simulation construction profiling 92 | common_builder.profiler.save_profiler(profiler_name) 93 | 94 | logger.info("Running simulation...") 95 | run_runners( 96 | runners=runners, 97 | common_builder=common_builder, 98 | cfg=cfg, 99 | profiler_name="running_simulation", 100 | ) 101 | logger.info("Finished running simulation!") 102 | 103 | 104 | def clean_up_s3_artifacts() -> None: 105 | """ 106 | Cleanup lingering s3 artifacts that are written locally. 107 | This happens because some minor write-to-s3 functionality isn't yet implemented. 108 | """ 109 | # Lingering artifacts get written locally to a 's3:' directory. Hydra changes 110 | # the working directory to a subdirectory of this, so we serach the working 111 | # path for it. 112 | working_path = os.getcwd() 113 | s3_dirname = "s3:" 114 | s3_ind = working_path.find(s3_dirname) 115 | if s3_ind != -1: 116 | local_s3_path = working_path[: working_path.find(s3_dirname) + len(s3_dirname)] 117 | rmtree(local_s3_path) 118 | 119 | 120 | @hydra.main(config_path="./config", config_name="default_simulation") 121 | def main(cfg: DictConfig) -> None: 122 | """ 123 | Execute all available challenges simultaneously on the same scenario. Calls run_simulation to allow planner to 124 | be specified via config or directly passed as argument. 125 | :param cfg: Configuration that is used to run the experiment. 126 | Already contains the changes merged from the experiment's config to default config. 127 | """ 128 | assert ( 129 | cfg.simulation_log_main_path is None 130 | ), "Simulation_log_main_path must not be set when running simulation." 131 | 132 | run_simulation(cfg=cfg) 133 | 134 | if is_s3_path(Path(cfg.output_dir)): 135 | clean_up_s3_artifacts() 136 | 137 | print_simulation_results() 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /src/models/planTF/planning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling 4 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 5 | from nuplan.planning.training.preprocessing.target_builders.ego_trajectory_target_builder import ( 6 | EgoTrajectoryTargetBuilder, 7 | ) 8 | 9 | from src.feature_builders.nuplan_feature_builder import NuplanFeatureBuilder 10 | 11 | from .layers.common_layers import build_mlp 12 | from .layers.transformer_encoder_layer import TransformerEncoderLayer 13 | from .modules.agent_encoder import AgentEncoder 14 | from .modules.map_encoder import MapEncoder 15 | from .modules.trajectory_decoder import TrajectoryDecoder 16 | 17 | # no meaning, required by nuplan 18 | trajectory_sampling = TrajectorySampling(num_poses=8, time_horizon=8, interval_length=1) 19 | 20 | 21 | class PlanningModel(TorchModuleWrapper): 22 | def __init__( 23 | self, 24 | dim=128, 25 | state_channel=6, 26 | polygon_channel=6, 27 | history_channel=9, 28 | history_steps=21, 29 | future_steps=80, 30 | encoder_depth=4, 31 | drop_path=0.2, 32 | num_heads=8, 33 | num_modes=6, 34 | use_ego_history=False, 35 | state_attn_encoder=True, 36 | state_dropout=0.75, 37 | feature_builder: NuplanFeatureBuilder = NuplanFeatureBuilder(), 38 | ) -> None: 39 | super().__init__( 40 | feature_builders=[feature_builder], 41 | target_builders=[EgoTrajectoryTargetBuilder(trajectory_sampling)], 42 | future_trajectory_sampling=trajectory_sampling, 43 | ) 44 | 45 | self.dim = dim 46 | self.history_steps = history_steps 47 | self.future_steps = future_steps 48 | 49 | self.pos_emb = build_mlp(4, [dim] * 2) 50 | self.agent_encoder = AgentEncoder( 51 | state_channel=state_channel, 52 | history_channel=history_channel, 53 | dim=dim, 54 | hist_steps=history_steps, 55 | drop_path=drop_path, 56 | use_ego_history=use_ego_history, 57 | state_attn_encoder=state_attn_encoder, 58 | state_dropout=state_dropout, 59 | ) 60 | 61 | self.map_encoder = MapEncoder( 62 | dim=dim, 63 | polygon_channel=polygon_channel, 64 | ) 65 | 66 | self.encoder_blocks = nn.ModuleList( 67 | TransformerEncoderLayer(dim=dim, num_heads=num_heads, drop_path=dp) 68 | for dp in [x.item() for x in torch.linspace(0, drop_path, encoder_depth)] 69 | ) 70 | self.norm = nn.LayerNorm(dim) 71 | 72 | self.trajectory_decoder = TrajectoryDecoder( 73 | embed_dim=dim, 74 | num_modes=num_modes, 75 | future_steps=future_steps, 76 | out_channels=4, 77 | ) 78 | self.agent_predictor = build_mlp(dim, [dim * 2, future_steps * 2], norm="ln") 79 | 80 | self.apply(self._init_weights) 81 | 82 | def _init_weights(self, m): 83 | if isinstance(m, nn.Linear): 84 | torch.nn.init.xavier_uniform_(m.weight) 85 | if isinstance(m, nn.Linear) and m.bias is not None: 86 | nn.init.constant_(m.bias, 0) 87 | elif isinstance(m, nn.LayerNorm): 88 | nn.init.constant_(m.bias, 0) 89 | nn.init.constant_(m.weight, 1.0) 90 | elif isinstance(m, nn.BatchNorm1d): 91 | nn.init.ones_(m.weight) 92 | nn.init.zeros_(m.bias) 93 | elif isinstance(m, nn.Embedding): 94 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 95 | 96 | def forward(self, data): 97 | agent_pos = data["agent"]["position"][:, :, self.history_steps - 1] 98 | agent_heading = data["agent"]["heading"][:, :, self.history_steps - 1] 99 | agent_mask = data["agent"]["valid_mask"][:, :, : self.history_steps] 100 | polygon_center = data["map"]["polygon_center"] 101 | polygon_mask = data["map"]["valid_mask"] 102 | 103 | bs, A = agent_pos.shape[0:2] 104 | 105 | position = torch.cat([agent_pos, polygon_center[..., :2]], dim=1) 106 | angle = torch.cat([agent_heading, polygon_center[..., 2]], dim=1) 107 | pos = torch.cat( 108 | [position, torch.stack([angle.cos(), angle.sin()], dim=-1)], dim=-1 109 | ) 110 | pos_embed = self.pos_emb(pos) 111 | 112 | agent_key_padding = ~(agent_mask.any(-1)) 113 | polygon_key_padding = ~(polygon_mask.any(-1)) 114 | key_padding_mask = torch.cat([agent_key_padding, polygon_key_padding], dim=-1) 115 | 116 | x_agent = self.agent_encoder(data) 117 | x_polygon = self.map_encoder(data) 118 | 119 | x = torch.cat([x_agent, x_polygon], dim=1) + pos_embed 120 | 121 | for blk in self.encoder_blocks: 122 | x = blk(x, key_padding_mask=key_padding_mask) 123 | x = self.norm(x) 124 | 125 | trajectory, probability = self.trajectory_decoder(x[:, 0]) 126 | prediction = self.agent_predictor(x[:, 1:A]).view(bs, -1, self.future_steps, 2) 127 | 128 | out = { 129 | "trajectory": trajectory, 130 | "probability": probability, 131 | "prediction": prediction, 132 | } 133 | 134 | if not self.training: 135 | best_mode = probability.argmax(dim=-1) 136 | output_trajectory = trajectory[torch.arange(bs), best_mode] 137 | angle = torch.atan2(output_trajectory[..., 3], output_trajectory[..., 2]) 138 | out["output_trajectory"] = torch.cat( 139 | [output_trajectory[..., :2], angle.unsqueeze(-1)], dim=-1 140 | ) 141 | 142 | return out 143 | -------------------------------------------------------------------------------- /src/feature_builders/common/bfs_roadblock.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Dict, Optional, Tuple, Union, List 3 | 4 | from nuplan.common.maps.abstract_map import AbstractMap 5 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject 6 | 7 | 8 | class BreadthFirstSearchRoadBlock: 9 | """ 10 | A class that performs iterative breadth first search. The class operates on the roadblock graph. 11 | """ 12 | 13 | def __init__( 14 | self, start_roadblock_id: int, map_api: Optional[AbstractMap], forward_search: str = True 15 | ): 16 | """ 17 | Constructor of BreadthFirstSearchRoadBlock class 18 | :param start_roadblock_id: roadblock id where graph starts 19 | :param map_api: map class in nuPlan 20 | :param forward_search: whether to search in driving direction, defaults to True 21 | """ 22 | self._map_api: Optional[AbstractMap] = map_api 23 | self._queue = deque([self.id_to_roadblock(start_roadblock_id), None]) 24 | self._parent: Dict[str, Optional[RoadBlockGraphEdgeMapObject]] = dict() 25 | self._forward_search = forward_search 26 | 27 | # lazy loaded 28 | self._target_roadblock_ids: List[str] = None 29 | 30 | def search( 31 | self, target_roadblock_id: Union[str, List[str]], max_depth: int 32 | ) -> Tuple[List[RoadBlockGraphEdgeMapObject], bool]: 33 | """ 34 | Apply BFS to find route to target roadblock. 35 | :param target_roadblock_id: id of target roadblock 36 | :param max_depth: maximum search depth 37 | :return: tuple of route and whether a path was found 38 | """ 39 | 40 | if isinstance(target_roadblock_id, str): 41 | target_roadblock_id = [target_roadblock_id] 42 | self._target_roadblock_ids = target_roadblock_id 43 | 44 | start_edge = self._queue[0] 45 | 46 | # Initial search states 47 | path_found: bool = False 48 | end_edge: RoadBlockGraphEdgeMapObject = start_edge 49 | end_depth: int = 1 50 | depth: int = 1 51 | 52 | self._parent[start_edge.id + f"_{depth}"] = None 53 | 54 | while self._queue: 55 | current_edge = self._queue.popleft() 56 | 57 | # Early exit condition 58 | if self._check_end_condition(depth, max_depth): 59 | break 60 | 61 | # Depth tracking 62 | if current_edge is None: 63 | depth += 1 64 | self._queue.append(None) 65 | if self._queue[0] is None: 66 | break 67 | continue 68 | 69 | # Goal condition 70 | if self._check_goal_condition(current_edge, depth, max_depth): 71 | end_edge = current_edge 72 | end_depth = depth 73 | path_found = True 74 | break 75 | 76 | neighbors = ( 77 | current_edge.outgoing_edges if self._forward_search else current_edge.incoming_edges 78 | ) 79 | 80 | # Populate queue 81 | for next_edge in neighbors: 82 | # if next_edge.id in self._candidate_lane_edge_ids_old: 83 | self._queue.append(next_edge) 84 | self._parent[next_edge.id + f"_{depth + 1}"] = current_edge 85 | end_edge = next_edge 86 | end_depth = depth + 1 87 | 88 | return self._construct_path(end_edge, end_depth), path_found 89 | 90 | def id_to_roadblock(self, id: str) -> RoadBlockGraphEdgeMapObject: 91 | """ 92 | Retrieves roadblock from map-api based on id 93 | :param id: id of roadblock 94 | :return: roadblock class 95 | """ 96 | block = self._map_api._get_roadblock(id) 97 | block = block or self._map_api._get_roadblock_connector(id) 98 | return block 99 | 100 | @staticmethod 101 | def _check_end_condition(depth: int, max_depth: int) -> bool: 102 | """ 103 | Check if the search should end regardless if the goal condition is met. 104 | :param depth: The current depth to check. 105 | :param target_depth: The target depth to check against. 106 | :return: whether depth exceeds the target depth. 107 | """ 108 | return depth > max_depth 109 | 110 | def _check_goal_condition( 111 | self, 112 | current_edge: RoadBlockGraphEdgeMapObject, 113 | depth: int, 114 | max_depth: int, 115 | ) -> bool: 116 | """ 117 | Check if the current edge is at the target roadblock at the given depth. 118 | :param current_edge: edge to check. 119 | :param depth: current depth to check. 120 | :param max_depth: maximum depth the edge should be at. 121 | :return: True if the lane edge is contain the in the target roadblock. False, otherwise. 122 | """ 123 | return current_edge.id in self._target_roadblock_ids and depth <= max_depth 124 | 125 | def _construct_path( 126 | self, end_edge: RoadBlockGraphEdgeMapObject, depth: int 127 | ) -> List[RoadBlockGraphEdgeMapObject]: 128 | """ 129 | Constructs a path when goal was found. 130 | :param end_edge: The end edge to start back propagating back to the start edge. 131 | :param depth: The depth of the target edge. 132 | :return: The constructed path as a list of RoadBlockGraphEdgeMapObject 133 | """ 134 | path = [end_edge] 135 | path_id = [end_edge.id] 136 | 137 | while self._parent[end_edge.id + f"_{depth}"] is not None: 138 | path.append(self._parent[end_edge.id + f"_{depth}"]) 139 | path_id.append(path[-1].id) 140 | end_edge = self._parent[end_edge.id + f"_{depth}"] 141 | depth -= 1 142 | 143 | if self._forward_search: 144 | path.reverse() 145 | path_id.reverse() 146 | 147 | return (path, path_id) 148 | -------------------------------------------------------------------------------- /src/data_augmentation/state_perturbation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Tuple, cast 3 | 4 | import numpy as np 5 | import numpy.typing as npt 6 | import torch 7 | from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters 8 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario 9 | from nuplan.planning.training.data_augmentation.abstract_data_augmentation import ( 10 | AbstractAugmentor, 11 | ) 12 | from nuplan.planning.training.data_augmentation.data_augmentation_util import ( 13 | ParameterToScale, 14 | ScalingDirection, 15 | UniformNoise, 16 | ) 17 | from nuplan.planning.training.modeling.types import FeaturesType, TargetsType 18 | 19 | from src.features.nuplan_feature import NuplanFeature 20 | from src.utils.collision_checker import CollisionChecker 21 | 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class StatePerturbation(AbstractAugmentor): 27 | """ 28 | Data augmentation that perturbs the current ego position and generates a feasible trajectory history that 29 | satisfies a set of kinematic constraints. 30 | 31 | This involves constrained minimization of the following objective: 32 | * minimize dist(perturbed_trajectory, ground_truth_trajectory) 33 | 34 | 35 | Simple data augmentation that adds Gaussian noise to the ego current position with specified mean and std. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | dt: float = 0.1, 41 | hist_len: int = 21, 42 | low: List[float] = [0.0, -1.5, -0.55, -1, -0.5, -0.2, -0.2], 43 | high: List[float] = [2.0, 1.5, 0.55, 1, 0.5, 0.2, 0.2], 44 | augment_prob: float = 0.5, 45 | normalize=True, 46 | ) -> None: 47 | """ 48 | Initialize the augmentor, 49 | state: [x, y, yaw, vel, acc, steer, steer_rate, angular_vel, angular_acc], 50 | :param dt: Time interval between trajectory points. 51 | :param low: Parameter to set lower bound vector of the Uniform noise on [x, y, yaw]. Used only if use_uniform_noise == True. 52 | :param high: Parameter to set upper bound vector of the Uniform noise on [x, y, yaw]. Used only if use_uniform_noise == True. 53 | :param augment_prob: probability between 0 and 1 of applying the data augmentation 54 | :param use_uniform_noise: Parameter to decide to use uniform noise instead of gaussian noise if true. 55 | """ 56 | self._dt = dt 57 | self._hist_len = hist_len 58 | self._random_offset_generator = UniformNoise(low, high) 59 | self._augment_prob = augment_prob 60 | self._normalize = normalize 61 | self._collision_checker = CollisionChecker() 62 | self._rear_to_cog = get_pacifica_parameters().rear_axle_to_center 63 | 64 | def safety_check( 65 | self, 66 | ego_position: npt.NDArray[np.float32], 67 | ego_heading: npt.NDArray[np.float32], 68 | agents_position: npt.NDArray[np.float32], 69 | agents_heading: npt.NDArray[np.float32], 70 | agents_shape: npt.NDArray[np.float32], 71 | ) -> bool: 72 | if len(agents_position) == 0: 73 | return True 74 | 75 | ego_center = ( 76 | ego_position 77 | + np.stack([np.cos(ego_heading), np.sin(ego_heading)], axis=-1) 78 | * self._rear_to_cog 79 | ) 80 | ego_state = torch.from_numpy( 81 | np.concatenate([ego_center, [ego_heading]], axis=-1) 82 | ).unsqueeze(0) 83 | objects_state = torch.from_numpy( 84 | np.concatenate([agents_position, agents_heading[..., None]], axis=-1) 85 | ).unsqueeze(0) 86 | 87 | collisions = self._collision_checker.collision_check( 88 | ego_state=ego_state, 89 | objects=objects_state, 90 | objects_width=torch.from_numpy(agents_shape[:, 0]).unsqueeze(0), 91 | objects_length=torch.from_numpy(agents_shape[:, 1]).unsqueeze(0), 92 | ) 93 | 94 | return not collisions.any() 95 | 96 | def augment( 97 | self, 98 | features: FeaturesType, 99 | targets: TargetsType = None, 100 | scenario: Optional[AbstractScenario] = None, 101 | ) -> Tuple[FeaturesType, TargetsType]: 102 | """Inherited, see superclass.""" 103 | if np.random.rand() >= self._augment_prob: 104 | return features, targets 105 | 106 | data = features["feature"].data 107 | 108 | current_state = data["current_state"] 109 | new_state = current_state + self._random_offset_generator.sample() 110 | new_state[3] = max(0.0, new_state[3]) 111 | 112 | # consider nearest 10 agents 113 | agents_position = data["agent"]["position"][1:11, self._hist_len - 1] 114 | agents_shape = data["agent"]["shape"][1:11, self._hist_len - 1] 115 | agents_heading = data["agent"]["heading"][1:11, self._hist_len - 1] 116 | agents_shape = data["agent"]["shape"][1:11, self._hist_len - 1] 117 | 118 | if not self.safety_check( 119 | ego_position=new_state[:2], 120 | ego_heading=new_state[2], 121 | agents_position=agents_position, 122 | agents_heading=agents_heading, 123 | agents_shape=agents_shape, 124 | ): 125 | return features, targets 126 | 127 | data["current_state"] = new_state 128 | data["agent"]["position"][0, self._hist_len - 1] = new_state[:2] 129 | data["agent"]["heading"][0, self._hist_len - 1] = new_state[2] 130 | 131 | if self._normalize: 132 | features["feature"] = NuplanFeature.normalize(data) 133 | 134 | return features, targets 135 | 136 | @property 137 | def required_features(self) -> List[str]: 138 | """Inherited, see superclass.""" 139 | return [] 140 | 141 | @property 142 | def required_targets(self) -> List[str]: 143 | """Inherited, see superclass.""" 144 | return [] 145 | 146 | @property 147 | def augmentation_probability(self) -> ParameterToScale: 148 | """Inherited, see superclass.""" 149 | return ParameterToScale( 150 | param=self._augment_prob, 151 | param_name=f"{self._augment_prob=}".partition("=")[0].split(".")[1], 152 | scaling_direction=ScalingDirection.MAX, 153 | ) 154 | 155 | @property 156 | def get_schedulable_attributes(self) -> List[ParameterToScale]: 157 | """Inherited, see superclass.""" 158 | return cast( 159 | List[ParameterToScale], 160 | self._random_offset_generator.get_schedulable_attributes(), 161 | ) 162 | -------------------------------------------------------------------------------- /src/planners/imitation_planner.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Optional, Type 3 | 4 | import numpy as np 5 | import torch 6 | from nuplan.common.actor_state.ego_state import EgoState 7 | from nuplan.planning.simulation.observation.observation_type import ( 8 | DetectionsTracks, 9 | Observation, 10 | ) 11 | from nuplan.planning.simulation.planner.abstract_planner import ( 12 | AbstractPlanner, 13 | PlannerInitialization, 14 | PlannerInput, 15 | PlannerReport, 16 | ) 17 | from nuplan.planning.simulation.planner.planner_report import MLPlannerReport 18 | from nuplan.planning.simulation.trajectory.abstract_trajectory import AbstractTrajectory 19 | from nuplan.planning.simulation.trajectory.interpolated_trajectory import ( 20 | InterpolatedTrajectory, 21 | ) 22 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 23 | 24 | from src.feature_builders.common.utils import rotate_round_z_axis 25 | 26 | from .planner_utils import global_trajectory_to_states, load_checkpoint 27 | 28 | 29 | class ImitationPlanner(AbstractPlanner): 30 | """ 31 | Long-term IL-based trajectory planner, with short-term RL-based trajectory tracker. 32 | """ 33 | 34 | requires_scenario: bool = False 35 | 36 | def __init__( 37 | self, 38 | planner: TorchModuleWrapper, 39 | planner_ckpt: str = None, 40 | replan_interval: int = 1, 41 | use_gpu: bool = True, 42 | ) -> None: 43 | """ 44 | Initializes the ML planner class. 45 | :param model: Model to use for inference. 46 | """ 47 | if use_gpu: 48 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 | else: 50 | self.device = torch.device("cpu") 51 | 52 | self._planner = planner 53 | self._planner_feature_builder = planner.get_list_of_required_feature()[0] 54 | self._planner_ckpt = planner_ckpt 55 | self._initialization: Optional[PlannerInitialization] = None 56 | 57 | self._future_horizon = 8.0 58 | self._step_interval = 0.1 59 | 60 | self._replan_interval = replan_interval 61 | self._last_plan_elapsed_step = replan_interval # force plan at first step 62 | self._global_trajectory = None 63 | self._start_time = None 64 | 65 | # Runtime stats for the MLPlannerReport 66 | self._feature_building_runtimes: List[float] = [] 67 | self._inference_runtimes: List[float] = [] 68 | 69 | def initialize(self, initialization: PlannerInitialization) -> None: 70 | """Inherited, see superclass.""" 71 | torch.set_grad_enabled(False) 72 | 73 | if self._planner_ckpt is not None: 74 | self._planner.load_state_dict(load_checkpoint(self._planner_ckpt)) 75 | 76 | self._planner.eval() 77 | self._planner = self._planner.to(self.device) 78 | self._initialization = initialization 79 | 80 | # just to trigger numba compile, no actually meaning 81 | rotate_round_z_axis(np.zeros((1, 2), dtype=np.float64), float(0.0)) 82 | 83 | def name(self) -> str: 84 | """Inherited, see superclass.""" 85 | return self.__class__.__name__ 86 | 87 | def observation_type(self) -> Type[Observation]: 88 | """Inherited, see superclass.""" 89 | return DetectionsTracks # type: ignore 90 | 91 | def _planning(self, current_input: PlannerInput): 92 | self._start_time = time.perf_counter() 93 | planner_feature = self._planner_feature_builder.get_features_from_simulation( 94 | current_input, self._initialization 95 | ) 96 | planner_feature_torch = planner_feature.collate( 97 | [planner_feature.to_feature_tensor().to_device(self.device)] 98 | ) 99 | self._feature_building_runtimes.append(time.perf_counter() - self._start_time) 100 | 101 | out = self._planner.forward(planner_feature_torch.data) 102 | local_trajectory = out["output_trajectory"][0].cpu().numpy() 103 | 104 | return local_trajectory.astype(np.float64) 105 | 106 | def compute_planner_trajectory( 107 | self, current_input: PlannerInput 108 | ) -> AbstractTrajectory: 109 | """ 110 | Infer relative trajectory poses from model and convert to absolute agent states wrapped in a trajectory. 111 | Inherited, see superclass. 112 | """ 113 | ego_state = current_input.history.ego_states[-1] 114 | 115 | if self._last_plan_elapsed_step >= self._replan_interval: 116 | local_trajectory = self._planning(current_input) 117 | self._global_trajectory = self._get_global_trajectory( 118 | local_trajectory, ego_state 119 | ) 120 | self._last_plan_elapsed_step = 0 121 | else: 122 | self._global_trajectory = self._global_trajectory[1:] 123 | 124 | trajectory = InterpolatedTrajectory( 125 | trajectory=global_trajectory_to_states( 126 | global_trajectory=self._global_trajectory, 127 | ego_history=current_input.history.ego_states, 128 | future_horizon=len(self._global_trajectory) * self._step_interval, 129 | step_interval=self._step_interval, 130 | ) 131 | ) 132 | 133 | self._inference_runtimes.append(time.perf_counter() - self._start_time) 134 | 135 | self._last_plan_elapsed_step += 1 136 | 137 | return trajectory 138 | 139 | def generate_planner_report(self, clear_stats: bool = True) -> PlannerReport: 140 | """Inherited, see superclass.""" 141 | report = MLPlannerReport( 142 | compute_trajectory_runtimes=self._compute_trajectory_runtimes, 143 | feature_building_runtimes=self._feature_building_runtimes, 144 | inference_runtimes=self._inference_runtimes, 145 | ) 146 | if clear_stats: 147 | self._compute_trajectory_runtimes: List[float] = [] 148 | self._feature_building_runtimes = [] 149 | self._inference_runtimes = [] 150 | 151 | return report 152 | 153 | def _get_global_trajectory(self, local_trajectory: np.ndarray, ego_state: EgoState): 154 | origin = ego_state.rear_axle.array 155 | angle = ego_state.rear_axle.heading 156 | 157 | global_position = ( 158 | rotate_round_z_axis(np.ascontiguousarray(local_trajectory[..., :2]), -angle) 159 | + origin 160 | ) 161 | global_heading = local_trajectory[..., 2] + angle 162 | 163 | global_trajectory = np.concatenate( 164 | [global_position, global_heading[..., None]], axis=1 165 | ) 166 | 167 | return global_trajectory 168 | -------------------------------------------------------------------------------- /config/scenario_filter/test14-hard.yaml: -------------------------------------------------------------------------------- 1 | _target_: nuplan.planning.scenario_builder.scenario_filter.ScenarioFilter 2 | _convert_: all 3 | scenario_types: 4 | - behind_long_vehicle 5 | - changing_lane 6 | - following_lane_with_lead 7 | - high_lateral_acceleration 8 | - high_magnitude_speed 9 | - low_magnitude_speed 10 | - near_multiple_vehicles 11 | - starting_left_turn 12 | - starting_right_turn 13 | - starting_straight_traffic_light_intersection_traversal 14 | - stationary_in_traffic 15 | - stopping_with_lead 16 | - traversing_pickup_dropoff 17 | - waiting_for_pedestrian_to_cross 18 | scenario_tokens: 19 | - a36c1b943871552a 20 | - "88aa2ad613205556" 21 | - "6db7f9f43c655149" 22 | - "45aa9e8713fa5bee" 23 | - "577d4f456cd65460" 24 | - ac286967ed895963 25 | - "4836b2dd09895237" 26 | - f3702a1cc1cb5c64 27 | - "660d375c109f5eed" 28 | - fa6b31fc16f251c9 29 | - "96c975c46cac5a49" 30 | - "8a29aecff22b5657" 31 | - "990551bed2555351" 32 | - "014ad27ed9da5b86" 33 | - e7b473cea10954cb 34 | - beba883bb6285cee 35 | - a10ebe68a57c5dda 36 | - "2754fbd9c2445dce" 37 | - "48d3ee048cff55d6" 38 | - "03d25c49fbc6550e" 39 | - "5d392fa38ff65c3f" 40 | - "59199b4d340558b4" 41 | - d8a23f0cb78e5938 42 | - "9ef2cc51be4c51ed" 43 | - be2efbeede795568 44 | - "06533b6b947357f7" 45 | - b87bd5af8ff75bc5 46 | - cb0c11da557650db 47 | - b6f48c3ca32750d5 48 | - "96074214b9645952" 49 | - "2e56997063c057a0" 50 | - "217131c79adf588d" 51 | - "7995b9b53ef55e80" 52 | - "4236a1f3c9e35f9d" 53 | - be8d080bf7cf5835 54 | - a61d56930e9b5f36 55 | - "094db8d50cf95250" 56 | - "442523fcce0950e8" 57 | - f7700356e6d85410 58 | - c6277e74a4f054a7 59 | - "0d3490827eb45df7" 60 | - "5a79cd2161c754c9" 61 | - cf87d58064425ee7 62 | - "1c17b9dad1f65970" 63 | - "93c583b46398560e" 64 | - "4e8af2b28cea5133" 65 | - ff4d69f4dd0c5474 66 | - "74914e4a95025d59" 67 | - e8ec64ceb6d050ff 68 | - c0c385cbdd47536b 69 | - "938b0223319957ef" 70 | - d13124fa683654b4 71 | - "32aea40a777c5156" 72 | - "35c3787707ce5deb" 73 | - "5bebf2e252ec5367" 74 | - "6ff0d0bb90d852b9" 75 | - "75025613e3935595" 76 | - "7c34e3a807965fc1" 77 | - "752db4af7e2754e7" 78 | - a2fc81fb19985d60 79 | - "6d0875dbcec45b3a" 80 | - "75d56d4c7b6f5013" 81 | - "9a92b5bf8d735034" 82 | - "68e943391ab65773" 83 | - "63352660f02f5e3d" 84 | - "295432d96052578e" 85 | - "8c8e9b7de9bf541a" 86 | - "3e673b48bbf15727" 87 | - efa032e2bf055c96 88 | - "78fb87b881535792" 89 | - d10a8714a31b58b4 90 | - "45df14109a3a5937" 91 | - d07309ff567b5953 92 | - "3f9e713baeec55a9" 93 | - "11f773b0b19055e9" 94 | - "3a0b1987ce79508f" 95 | - a22b2b1e04595a0e 96 | - ae6691cb57175db6 97 | - "0a2be0a8f9c75775" 98 | - a9407c1ea38959d5 99 | - "88ac8fc4c8b25a22" 100 | - "070345fd69165e11" 101 | - "90a02e90433d5137" 102 | - "9c7f7922e3225dec" 103 | - cb57c23a2cc05786 104 | - "9c711bed4f175f6d" 105 | - "57441c7185bc5b30" 106 | - "56cdd36746c85788" 107 | - c5171ce6d7f05d80 108 | - "8629ab5781b853d3" 109 | - "420653bda4b2575b" 110 | - edc971e9dc7f5165 111 | - d173d3a1c8fb5e8a 112 | - "2cf6ac2d997a5dff" 113 | - "35265fcdd0be5579" 114 | - c136ba64909b5ffe 115 | - "69e91c91e4e65848" 116 | - "840f360aea765a57" 117 | - "88a42c466d2b5ebd" 118 | - ffbae8b71907545c 119 | - "15d232cebce05616" 120 | - be2029b5dc2c5b78 121 | - fd1a947d104956f8 122 | - "63aa10eaed9f53c5" 123 | - b1cef3eeb5445447 124 | - c24af34deb6f54bf 125 | - "33187fb09d0e52f8" 126 | - "2c5af3c1152c5e69" 127 | - "37947d83063255ea" 128 | - "698cc78af2d154d2" 129 | - ff02a16ae6c95a42 130 | - "3454842a96ef511e" 131 | - "61380daa1b275ce0" 132 | - "4f3e807c698a5335" 133 | - "4fde9018255c5f96" 134 | - "6bdb5f343c355f98" 135 | - c9f1e5d3c8325ac4 136 | - "2b34e2a1fc6b543c" 137 | - d63776b80fbd5d4b 138 | - "56022a77c9fd5a1e" 139 | - "91c272aacfc6511b" 140 | - "904819e474565d84" 141 | - "7f7cd65cacff5f22" 142 | - "1796b59d5f16581f" 143 | - e0834d617d8d5453 144 | - "18a8fe4997e053fe" 145 | - "9891d19fd245546a" 146 | - c88e29a426805b27 147 | - "71544041acf754bc" 148 | - "72a3637d375155d9" 149 | - "763bb5fb88d2556f" 150 | - "0586b3fb1ffb5fba" 151 | - "0cd6cff135c25fa9" 152 | - "577c248204055c02" 153 | - "1075ba60960c564e" 154 | - "044a3924d9e15bc1" 155 | - "60ad86e2dbf156b5" 156 | - "40e5fc0034035139" 157 | - "067d36d6568657f4" 158 | - "51b40354373a5fac" 159 | - "8e355342a0115145" 160 | - cbdc3e2a73e15687 161 | - eb2ac6cc24f55f6e 162 | - d3be3fa904135dd9 163 | - "338ca19e3fde5381" 164 | - e10289f39c5f5bea 165 | - "70ee68fbecd354b0" 166 | - "7c64bfeda7ca5f22" 167 | - "96b881f39d3a55f2" 168 | - "914f752db44d5143" 169 | - afd9603706985285 170 | - "88431143abe45625" 171 | - b703d99845f45666 172 | - c921e9258bdf5c0c 173 | - "66f5844016ec5134" 174 | - ba267dbeb0c853b5 175 | - "2b3f110c96995dd1" 176 | - ae0ace575c775f82 177 | - de3cc0dd7ac45a65 178 | - "230a1a270a105821" 179 | - c2c649fcd5325b3d 180 | - "6f2ce5e5530b5e3c" 181 | - "688caa7560c05d03" 182 | - "8de10fd86b825304" 183 | - "61f6e42143cf5eb9" 184 | - "00af7480d144507f" 185 | - cfa05ee317245c9f 186 | - "99ef05a2e6a45454" 187 | - d4edda3ab1d75034 188 | - "0b7bb38b72ff5034" 189 | - "179a1bf034c650ce" 190 | - "643114dca0825b3f" 191 | - edbc6fe0dfbf51f1 192 | - "6fd36832e6925f74" 193 | - "295154d51713573f" 194 | - "40e1f34c92255a0e" 195 | - "1c3d3a5bb86c5d97" 196 | - "6eb38f317f2251f2" 197 | - "0a1adb702c0f5949" 198 | - "003445cf99235331" 199 | - "46236093f497573f" 200 | - e9faae87fb83540d 201 | - "2fad34f49b825d6b" 202 | - "521a9f9977d351da" 203 | - "535516116fa35117" 204 | - d453b49dafae5dd5 205 | - "89b7bd3592505a26" 206 | - "20f2e6f4bf3f5e80" 207 | - "5fa2710018705849" 208 | - f9f0880607a15639 209 | - "9fd5ec2b453d556e" 210 | - "80a05e5b3037536a" 211 | - ca050f5421925415 212 | - "4095fd4af9a45b18" 213 | - "33b448fab23a5d03" 214 | - "5cfbb108eb5f584f" 215 | - "2ddc6c9887915bc9" 216 | - "00de3f6da9205a0c" 217 | - ebf09047e62a5fb8 218 | - "6639a4f7873a56dc" 219 | - "2703f12098405703" 220 | - "26c1cae14e5b5ce8" 221 | - a112ea9f3d2b5d5b 222 | - a1a3d628e53c5d75 223 | - a2ec5056da3c5c67 224 | - "7289414e82da5b00" 225 | - "1e3f5ab092335059" 226 | - f15c28ffb0ea5a7a 227 | - "9b039177b43b5260" 228 | - "7803717ee46c58cc" 229 | - "134e675d2807537f" 230 | - e569500a38a256bf 231 | - "6262500c70275443" 232 | - "0e165b03aae35700" 233 | - b56eda3b0d1c5f67 234 | - "5bc4f584c6325e50" 235 | - "5cb7891d29545bd4" 236 | - "62ec1599159e5af6" 237 | - f67f6478ccf65687 238 | - "0f1513c4f8285ab1" 239 | - "97237e8269415fb0" 240 | - "80ababf8dcfb5914" 241 | - "752323a35a825d22" 242 | - ba4883a772905c39 243 | - b327b4d0432a51a1 244 | - e4bc04cbc7eb5940 245 | - f8c717165ff15ff7 246 | - "01010f1fc37b5321" 247 | - "3150573de84e5ea9" 248 | - c532a0844ee35bde 249 | - df68e9d709e65b1e 250 | - "4570000dd1685c5d" 251 | - a3b20f60c7835df8 252 | - db363d844b3754f1 253 | - "65d4717562ea5d95" 254 | - d8834251f07a597f 255 | - "42ad59cd4b2b5205" 256 | - "698b38a35cc95956" 257 | - f9ac8947d2c55c1b 258 | - "0681f7fa37dd5f63" 259 | - "13f3028945475a79" 260 | - ae6e6bd7567d56bc 261 | - d91351e4615859fb 262 | - "0795bcd734235cf5" 263 | - d908073d216b5e04 264 | - bcf52e9ad12c5ce4 265 | - "62db9dc16ffd5135" 266 | - f352e2c4378b5ea1 267 | - "99e63d494deb5808" 268 | - "3cd758f0d51a55e5" 269 | - "74874f782e725aec" 270 | - "98c66f5373705fd7" 271 | - "0d17d16b86e65700" 272 | - "54da00bc66c7575d" 273 | - d417ec1ee7295c5f 274 | - "0488dbdf03b55f00" 275 | - "9e220701888b5ab2" 276 | - "855985a401ab59ea" 277 | - "16a890626fa9570d" 278 | - daa841dbda985ed0 279 | - "1e502e7fc8745e0a" 280 | - "9914f87b536b5c23" 281 | - c2b222d1f0715c00 282 | - e94b99e48b355de6 283 | - "7d365cae2cd45ad8" 284 | - "8181c455582c5623" 285 | - "9e7b8fa4248d55de" 286 | - "159942fd13675580" 287 | - "820bf3685bc955ef" 288 | - "7dde022f98be574c" 289 | - b8fc7d499e705b68 290 | - bc962c4b185859a6 291 | log_names: null 292 | map_names: null 293 | num_scenarios_per_type: null 294 | limit_total_scenarios: null 295 | timestamp_threshold_s: 15 296 | ego_displacement_minimum_m: null 297 | ego_start_speed_threshold: null 298 | ego_stop_speed_threshold: null 299 | speed_noise_tolerance: null 300 | expand_scenarios: null 301 | remove_invalid_goals: true 302 | shuffle: false 303 | -------------------------------------------------------------------------------- /src/models/planTF/layers/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from natten import NeighborhoodAttention1D 5 | from timm.models.layers import DropPath 6 | 7 | 8 | class NATSequenceEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | in_chans=3, 12 | embed_dim=32, 13 | mlp_ratio=3, 14 | kernel_size=[3, 3, 5], 15 | depths=[2, 2, 2], 16 | num_heads=[2, 4, 8], 17 | out_indices=[0, 1, 2], 18 | drop_rate=0.0, 19 | attn_drop_rate=0.0, 20 | drop_path_rate=0.2, 21 | norm_layer=nn.LayerNorm, 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.embed = ConvTokenizer(in_chans, embed_dim) 26 | self.num_levels = len(depths) 27 | self.num_features = [int(embed_dim * 2**i) for i in range(self.num_levels)] 28 | self.out_indices = out_indices 29 | 30 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 31 | self.levels = nn.ModuleList() 32 | for i in range(self.num_levels): 33 | level = NATBlock( 34 | dim=int(embed_dim * 2**i), 35 | depth=depths[i], 36 | num_heads=num_heads[i], 37 | kernel_size=kernel_size[i], 38 | dilations=None, 39 | mlp_ratio=mlp_ratio, 40 | drop=drop_rate, 41 | attn_drop=attn_drop_rate, 42 | drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], 43 | norm_layer=norm_layer, 44 | downsample=(i < self.num_levels - 1), 45 | ) 46 | self.levels.append(level) 47 | 48 | for i_layer in self.out_indices: 49 | layer = norm_layer(self.num_features[i_layer]) 50 | layer_name = f"norm{i_layer}" 51 | self.add_module(layer_name, layer) 52 | 53 | n = self.num_features[-1] 54 | self.lateral_convs = nn.ModuleList() 55 | for i_layer in self.out_indices: 56 | self.lateral_convs.append( 57 | nn.Conv1d(self.num_features[i_layer], n, 3, padding=1) 58 | ) 59 | 60 | self.fpn_conv = nn.Conv1d(n, n, 3, padding=1) 61 | 62 | def forward(self, x): 63 | """x: [B, C, T]""" 64 | x = self.embed(x) 65 | 66 | out = [] 67 | for idx, level in enumerate(self.levels): 68 | x, xo = level(x) 69 | if idx in self.out_indices: 70 | norm_layer = getattr(self, f"norm{idx}") 71 | x_out = norm_layer(xo) 72 | out.append(x_out.permute(0, 2, 1).contiguous()) 73 | 74 | laterals = [ 75 | lateral_conv(out[i]) for i, lateral_conv in enumerate(self.lateral_convs) 76 | ] 77 | for i in range(len(out) - 1, 0, -1): 78 | laterals[i - 1] = laterals[i - 1] + F.interpolate( 79 | laterals[i], 80 | scale_factor=(laterals[i - 1].shape[-1] / laterals[i].shape[-1]), 81 | mode="linear", 82 | align_corners=False, 83 | ) 84 | 85 | out = self.fpn_conv(laterals[0]) 86 | 87 | return out[:, :, -1] 88 | 89 | 90 | class ConvTokenizer(nn.Module): 91 | def __init__(self, in_chans=3, embed_dim=32, norm_layer=None): 92 | super().__init__() 93 | self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=3, stride=1, padding=1) 94 | 95 | if norm_layer is not None: 96 | self.norm = norm_layer(embed_dim) 97 | else: 98 | self.norm = None 99 | 100 | def forward(self, x): 101 | x = self.proj(x).permute(0, 2, 1) # B, C, L -> B, L, C 102 | if self.norm is not None: 103 | x = self.norm(x) 104 | return x 105 | 106 | 107 | class ConvDownsampler(nn.Module): 108 | def __init__(self, dim, norm_layer=nn.LayerNorm): 109 | super().__init__() 110 | self.reduction = nn.Conv1d( 111 | dim, 2 * dim, kernel_size=3, stride=2, padding=1, bias=False 112 | ) 113 | self.norm = norm_layer(2 * dim) 114 | 115 | def forward(self, x): 116 | x = self.reduction(x.permute(0, 2, 1)).permute(0, 2, 1) 117 | x = self.norm(x) 118 | return x 119 | 120 | 121 | class Mlp(nn.Module): 122 | def __init__( 123 | self, 124 | in_features, 125 | hidden_features=None, 126 | out_features=None, 127 | act_layer=nn.GELU, 128 | drop=0.0, 129 | ): 130 | super().__init__() 131 | out_features = out_features or in_features 132 | hidden_features = hidden_features or in_features 133 | self.fc1 = nn.Linear(in_features, hidden_features) 134 | self.act = act_layer() 135 | self.fc2 = nn.Linear(hidden_features, out_features) 136 | self.drop = nn.Dropout(drop) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) 140 | x = self.act(x) 141 | x = self.drop(x) 142 | x = self.fc2(x) 143 | x = self.drop(x) 144 | return x 145 | 146 | 147 | class NATLayer(nn.Module): 148 | def __init__( 149 | self, 150 | dim, 151 | num_heads, 152 | kernel_size=7, 153 | dilation=None, 154 | mlp_ratio=4.0, 155 | qkv_bias=True, 156 | qk_scale=None, 157 | drop=0.0, 158 | attn_drop=0.0, 159 | drop_path=0.0, 160 | act_layer=nn.GELU, 161 | norm_layer=nn.LayerNorm, 162 | ): 163 | super().__init__() 164 | self.dim = dim 165 | self.num_heads = num_heads 166 | self.mlp_ratio = mlp_ratio 167 | 168 | self.norm1 = norm_layer(dim) 169 | self.attn = NeighborhoodAttention1D( 170 | dim, 171 | kernel_size=kernel_size, 172 | dilation=dilation, 173 | num_heads=num_heads, 174 | qkv_bias=qkv_bias, 175 | qk_scale=qk_scale, 176 | attn_drop=attn_drop, 177 | proj_drop=drop, 178 | ) 179 | 180 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 181 | self.norm2 = norm_layer(dim) 182 | self.mlp = Mlp( 183 | in_features=dim, 184 | hidden_features=int(dim * mlp_ratio), 185 | act_layer=act_layer, 186 | drop=drop, 187 | ) 188 | 189 | def forward(self, x): 190 | shortcut = x 191 | x = self.norm1(x) 192 | x = self.attn(x) 193 | x = shortcut + self.drop_path(x) 194 | x = x + self.drop_path(self.mlp(self.norm2(x))) 195 | return x 196 | 197 | 198 | class NATBlock(nn.Module): 199 | def __init__( 200 | self, 201 | dim, 202 | depth, 203 | num_heads, 204 | kernel_size, 205 | dilations=None, 206 | downsample=True, 207 | mlp_ratio=4.0, 208 | qkv_bias=True, 209 | qk_scale=None, 210 | drop=0.0, 211 | attn_drop=0.0, 212 | drop_path=0.0, 213 | norm_layer=nn.LayerNorm, 214 | act_layer=nn.GELU, 215 | ): 216 | super().__init__() 217 | self.dim = dim 218 | self.depth = depth 219 | 220 | self.blocks = nn.ModuleList( 221 | [ 222 | NATLayer( 223 | dim=dim, 224 | num_heads=num_heads, 225 | kernel_size=kernel_size, 226 | dilation=None if dilations is None else dilations[i], 227 | mlp_ratio=mlp_ratio, 228 | qkv_bias=qkv_bias, 229 | qk_scale=qk_scale, 230 | drop=drop, 231 | attn_drop=attn_drop, 232 | drop_path=drop_path[i] 233 | if isinstance(drop_path, list) 234 | else drop_path, 235 | norm_layer=norm_layer, 236 | act_layer=act_layer, 237 | ) 238 | for i in range(depth) 239 | ] 240 | ) 241 | 242 | self.downsample = ( 243 | None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer) 244 | ) 245 | 246 | def forward(self, x): 247 | for blk in self.blocks: 248 | x = blk(x) 249 | if self.downsample is None: 250 | return x, x 251 | return self.downsample(x), x 252 | 253 | 254 | class PointsEncoder(nn.Module): 255 | def __init__(self, feat_channel, encoder_channel): 256 | super().__init__() 257 | self.encoder_channel = encoder_channel 258 | self.first_mlp = nn.Sequential( 259 | nn.Linear(feat_channel, 128), 260 | nn.BatchNorm1d(128), 261 | nn.ReLU(inplace=True), 262 | nn.Linear(128, 256), 263 | ) 264 | self.second_mlp = nn.Sequential( 265 | nn.Linear(512, 256), 266 | nn.BatchNorm1d(256), 267 | nn.ReLU(inplace=True), 268 | nn.Linear(256, self.encoder_channel), 269 | ) 270 | 271 | def forward(self, x, mask=None): 272 | """ 273 | x : B M 3 274 | mask: B M 275 | ----------------- 276 | feature_global : B C 277 | """ 278 | 279 | bs, n, _ = x.shape 280 | device = x.device 281 | 282 | x_valid = self.first_mlp(x[mask]) # B n 256 283 | x_features = torch.zeros(bs, n, 256, device=device) 284 | x_features[mask] = x_valid 285 | 286 | pooled_feature = x_features.max(dim=1)[0] 287 | x_features = torch.cat( 288 | [x_features, pooled_feature.unsqueeze(1).repeat(1, n, 1)], dim=-1 289 | ) 290 | 291 | x_features_valid = self.second_mlp(x_features[mask]) 292 | res = torch.zeros(bs, n, self.encoder_channel, device=device) 293 | res[mask] = x_features_valid 294 | 295 | res = res.max(dim=1)[0] 296 | return res 297 | -------------------------------------------------------------------------------- /src/models/planTF/lightning_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict, Tuple, Union 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 10 | from nuplan.planning.training.modeling.types import ( 11 | FeaturesType, 12 | ScenarioListType, 13 | TargetsType, 14 | ) 15 | from torch.optim import Optimizer 16 | from torch.optim.lr_scheduler import _LRScheduler 17 | from torchmetrics import MetricCollection 18 | 19 | from src.metrics import MR, minADE, minFDE 20 | from src.optim.warmup_cos_lr import WarmupCosLR 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class LightningTrainer(pl.LightningModule): 26 | def __init__( 27 | self, 28 | model: TorchModuleWrapper, 29 | lr, 30 | weight_decay, 31 | epochs, 32 | warmup_epochs, 33 | ) -> None: 34 | super().__init__() 35 | self.save_hyperparameters(ignore=["model"]) 36 | 37 | self.model = model 38 | self.lr = lr 39 | self.weight_decay = weight_decay 40 | self.epochs = epochs 41 | self.warmup_epochs = warmup_epochs 42 | 43 | def on_fit_start(self) -> None: 44 | metrics_collection = MetricCollection( 45 | { 46 | "minADE1": minADE(k=1).to(self.device), 47 | "minADE6": minADE(k=6).to(self.device), 48 | "minFDE1": minFDE(k=1).to(self.device), 49 | "minFDE6": minFDE(k=6).to(self.device), 50 | "MR": MR().to(self.device), 51 | } 52 | ) 53 | self.metrics = { 54 | "train": metrics_collection.clone(prefix="train/"), 55 | "val": metrics_collection.clone(prefix="val/"), 56 | } 57 | 58 | def _step( 59 | self, batch: Tuple[FeaturesType, TargetsType, ScenarioListType], prefix: str 60 | ) -> torch.Tensor: 61 | features, _, _ = batch 62 | res = self.forward(features["feature"].data) 63 | 64 | losses = self._compute_objectives(res, features["feature"].data) 65 | metrics = self._compute_metrics(res, features["feature"].data, prefix) 66 | self._log_step(losses["loss"], losses, metrics, prefix) 67 | 68 | return losses["loss"] 69 | 70 | def _compute_objectives(self, res, data) -> Dict[str, torch.Tensor]: 71 | trajectory, probability, prediction = ( 72 | res["trajectory"], 73 | res["probability"], 74 | res["prediction"], 75 | ) 76 | targets = data["agent"]["target"] 77 | valid_mask = data["agent"]["valid_mask"][:, :, -trajectory.shape[-2] :] 78 | 79 | ego_target_pos, ego_target_heading = targets[:, 0, :, :2], targets[:, 0, :, 2] 80 | ego_target = torch.cat( 81 | [ 82 | ego_target_pos, 83 | torch.stack( 84 | [ego_target_heading.cos(), ego_target_heading.sin()], dim=-1 85 | ), 86 | ], 87 | dim=-1, 88 | ) 89 | agent_target, agent_mask = targets[:, 1:], valid_mask[:, 1:] 90 | 91 | ade = torch.norm(trajectory[..., :2] - ego_target[:, None, :, :2], dim=-1) 92 | best_mode = torch.argmin(ade.sum(-1), dim=-1) 93 | best_traj = trajectory[torch.arange(trajectory.shape[0]), best_mode] 94 | ego_reg_loss = F.smooth_l1_loss(best_traj, ego_target) 95 | ego_cls_loss = F.cross_entropy(probability, best_mode.detach()) 96 | 97 | agent_reg_loss = F.smooth_l1_loss( 98 | prediction[agent_mask], agent_target[agent_mask][:, :2] 99 | ) 100 | 101 | loss = ego_reg_loss + ego_cls_loss + agent_reg_loss 102 | 103 | return { 104 | "loss": loss, 105 | "reg_loss": ego_reg_loss, 106 | "cls_loss": ego_cls_loss, 107 | "prediction_loss": agent_reg_loss, 108 | } 109 | 110 | def _compute_metrics(self, output, data, prefix) -> Dict[str, torch.Tensor]: 111 | metrics = self.metrics[prefix](output, data["agent"]["target"][:, 0]) 112 | return metrics 113 | 114 | def _log_step( 115 | self, 116 | loss: torch.Tensor, 117 | objectives: Dict[str, torch.Tensor], 118 | metrics: Dict[str, torch.Tensor], 119 | prefix: str, 120 | loss_name: str = "loss", 121 | ) -> None: 122 | self.log( 123 | f"loss/{prefix}_{loss_name}", 124 | loss, 125 | on_step=True, 126 | on_epoch=True, 127 | sync_dist=True, 128 | ) 129 | 130 | for key, value in objectives.items(): 131 | self.log( 132 | f"objectives/{prefix}_{key}", 133 | value, 134 | on_step=False, 135 | on_epoch=True, 136 | sync_dist=True, 137 | ) 138 | 139 | if metrics is not None: 140 | self.log_dict( 141 | metrics, 142 | prog_bar=(prefix == "val"), 143 | on_step=False, 144 | on_epoch=True, 145 | batch_size=1, 146 | sync_dist=True, 147 | ) 148 | 149 | def training_step( 150 | self, batch: Tuple[FeaturesType, TargetsType, ScenarioListType], batch_idx: int 151 | ) -> torch.Tensor: 152 | """ 153 | Step called for each batch example during training. 154 | 155 | :param batch: example batch 156 | :param batch_idx: batch's index (unused) 157 | :return: model's loss tensor 158 | """ 159 | return self._step(batch, "train") 160 | 161 | def validation_step( 162 | self, batch: Tuple[FeaturesType, TargetsType, ScenarioListType], batch_idx: int 163 | ) -> torch.Tensor: 164 | """ 165 | Step called for each batch example during validation. 166 | 167 | :param batch: example batch 168 | :param batch_idx: batch's index (unused) 169 | :return: model's loss tensor 170 | """ 171 | return self._step(batch, "val") 172 | 173 | def test_step( 174 | self, batch: Tuple[FeaturesType, TargetsType, ScenarioListType], batch_idx: int 175 | ) -> torch.Tensor: 176 | """ 177 | Step called for each batch example during testing. 178 | 179 | :param batch: example batch 180 | :param batch_idx: batch's index (unused) 181 | :return: model's loss tensor 182 | """ 183 | return self._step(batch, "test") 184 | 185 | def forward(self, features: FeaturesType) -> TargetsType: 186 | """ 187 | Propagates a batch of features through the model. 188 | 189 | :param features: features batch 190 | :return: model's predictions 191 | """ 192 | return self.model(features) 193 | 194 | def configure_optimizers( 195 | self, 196 | ) -> Union[Optimizer, Dict[str, Union[Optimizer, _LRScheduler]]]: 197 | """ 198 | Configures the optimizers and learning schedules for the training. 199 | 200 | :return: optimizer or dictionary of optimizers and schedules 201 | """ 202 | decay = set() 203 | no_decay = set() 204 | whitelist_weight_modules = ( 205 | nn.Linear, 206 | nn.Conv1d, 207 | nn.Conv2d, 208 | nn.Conv3d, 209 | nn.MultiheadAttention, 210 | nn.LSTM, 211 | nn.GRU, 212 | ) 213 | blacklist_weight_modules = ( 214 | nn.BatchNorm1d, 215 | nn.BatchNorm2d, 216 | nn.BatchNorm3d, 217 | nn.SyncBatchNorm, 218 | nn.LayerNorm, 219 | nn.Embedding, 220 | ) 221 | for module_name, module in self.named_modules(): 222 | for param_name, param in module.named_parameters(): 223 | full_param_name = ( 224 | "%s.%s" % (module_name, param_name) if module_name else param_name 225 | ) 226 | if "bias" in param_name: 227 | no_decay.add(full_param_name) 228 | elif "weight" in param_name: 229 | if isinstance(module, whitelist_weight_modules): 230 | decay.add(full_param_name) 231 | elif isinstance(module, blacklist_weight_modules): 232 | no_decay.add(full_param_name) 233 | elif not ("weight" in param_name or "bias" in param_name): 234 | no_decay.add(full_param_name) 235 | param_dict = { 236 | param_name: param for param_name, param in self.named_parameters() 237 | } 238 | inter_params = decay & no_decay 239 | union_params = decay | no_decay 240 | assert len(inter_params) == 0 241 | assert len(param_dict.keys() - union_params) == 0 242 | 243 | optim_groups = [ 244 | { 245 | "params": [ 246 | param_dict[param_name] for param_name in sorted(list(decay)) 247 | ], 248 | "weight_decay": self.weight_decay, 249 | }, 250 | { 251 | "params": [ 252 | param_dict[param_name] for param_name in sorted(list(no_decay)) 253 | ], 254 | "weight_decay": 0.0, 255 | }, 256 | ] 257 | 258 | # Get optimizer 259 | optimizer = torch.optim.AdamW( 260 | optim_groups, lr=self.lr, weight_decay=self.weight_decay 261 | ) 262 | 263 | # Get lr_scheduler 264 | scheduler = WarmupCosLR( 265 | optimizer=optimizer, 266 | lr=self.lr, 267 | min_lr=1e-6, 268 | epochs=self.epochs, 269 | warmup_epochs=self.warmup_epochs, 270 | ) 271 | 272 | return [optimizer], [scheduler] 273 | -------------------------------------------------------------------------------- /src/custom_training/custom_training_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from shutil import rmtree 6 | from typing import cast 7 | 8 | import pytorch_lightning as pl 9 | from hydra.utils import instantiate 10 | from nuplan.planning.script.builders.data_augmentation_builder import ( 11 | build_agent_augmentor, 12 | ) 13 | from nuplan.planning.script.builders.model_builder import build_torch_module_wrapper 14 | from nuplan.planning.script.builders.objectives_builder import build_objectives 15 | from nuplan.planning.script.builders.scenario_builder import build_scenarios 16 | from nuplan.planning.script.builders.splitter_builder import build_splitter 17 | from nuplan.planning.script.builders.training_metrics_builder import ( 18 | build_training_metrics, 19 | ) 20 | from nuplan.planning.training.modeling.lightning_module_wrapper import ( 21 | LightningModuleWrapper, 22 | ) 23 | from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper 24 | from nuplan.planning.training.preprocessing.feature_preprocessor import ( 25 | FeaturePreprocessor, 26 | ) 27 | from nuplan.planning.utils.multithreading.worker_pool import WorkerPool 28 | from omegaconf import DictConfig, OmegaConf 29 | from pytorch_lightning.callbacks import ( 30 | LearningRateMonitor, 31 | ModelCheckpoint, 32 | RichModelSummary, 33 | RichProgressBar, 34 | ) 35 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 36 | from pytorch_lightning.loggers.wandb import WandbLogger 37 | 38 | from .custom_datamodule import CustomDataModule 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | def update_config_for_training(cfg: DictConfig) -> None: 44 | """ 45 | Updates the config based on some conditions. 46 | :param cfg: omegaconf dictionary that is used to run the experiment. 47 | """ 48 | # Make the configuration editable. 49 | OmegaConf.set_struct(cfg, False) 50 | 51 | if cfg.cache.cache_path is None: 52 | logger.warning("Parameter cache_path is not set, caching is disabled") 53 | else: 54 | if not str(cfg.cache.cache_path).startswith("s3://"): 55 | if cfg.cache.cleanup_cache and Path(cfg.cache.cache_path).exists(): 56 | rmtree(cfg.cache.cache_path) 57 | 58 | Path(cfg.cache.cache_path).mkdir(parents=True, exist_ok=True) 59 | 60 | if cfg.lightning.trainer.overfitting.enable: 61 | cfg.data_loader.params.num_workers = 0 62 | 63 | OmegaConf.resolve(cfg) 64 | 65 | # Finalize the configuration and make it non-editable. 66 | OmegaConf.set_struct(cfg, True) 67 | 68 | # Log the final configuration after all overrides, interpolations and updates. 69 | if cfg.log_config: 70 | logger.info( 71 | f"Creating experiment name [{cfg.experiment}] in group [{cfg.group}] with config..." 72 | ) 73 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 74 | 75 | 76 | @dataclass(frozen=True) 77 | class TrainingEngine: 78 | """Lightning training engine dataclass wrapping the lightning trainer, model and datamodule.""" 79 | 80 | trainer: pl.Trainer # Trainer for models 81 | model: pl.LightningModule # Module describing NN model, loss, metrics, visualization 82 | datamodule: pl.LightningDataModule # Loading data 83 | 84 | def __repr__(self) -> str: 85 | """ 86 | :return: String representation of class without expanding the fields. 87 | """ 88 | return f"<{type(self).__module__}.{type(self).__qualname__} object at {hex(id(self))}>" 89 | 90 | 91 | def build_lightning_datamodule( 92 | cfg: DictConfig, worker: WorkerPool, model: TorchModuleWrapper 93 | ) -> pl.LightningDataModule: 94 | """ 95 | Build the lightning datamodule from the config. 96 | :param cfg: Omegaconf dictionary. 97 | :param model: NN model used for training. 98 | :param worker: Worker to submit tasks which can be executed in parallel. 99 | :return: Instantiated datamodule object. 100 | """ 101 | # Build features and targets 102 | feature_builders = model.get_list_of_required_feature() 103 | target_builders = model.get_list_of_computed_target() 104 | 105 | # Build splitter 106 | splitter = build_splitter(cfg.splitter) 107 | 108 | # Create feature preprocessor 109 | feature_preprocessor = FeaturePreprocessor( 110 | cache_path=cfg.cache.cache_path, 111 | force_feature_computation=cfg.cache.force_feature_computation, 112 | feature_builders=feature_builders, 113 | target_builders=target_builders, 114 | ) 115 | 116 | # Create data augmentation 117 | augmentors = ( 118 | build_agent_augmentor(cfg.data_augmentation) 119 | if "data_augmentation" in cfg 120 | else None 121 | ) 122 | 123 | # Build dataset scenarios 124 | scenarios = build_scenarios(cfg, worker, model) 125 | 126 | # Create datamodule 127 | datamodule: pl.LightningDataModule = CustomDataModule( 128 | feature_preprocessor=feature_preprocessor, 129 | splitter=splitter, 130 | all_scenarios=scenarios, 131 | dataloader_params=cfg.data_loader.params, 132 | augmentors=augmentors, 133 | worker=worker, 134 | scenario_type_sampling_weights=cfg.scenario_type_weights.scenario_type_sampling_weights, 135 | **cfg.data_loader.datamodule, 136 | ) 137 | 138 | return datamodule 139 | 140 | 141 | def build_lightning_module( 142 | cfg: DictConfig, torch_module_wrapper: TorchModuleWrapper 143 | ) -> pl.LightningModule: 144 | """ 145 | Builds the lightning module from the config. 146 | :param cfg: omegaconf dictionary 147 | :param torch_module_wrapper: NN model used for training 148 | :return: built object. 149 | """ 150 | # Create the complete Module 151 | if "custom_trainer" in cfg: 152 | model = instantiate( 153 | cfg.custom_trainer, 154 | model=torch_module_wrapper, 155 | lr=cfg.lr, 156 | weight_decay=cfg.weight_decay, 157 | epochs=cfg.epochs, 158 | warmup_epochs=cfg.warmup_epochs, 159 | ) 160 | else: 161 | objectives = build_objectives(cfg) 162 | metrics = build_training_metrics(cfg) 163 | model = LightningModuleWrapper( 164 | model=torch_module_wrapper, 165 | objectives=objectives, 166 | metrics=metrics, 167 | batch_size=cfg.data_loader.params.batch_size, 168 | optimizer=cfg.optimizer, 169 | lr_scheduler=cfg.lr_scheduler if "lr_scheduler" in cfg else None, 170 | warm_up_lr_scheduler=cfg.warm_up_lr_scheduler 171 | if "warm_up_lr_scheduler" in cfg 172 | else None, 173 | objective_aggregate_mode=cfg.objective_aggregate_mode, 174 | ) 175 | 176 | return cast(pl.LightningModule, model) 177 | 178 | 179 | def build_custom_trainer(cfg: DictConfig) -> pl.Trainer: 180 | """ 181 | Builds the lightning trainer from the config. 182 | :param cfg: omegaconf dictionary 183 | :return: built object. 184 | """ 185 | params = cfg.lightning.trainer.params 186 | 187 | # callbacks = build_callbacks(cfg) 188 | callbacks = [ 189 | ModelCheckpoint( 190 | dirpath=os.path.join(os.getcwd(), "checkpoints"), 191 | filename="{epoch}-{val_minFDE:.3f}", 192 | monitor=cfg.lightning.trainer.checkpoint.monitor, 193 | mode=cfg.lightning.trainer.checkpoint.mode, 194 | save_top_k=cfg.lightning.trainer.checkpoint.save_top_k, 195 | save_last=True, 196 | ), 197 | RichModelSummary(max_depth=1), 198 | RichProgressBar(), 199 | LearningRateMonitor(logging_interval="epoch"), 200 | ] 201 | 202 | if cfg.wandb.mode == "disable": 203 | training_logger = TensorBoardLogger( 204 | save_dir=cfg.group, 205 | name=cfg.experiment, 206 | log_graph=False, 207 | version="", 208 | prefix="", 209 | ) 210 | else: 211 | if cfg.wandb.artifact is not None: 212 | os.system(f"wandb artifact get {cfg.wandb.artifact}") 213 | _, _, artifact = cfg.wandb.artifact.split("/") 214 | checkpoint = os.path.join(os.getcwd(), f"artifacts/{artifact}/model.ckpt") 215 | run_id = artifact.split(":")[0][-8:] 216 | cfg.checkpoint = checkpoint 217 | cfg.wandb.run_id = run_id 218 | 219 | training_logger = WandbLogger( 220 | save_dir=cfg.group, 221 | project=cfg.wandb.project, 222 | name=cfg.wandb.name, 223 | mode=cfg.wandb.mode, 224 | log_model=cfg.wandb.log_model, 225 | resume=cfg.checkpoint is not None, 226 | id=cfg.wandb.run_id, 227 | ) 228 | 229 | trainer = pl.Trainer( 230 | callbacks=callbacks, 231 | logger=training_logger, 232 | **params, 233 | ) 234 | 235 | return trainer 236 | 237 | 238 | def build_training_engine(cfg: DictConfig, worker: WorkerPool) -> TrainingEngine: 239 | """ 240 | Build the three core lightning modules: LightningDataModule, LightningModule and Trainer 241 | :param cfg: omegaconf dictionary 242 | :param worker: Worker to submit tasks which can be executed in parallel 243 | :return: TrainingEngine 244 | """ 245 | logger.info("Building training engine...") 246 | 247 | trainer = build_custom_trainer(cfg) 248 | 249 | # Create model 250 | torch_module_wrapper = build_torch_module_wrapper(cfg.model) 251 | 252 | # Build the datamodule 253 | datamodule = build_lightning_datamodule(cfg, worker, torch_module_wrapper) 254 | 255 | # Build lightning module 256 | model = build_lightning_module(cfg, torch_module_wrapper) 257 | 258 | engine = TrainingEngine(trainer=trainer, datamodule=datamodule, model=model) 259 | 260 | return engine 261 | -------------------------------------------------------------------------------- /src/feature_builders/common/route_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy as np 4 | from nuplan.common.actor_state.ego_state import EgoState 5 | from nuplan.common.actor_state.state_representation import StateSE2 6 | from nuplan.common.maps.abstract_map import AbstractMap 7 | from nuplan.common.maps.abstract_map_objects import RoadBlockGraphEdgeMapObject 8 | from nuplan.common.maps.maps_datatypes import SemanticMapLayer 9 | from nuplan.planning.simulation.occupancy_map.strtree_occupancy_map import ( 10 | STRTreeOccupancyMapFactory, 11 | ) 12 | 13 | from .bfs_roadblock import BreadthFirstSearchRoadBlock 14 | from .utils import normalize_angle 15 | 16 | 17 | def get_current_roadblock_candidates( 18 | ego_state: EgoState, 19 | map_api: AbstractMap, 20 | route_roadblocks_dict: Dict[str, RoadBlockGraphEdgeMapObject], 21 | heading_error_thresh: float = np.pi / 4, 22 | displacement_error_thresh: float = 3, 23 | ) -> Tuple[RoadBlockGraphEdgeMapObject, List[RoadBlockGraphEdgeMapObject]]: 24 | """ 25 | Determines a set of roadblock candidate where ego is located 26 | :param ego_state: class containing ego state 27 | :param map_api: map object 28 | :param route_roadblocks_dict: dictionary of on-route roadblocks 29 | :param heading_error_thresh: maximum heading error, defaults to np.pi/4 30 | :param displacement_error_thresh: maximum displacement, defaults to 3 31 | :return: tuple of most promising roadblock and other candidates 32 | """ 33 | ego_pose: StateSE2 = ego_state.rear_axle 34 | roadblock_candidates = [] 35 | 36 | layers = [SemanticMapLayer.ROADBLOCK, SemanticMapLayer.ROADBLOCK_CONNECTOR] 37 | roadblock_dict = map_api.get_proximal_map_objects( 38 | point=ego_pose.point, radius=1.0, layers=layers 39 | ) 40 | roadblock_candidates = ( 41 | roadblock_dict[SemanticMapLayer.ROADBLOCK] 42 | + roadblock_dict[SemanticMapLayer.ROADBLOCK_CONNECTOR] 43 | ) 44 | 45 | if not roadblock_candidates: 46 | for layer in layers: 47 | roadblock_id_, distance = map_api.get_distance_to_nearest_map_object( 48 | point=ego_pose.point, layer=layer 49 | ) 50 | roadblock = map_api.get_map_object(roadblock_id_, layer) 51 | 52 | if roadblock: 53 | roadblock_candidates.append(roadblock) 54 | 55 | on_route_candidates, on_route_candidate_displacement_errors = [], [] 56 | candidates, candidate_displacement_errors = [], [] 57 | 58 | roadblock_displacement_errors = [] 59 | roadblock_heading_errors = [] 60 | 61 | for idx, roadblock in enumerate(roadblock_candidates): 62 | lane_displacement_error, lane_heading_error = np.inf, np.inf 63 | 64 | for lane in roadblock.interior_edges: 65 | lane_discrete_path: List[StateSE2] = lane.baseline_path.discrete_path 66 | lane_discrete_points = np.array( 67 | [state.point.array for state in lane_discrete_path], dtype=np.float64 68 | ) 69 | lane_state_distances = ( 70 | (lane_discrete_points - ego_pose.point.array[None, ...]) ** 2.0 71 | ).sum(axis=-1) ** 0.5 72 | argmin = np.argmin(lane_state_distances) 73 | 74 | heading_error = np.abs( 75 | normalize_angle(lane_discrete_path[argmin].heading - ego_pose.heading) 76 | ) 77 | displacement_error = lane_state_distances[argmin] 78 | 79 | if displacement_error < lane_displacement_error: 80 | lane_heading_error, lane_displacement_error = ( 81 | heading_error, 82 | displacement_error, 83 | ) 84 | 85 | if ( 86 | heading_error < heading_error_thresh 87 | and displacement_error < displacement_error_thresh 88 | ): 89 | if roadblock.id in route_roadblocks_dict.keys(): 90 | on_route_candidates.append(roadblock) 91 | on_route_candidate_displacement_errors.append(displacement_error) 92 | else: 93 | candidates.append(roadblock) 94 | candidate_displacement_errors.append(displacement_error) 95 | 96 | roadblock_displacement_errors.append(lane_displacement_error) 97 | roadblock_heading_errors.append(lane_heading_error) 98 | 99 | if on_route_candidates: # prefer on-route roadblocks 100 | return ( 101 | on_route_candidates[np.argmin(on_route_candidate_displacement_errors)], 102 | on_route_candidates, 103 | ) 104 | elif candidates: # fallback to most promising candidate 105 | return candidates[np.argmin(candidate_displacement_errors)], candidates 106 | 107 | # otherwise, just find any close roadblock 108 | return ( 109 | roadblock_candidates[np.argmin(roadblock_displacement_errors)], 110 | roadblock_candidates, 111 | ) 112 | 113 | 114 | def route_roadblock_correction( 115 | ego_state: EgoState, 116 | map_api: AbstractMap, 117 | route_roadblock_ids: List[str], 118 | search_depth_backward: int = 15, 119 | search_depth_forward: int = 30, 120 | ) -> List[str]: 121 | """ 122 | Applies several methods to correct route roadblocks. 123 | :param ego_state: class containing ego state 124 | :param map_api: map object 125 | :param route_roadblocks_dict: dictionary of on-route roadblocks 126 | :param search_depth_backward: depth of forward BFS search, defaults to 15 127 | :param search_depth_forward: depth of backward BFS search, defaults to 30 128 | :return: list of roadblock id's of corrected route 129 | """ 130 | 131 | route_roadblock_dict = {} 132 | for id_ in route_roadblock_ids: 133 | block = map_api.get_map_object(id_, SemanticMapLayer.ROADBLOCK) 134 | block = block or map_api.get_map_object( 135 | id_, SemanticMapLayer.ROADBLOCK_CONNECTOR 136 | ) 137 | route_roadblock_dict[id_] = block 138 | 139 | starting_block, starting_block_candidates = get_current_roadblock_candidates( 140 | ego_state, map_api, route_roadblock_dict 141 | ) 142 | starting_block_ids = [roadblock.id for roadblock in starting_block_candidates] 143 | 144 | route_roadblocks = list(route_roadblock_dict.values()) 145 | route_roadblock_ids = list(route_roadblock_dict.keys()) 146 | 147 | # Fix 1: when agent starts off-route 148 | if starting_block.id not in route_roadblock_ids: 149 | # Backward search if current roadblock not in route 150 | graph_search = BreadthFirstSearchRoadBlock( 151 | route_roadblock_ids[0], map_api, forward_search=False 152 | ) 153 | (path, path_id), path_found = graph_search.search( 154 | starting_block_ids, max_depth=search_depth_backward 155 | ) 156 | 157 | if path_found: 158 | route_roadblocks[:0] = path[:-1] 159 | route_roadblock_ids[:0] = path_id[:-1] 160 | 161 | else: 162 | # Forward search to any route roadblock 163 | graph_search = BreadthFirstSearchRoadBlock( 164 | starting_block.id, map_api, forward_search=True 165 | ) 166 | (path, path_id), path_found = graph_search.search( 167 | route_roadblock_ids[:3], max_depth=search_depth_forward 168 | ) 169 | 170 | if path_found: 171 | end_roadblock_idx = np.argmax( 172 | np.array(route_roadblock_ids) == path_id[-1] 173 | ) 174 | 175 | route_roadblocks = route_roadblocks[end_roadblock_idx + 1 :] 176 | route_roadblock_ids = route_roadblock_ids[end_roadblock_idx + 1 :] 177 | 178 | route_roadblocks[:0] = path 179 | route_roadblock_ids[:0] = path_id 180 | 181 | # Fix 2: check if roadblocks are linked, search for links if not 182 | roadblocks_to_append = {} 183 | for i in range(len(route_roadblocks) - 1): 184 | next_incoming_block_ids = [ 185 | _roadblock.id for _roadblock in route_roadblocks[i + 1].incoming_edges 186 | ] 187 | is_incoming = route_roadblock_ids[i] in next_incoming_block_ids 188 | 189 | if is_incoming: 190 | continue 191 | 192 | graph_search = BreadthFirstSearchRoadBlock( 193 | route_roadblock_ids[i], map_api, forward_search=True 194 | ) 195 | (path, path_id), path_found = graph_search.search( 196 | route_roadblock_ids[i + 1], max_depth=search_depth_forward 197 | ) 198 | 199 | if path_found and path and len(path) >= 3: 200 | path, path_id = path[1:-1], path_id[1:-1] 201 | roadblocks_to_append[i] = (path, path_id) 202 | 203 | # append missing intermediate roadblocks 204 | offset = 1 205 | for i, (path, path_id) in roadblocks_to_append.items(): 206 | route_roadblocks[i + offset : i + offset] = path 207 | route_roadblock_ids[i + offset : i + offset] = path_id 208 | offset += len(path) 209 | 210 | # Fix 3: cut route-loops 211 | route_roadblocks, route_roadblock_ids = remove_route_loops( 212 | route_roadblocks, route_roadblock_ids 213 | ) 214 | 215 | return route_roadblock_ids 216 | 217 | 218 | def remove_route_loops( 219 | route_roadblocks: List[RoadBlockGraphEdgeMapObject], 220 | route_roadblock_ids: List[str], 221 | ) -> Tuple[List[str], List[RoadBlockGraphEdgeMapObject]]: 222 | """ 223 | Remove ending of route, if the roadblock are intersecting the route (forming a loop). 224 | :param route_roadblocks: input route roadblocks 225 | :param route_roadblock_ids: input route roadblocks ids 226 | :return: tuple of ids and roadblocks of route without loops 227 | """ 228 | 229 | roadblock_occupancy_map = None 230 | loop_idx = None 231 | 232 | for idx, roadblock in enumerate(route_roadblocks): 233 | # loops only occur at intersection, thus searching for roadblock-connectors. 234 | if str(roadblock.__class__.__name__) == "NuPlanRoadBlockConnector": 235 | if not roadblock_occupancy_map: 236 | roadblock_occupancy_map = STRTreeOccupancyMapFactory.get_from_geometry( 237 | [roadblock.polygon], [roadblock.id] 238 | ) 239 | continue 240 | 241 | strtree, index_by_id = roadblock_occupancy_map._build_strtree() 242 | indices = strtree.query(roadblock.polygon) 243 | if len(indices) > 0: 244 | for geom in strtree.geometries.take(indices): 245 | area = geom.intersection(roadblock.polygon).area 246 | if area > 1: 247 | loop_idx = idx 248 | break 249 | if loop_idx: 250 | break 251 | 252 | roadblock_occupancy_map.insert(roadblock.id, roadblock.polygon) 253 | 254 | if loop_idx: 255 | route_roadblocks = route_roadblocks[:loop_idx] 256 | route_roadblock_ids = route_roadblock_ids[:loop_idx] 257 | 258 | return route_roadblocks, route_roadblock_ids 259 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the official repository of 2 | 3 | **Rethink Imitation-based Planner for Autonomous Driving**, 4 | *Jie Cheng,Yingbing chen,Xiaodong Mei,Bowen Yang,Bo Li and Ming Liu*, arXiv 2023 5 | 6 |

7 | 8 | 9 | 10 | 11 | arXiv PDF 12 | 13 |

14 | 15 | 16 | 17 | 18 | ## Highlight 19 | - A good starting point for research on learning-based planner on the [nuPlan](https://www.nuscenes.org/nuplan) dataset. This repo provides detailed instructions on data preprocess, training and benchmark. 20 | - A simple pure learning-based baseline model **planTF**, that achieves decent performance **without** any rule-based strategies or post-optimization. 21 | 22 | ## Get Started 23 | 24 | - [Get Started](#get-started) 25 | - [Setup Environment](#setup-environment) 26 | - [Feature cache](#feature-cache) 27 | - [Training](#training) 28 | - [Trained models](#trained-models) 29 | - [Evaluation](#evaluation) 30 | - [Results](#results) 31 | - [Acknowledgements](#acknowledgements) 32 | - [Citation](#citation) 33 | 34 | 35 | ## Setup Environment 36 | 37 | - setup the nuPlan dataset following the [offiical-doc](https://nuplan-devkit.readthedocs.io/en/latest/dataset_setup.html) 38 | - setup conda environment 39 | ``` 40 | conda create -n plantf python=3.9 41 | conda activate plantf 42 | 43 | # install nuplan-devkit 44 | git clone https://github.com/motional/nuplan-devkit.git && cd nuplan-devkit 45 | pip install -e . 46 | pip install -r ./requirements.txt 47 | 48 | # setup planTF 49 | cd .. 50 | git clone https://github.com/jchengai/planTF.git && cd planTF 51 | sh ./script/setup_env.sh 52 | ``` 53 | 54 | ## Feature cache 55 | 56 | Preprocess the dataset to accelerate training. The following command generates 1M frames of training data from the whole nuPlan training set. You may need: 57 | - change `cache.cache_path` to suit your condition 58 | - decrease/increase `worker.threads_per_node` depends on your RAM and CPU. 59 | 60 | ```sh 61 | export PYTHONPATH=$PYTHONPATH:$(pwd) 62 | 63 | python run_training.py \ 64 | py_func=cache +training=train_planTF \ 65 | scenario_builder=nuplan \ 66 | cache.cache_path=/nuplan/exp/cache_plantf_1M \ 67 | cache.cleanup_cache=true \ 68 | scenario_filter=training_scenarios_1M \ 69 | worker.threads_per_node=40 70 | ``` 71 | 72 | This process may take some time, be patient (20+hours in my setting). 73 | 74 | ## Training 75 | 76 | We modified the training scipt provided by [nuplan-devkit](https://github.com/autonomousvision/tuplan_garage) a little bit for more flexible training. 77 | By default, the training script will use all visible GPUs for training. PlanTF is quite lightweight, which takes about 4~6G GPU memory under the batch size of 32 (each GPU). 78 | 79 | ```sh 80 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_training.py \ 81 | py_func=train +training=train_planTF \ 82 | worker=single_machine_thread_pool worker.max_workers=32 \ 83 | scenario_builder=nuplan cache.cache_path=/nuplan/exp/cache_plantf_1M cache.use_cache_without_dataset=true \ 84 | data_loader.params.batch_size=32 data_loader.params.num_workers=32 \ 85 | lr=1e-3 epochs=25 warmup_epochs=3 weight_decay=0.0001 \ 86 | lightning.trainer.params.val_check_interval=0.5 \ 87 | wandb.mode=online wandb.project=nuplan wandb.name=plantf 88 | ``` 89 | 90 | you can remove wandb related configurations if your prefer tensorboard. 91 | 92 | ## Trained models 93 | 94 | Place the trained models at `planTF/checkpoints/` 95 | 96 | | Model | Document | Download | 97 | | ---------------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------- | 98 | | PlanTF (state6+SDE) | - | [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchengai_connect_ust_hk/EW7HbklkAhVNpcDUEga2aLABxioVA1S98vyqk2VbziYfTw?e=fe3CxI) | 99 | | RasterModel | [Doc](./docs/other_baselines.md#rastermodel) | [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchengai_connect_ust_hk/EcfVyHFUoV1KhAv7D_JPqtwBlwR-2zT2suGHD1rLXsBtKA?e=PIwD7U) | 100 | | UrbanDriver (openloop) | [Doc](./docs/other_baselines.md#urbandriver-open-loop) | [OneDrive](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jchengai_connect_ust_hk/EbM_BSpFS9NBqIWuhlVHMrYBMrSOtusHjH6hwfamZCuI_Q?e=Q2bN75) | 101 | 102 | 103 | 104 | ## Evaluation 105 | 106 | 107 | - run a single scenario simulation (for sanity check): `sh ./script/plantf_single_scenarios.sh` 108 | - run **Test14-random**: `sh ./script/plantf_benchmarks.sh test14-random` 109 | - run **Test14-hard**: `sh ./script/plantf_benchmarks.sh test14-hard` 110 | - run **Val14** (this may take a long time): `sh ./script/plantf_benchmarks.sh val14` 111 | 112 | ## Results 113 | 114 | ### Test14-random and Test14-hard benchmarks 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 |
PlannersTest14-randomTest14-hard
TypeMethodOLS↑NR-CLS↑R-CLS↑OLS↑NR-CLS↑R-CLS↑Time
ExpertLogReplay100.094.0375.86100.085.9668.80-
Rule-basedIDM34.1570.3972.4220.0756.1662.2632
PDM-Closed46.3290.0591.6426.4365.0775.18140
HybridGameFormer79.3580.8079.3175.2766.5968.83443
PDM-Hybrid82.2190.2091.5673.8165.9575.79152
Learning-based

PlanCNN62.9369.6667.5452.449.4752.1682
UrbanDriver 82.4463.2761.0276.951.5449.07124
GC-PGP77.3355.9951.3973.7843.2239.63160
PDM-Open84.1452.8057.2379.0633.5135.83101
PlanTF (Ours)87.0786.4880.5983.3272.6861.7155
243 | 244 |

245 | open-loop re-implementation 246 |

247 | 248 | ### Val14 benchmark 249 | 250 | | Method | OLS | NR-CLS | R-CLS | 251 | | ------------- | ----- | ------ | ----- | 252 | | Log-replay | 100 | 94 | 80 | 253 | | IDM | 38 | 77 | 76 | 254 | | GC-PGP | 82 | 57 | 54 | 255 | | PlanCNN | 64 | 73 | 72 | 256 | | PDM-Hybrid | 84 | 93 | 92 | 257 | | PlanTF (Ours) | 89.18 | 84.83 | 76.78 | 258 | 259 | ## Acknowledgements 260 | 261 | Many thanks to the open-source community, also checkout these works: 262 | - [tuplan_garage](https://github.com/autonomousvision/tuplan_garage) 263 | - [GameFormer-Planner](https://github.com/MCZhi/GameFormer-Planner) 264 | 265 | ## Citation 266 | 267 | If you find this repo useful, please consider giving us a star 🌟 and citing our related paper. 268 | 269 | ```bibtex 270 | @misc{cheng2023plantf, 271 | title={Rethinking Imitation-based Planner for Autonomous Driving}, 272 | author={Jie Cheng and Yingbing Chen and Xiaodong Mei and Bowen Yang and Bo Li and Ming Liu}, 273 | year={2023}, 274 | eprint={2309.10443}, 275 | archivePrefix={arXiv}, 276 | primaryClass={cs.RO} 277 | } 278 | ``` 279 | -------------------------------------------------------------------------------- /src/custom_training/custom_datamodule.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.utils.data 8 | from omegaconf import DictConfig 9 | from torch.utils.data.sampler import WeightedRandomSampler 10 | 11 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario 12 | from nuplan.planning.training.data_augmentation.abstract_data_augmentation import ( 13 | AbstractAugmentor, 14 | ) 15 | from nuplan.planning.training.data_loader.distributed_sampler_wrapper import ( 16 | DistributedSamplerWrapper, 17 | ) 18 | from nuplan.planning.training.data_loader.scenario_dataset import ScenarioDataset 19 | from nuplan.planning.training.data_loader.splitter import AbstractSplitter 20 | from nuplan.planning.training.modeling.types import ( 21 | FeaturesType, 22 | move_features_type_to_device, 23 | ) 24 | from nuplan.planning.training.preprocessing.feature_collate import FeatureCollate 25 | from nuplan.planning.training.preprocessing.feature_preprocessor import ( 26 | FeaturePreprocessor, 27 | ) 28 | from nuplan.planning.utils.multithreading.worker_pool import WorkerPool 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | DataModuleNotSetupError = RuntimeError('Data module has not been setup, call "setup()"') 33 | 34 | 35 | def create_dataset( 36 | samples: List[AbstractScenario], 37 | feature_preprocessor: FeaturePreprocessor, 38 | dataset_fraction: float, 39 | dataset_name: str, 40 | augmentors: Optional[List[AbstractAugmentor]] = None, 41 | ) -> torch.utils.data.Dataset: 42 | """ 43 | Create a dataset from a list of samples. 44 | :param samples: List of dataset candidate samples. 45 | :param feature_preprocessor: Feature preprocessor object. 46 | :param dataset_fraction: Fraction of the dataset to load. 47 | :param dataset_name: Set name (train/val/test). 48 | :param scenario_type_loss_weights: Dictionary of scenario type loss weights. 49 | :param augmentors: List of augmentor objects for providing data augmentation to data samples. 50 | :return: The instantiated torch dataset. 51 | """ 52 | # Sample the desired fraction from the total samples 53 | num_keep = int(len(samples) * dataset_fraction) 54 | selected_scenarios = random.sample(samples, num_keep) 55 | 56 | logger.info(f"Number of samples in {dataset_name} set: {len(selected_scenarios)}") 57 | return ScenarioDataset( 58 | scenarios=selected_scenarios, 59 | feature_preprocessor=feature_preprocessor, 60 | augmentors=augmentors, 61 | ) 62 | 63 | 64 | def distributed_weighted_sampler_init( 65 | scenario_dataset: ScenarioDataset, 66 | scenario_sampling_weights: Dict[str, float], 67 | replacement: bool = True, 68 | ) -> WeightedRandomSampler: 69 | """ 70 | Initiliazes WeightedSampler object with sampling weights for each scenario_type and returns it. 71 | :param scenario_dataset: ScenarioDataset object 72 | :param replacement: Samples with replacement if True. By default set to True. 73 | return: Initialized Weighted sampler 74 | """ 75 | scenarios = scenario_dataset._scenarios 76 | if ( 77 | not replacement 78 | ): # If we don't sample with replacement, then all sample weights must be nonzero 79 | assert all( 80 | w > 0 for w in scenario_sampling_weights.values() 81 | ), "All scenario sampling weights must be positive when sampling without replacement." 82 | 83 | default_scenario_sampling_weight = 1.0 84 | 85 | scenario_sampling_weights_per_idx = [ 86 | scenario_sampling_weights[scenario.scenario_type] 87 | if scenario.scenario_type in scenario_sampling_weights 88 | else default_scenario_sampling_weight 89 | for scenario in scenarios 90 | ] 91 | 92 | # Create weighted sampler 93 | weighted_sampler = WeightedRandomSampler( 94 | weights=scenario_sampling_weights_per_idx, 95 | num_samples=len(scenarios), 96 | replacement=replacement, 97 | ) 98 | 99 | distributed_weighted_sampler = DistributedSamplerWrapper(weighted_sampler) 100 | return distributed_weighted_sampler 101 | 102 | 103 | class CustomDataModule(pl.LightningDataModule): 104 | """ 105 | Datamodule wrapping all preparation and dataset creation functionality. 106 | """ 107 | 108 | def __init__( 109 | self, 110 | feature_preprocessor: FeaturePreprocessor, 111 | splitter: AbstractSplitter, 112 | all_scenarios: List[AbstractScenario], 113 | train_fraction: float, 114 | val_fraction: float, 115 | test_fraction: float, 116 | dataloader_params: Dict[str, Any], 117 | scenario_type_sampling_weights: DictConfig, 118 | worker: WorkerPool, 119 | augmentors: Optional[List[AbstractAugmentor]] = None, 120 | ) -> None: 121 | """ 122 | Initialize the class. 123 | :param feature_preprocessor: Feature preprocessor object. 124 | :param splitter: Splitter object used to retrieve lists of samples to construct train/val/test sets. 125 | :param train_fraction: Fraction of training examples to load. 126 | :param val_fraction: Fraction of validation examples to load. 127 | :param test_fraction: Fraction of test examples to load. 128 | :param dataloader_params: Parameter dictionary passed to the dataloaders. 129 | :param augmentors: Augmentor object for providing data augmentation to data samples. 130 | """ 131 | super().__init__() 132 | 133 | assert train_fraction > 0.0, "Train fraction has to be larger than 0!" 134 | assert val_fraction > 0.0, "Validation fraction has to be larger than 0!" 135 | assert test_fraction >= 0.0, "Test fraction has to be larger/equal than 0!" 136 | 137 | # Datasets 138 | self._train_set: Optional[torch.utils.data.Dataset] = None 139 | self._val_set: Optional[torch.utils.data.Dataset] = None 140 | self._test_set: Optional[torch.utils.data.Dataset] = None 141 | 142 | # Feature computation 143 | self._feature_preprocessor = feature_preprocessor 144 | 145 | # Data splitter train/test/val 146 | self._splitter = splitter 147 | 148 | # Fractions 149 | self._train_fraction = train_fraction 150 | self._val_fraction = val_fraction 151 | self._test_fraction = test_fraction 152 | 153 | # Data loader for train/val/test 154 | self._dataloader_params = dataloader_params 155 | 156 | # Extract all samples 157 | self._all_samples = all_scenarios 158 | assert len(self._all_samples) > 0, "No samples were passed to the datamodule" 159 | 160 | # Scenario sampling weights 161 | self._scenario_type_sampling_weights = scenario_type_sampling_weights 162 | 163 | # Augmentation setup 164 | self._augmentors = augmentors 165 | 166 | # Worker for multiprocessing to speed up initialization of datasets 167 | self._worker = worker 168 | 169 | @property 170 | def feature_and_targets_builder(self) -> FeaturePreprocessor: 171 | """Get feature and target builders.""" 172 | return self._feature_preprocessor 173 | 174 | def setup(self, stage: Optional[str] = None) -> None: 175 | """ 176 | Set up the dataset for each target set depending on the training stage. 177 | This is called by every process in distributed training. 178 | :param stage: Stage of training, can be "fit" or "test". 179 | """ 180 | if stage is None: 181 | return 182 | 183 | if stage == "fit": 184 | # Training Dataset 185 | train_samples = self._splitter.get_train_samples( 186 | self._all_samples, self._worker 187 | ) 188 | assert len(train_samples) > 0, "Splitter returned no training samples" 189 | 190 | self._train_set = create_dataset( 191 | train_samples, 192 | self._feature_preprocessor, 193 | self._train_fraction, 194 | "train", 195 | self._augmentors, 196 | ) 197 | 198 | # Validation Dataset 199 | val_samples = self._splitter.get_val_samples( 200 | self._all_samples, self._worker 201 | ) 202 | assert len(val_samples) > 0, "Splitter returned no validation samples" 203 | 204 | self._val_set = create_dataset( 205 | val_samples, 206 | self._feature_preprocessor, 207 | self._val_fraction, 208 | "validation", 209 | ) 210 | elif stage == "test": 211 | # Testing Dataset 212 | test_samples = self._splitter.get_test_samples( 213 | self._all_samples, self._worker 214 | ) 215 | assert len(test_samples) > 0, "Splitter returned no test samples" 216 | 217 | self._test_set = create_dataset( 218 | test_samples, self._feature_preprocessor, self._test_fraction, "test" 219 | ) 220 | else: 221 | raise ValueError(f'Stage must be one of ["fit", "test"], got ${stage}.') 222 | 223 | def teardown(self, stage: Optional[str] = None) -> None: 224 | """ 225 | Clean up after a training stage. 226 | This is called by every process in distributed training. 227 | :param stage: Stage of training, can be "fit" or "test". 228 | """ 229 | pass 230 | 231 | def train_dataloader(self) -> torch.utils.data.DataLoader: 232 | """ 233 | Create the training dataloader. 234 | :raises RuntimeError: If this method is called without calling "setup()" first. 235 | :return: The instantiated torch dataloader. 236 | """ 237 | if self._train_set is None: 238 | raise DataModuleNotSetupError 239 | 240 | # Initialize weighted sampler 241 | if self._scenario_type_sampling_weights.enable: 242 | weighted_sampler = distributed_weighted_sampler_init( 243 | scenario_dataset=self._train_set, 244 | scenario_sampling_weights=self._scenario_type_sampling_weights.scenario_type_weights, 245 | ) 246 | else: 247 | weighted_sampler = None 248 | 249 | return torch.utils.data.DataLoader( 250 | dataset=self._train_set, 251 | shuffle=weighted_sampler is None, 252 | collate_fn=FeatureCollate(), 253 | sampler=weighted_sampler, 254 | **self._dataloader_params, 255 | ) 256 | 257 | def val_dataloader(self) -> torch.utils.data.DataLoader: 258 | """ 259 | Create the validation dataloader. 260 | :raises RuntimeError: if this method is called without calling "setup()" first. 261 | :return: The instantiated torch dataloader. 262 | """ 263 | if self._val_set is None: 264 | raise DataModuleNotSetupError 265 | 266 | return torch.utils.data.DataLoader( 267 | dataset=self._val_set, 268 | **self._dataloader_params, 269 | collate_fn=FeatureCollate(), 270 | ) 271 | 272 | def test_dataloader(self) -> torch.utils.data.DataLoader: 273 | """ 274 | Create the test dataloader. 275 | :raises RuntimeError: if this method is called without calling "setup()" first. 276 | :return: The instantiated torch dataloader. 277 | """ 278 | if self._test_set is None: 279 | raise DataModuleNotSetupError 280 | 281 | return torch.utils.data.DataLoader( 282 | dataset=self._test_set, 283 | **self._dataloader_params, 284 | collate_fn=FeatureCollate(), 285 | ) 286 | 287 | # ! Modified to adapt to newer version of pytorch-lightning 288 | def transfer_batch_to_device( 289 | self, batch: Tuple[FeaturesType, ...], device: torch.device, dataloader_idx: int 290 | ) -> Tuple[FeaturesType, ...]: 291 | """ 292 | Transfer a batch to device. 293 | :param batch: Batch on origin device. 294 | :param device: Desired device. 295 | :return: Batch in new device. 296 | """ 297 | return tuple( 298 | ( 299 | move_features_type_to_device(batch[0], device), 300 | move_features_type_to_device(batch[1], device), 301 | batch[2], 302 | ) 303 | ) 304 | -------------------------------------------------------------------------------- /src/feature_builders/nuplan_feature_builder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Type 2 | 3 | import numpy as np 4 | import shapely 5 | from nuplan.common.actor_state.ego_state import EgoState 6 | from nuplan.common.actor_state.state_representation import Point2D, StateSE2 7 | from nuplan.common.actor_state.tracked_objects import TrackedObjects 8 | from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType 9 | from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters 10 | from nuplan.common.maps.abstract_map import AbstractMap, PolygonMapObject 11 | from nuplan.common.maps.maps_datatypes import ( 12 | SemanticMapLayer, 13 | TrafficLightStatusData, 14 | TrafficLightStatusType, 15 | ) 16 | from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario 17 | from nuplan.planning.simulation.planner.abstract_planner import ( 18 | PlannerInitialization, 19 | PlannerInput, 20 | ) 21 | from nuplan.planning.training.preprocessing.feature_builders.abstract_feature_builder import ( 22 | AbstractFeatureBuilder, 23 | ) 24 | from nuplan.planning.training.preprocessing.features.abstract_model_feature import ( 25 | AbstractModelFeature, 26 | ) 27 | 28 | from ..features.nuplan_feature import NuplanFeature 29 | from .common.route_utils import route_roadblock_correction 30 | from .common.utils import interpolate_polyline, rotate_round_z_axis 31 | 32 | 33 | class NuplanFeatureBuilder(AbstractFeatureBuilder): 34 | def __init__( 35 | self, 36 | radius: float = 100, 37 | history_horizon: float = 2, 38 | future_horizon: float = 8, 39 | sample_interval: float = 0.1, 40 | max_agents: int = 64, 41 | ) -> None: 42 | super().__init__() 43 | 44 | self.radius = radius 45 | self.history_horizon = history_horizon 46 | self.future_horizon = future_horizon 47 | self.history_samples = int(self.history_horizon / sample_interval) 48 | self.future_samples = int(self.future_horizon / sample_interval) 49 | self.sample_interval = sample_interval 50 | self.ego_params = get_pacifica_parameters() 51 | self.length = self.ego_params.length 52 | self.width = self.ego_params.width 53 | self.max_agents = max_agents 54 | 55 | self.interested_objects_types = [ 56 | TrackedObjectType.EGO, 57 | TrackedObjectType.VEHICLE, 58 | TrackedObjectType.PEDESTRIAN, 59 | TrackedObjectType.BICYCLE, 60 | ] 61 | self.polygon_types = [ 62 | SemanticMapLayer.LANE, 63 | SemanticMapLayer.LANE_CONNECTOR, 64 | SemanticMapLayer.CROSSWALK, 65 | ] 66 | 67 | def get_feature_type(self) -> Type[AbstractModelFeature]: 68 | """Inherited, see superclass.""" 69 | return NuplanFeature # type: ignore 70 | 71 | @classmethod 72 | def get_feature_unique_name(cls) -> str: 73 | """Inherited, see superclass.""" 74 | return "feature" 75 | 76 | def get_features_from_scenario( 77 | self, scenario: AbstractScenario 78 | ) -> AbstractModelFeature: 79 | ego_cur_state = scenario.initial_ego_state 80 | 81 | # ego features 82 | past_ego_trajectory = scenario.get_ego_past_trajectory( 83 | iteration=0, 84 | time_horizon=self.history_horizon, 85 | num_samples=self.history_samples, 86 | ) 87 | future_ego_trajectory = scenario.get_ego_future_trajectory( 88 | iteration=0, 89 | time_horizon=self.future_horizon, 90 | num_samples=self.future_samples, 91 | ) 92 | ego_state_list = ( 93 | list(past_ego_trajectory) + [ego_cur_state] + list(future_ego_trajectory) 94 | ) 95 | 96 | # agents features 97 | present_tracked_objects = scenario.initial_tracked_objects.tracked_objects 98 | past_tracked_objects = [ 99 | tracked_objects.tracked_objects 100 | for tracked_objects in scenario.get_past_tracked_objects( 101 | iteration=0, 102 | time_horizon=self.history_horizon, 103 | num_samples=self.history_samples, 104 | ) 105 | ] 106 | future_tracked_objects = [ 107 | tracked_objects.tracked_objects 108 | for tracked_objects in scenario.get_future_tracked_objects( 109 | iteration=0, 110 | time_horizon=self.future_horizon, 111 | num_samples=self.future_samples, 112 | ) 113 | ] 114 | tracked_objects_list = ( 115 | past_tracked_objects + [present_tracked_objects] + future_tracked_objects 116 | ) 117 | 118 | return self._build_feature( 119 | present_idx=self.history_samples, 120 | ego_state_list=ego_state_list, 121 | tracked_objects_list=tracked_objects_list, 122 | route_roadblocks_ids=scenario.get_route_roadblock_ids(), 123 | map_api=scenario.map_api, 124 | mission_goal=scenario.get_mission_goal(), 125 | traffic_light_status=scenario.get_traffic_light_status_at_iteration(0), 126 | ) 127 | 128 | def get_features_from_simulation( 129 | self, current_input: PlannerInput, initialization: PlannerInitialization 130 | ) -> AbstractModelFeature: 131 | history = current_input.history 132 | tracked_objects_list = [ 133 | observation.tracked_objects for observation in history.observations 134 | ] 135 | 136 | horizon = self.history_samples + 1 137 | return self._build_feature( 138 | present_idx=-1, 139 | ego_state_list=history.ego_states[-horizon:], 140 | tracked_objects_list=tracked_objects_list[-horizon:], 141 | route_roadblocks_ids=initialization.route_roadblock_ids, 142 | map_api=initialization.map_api, 143 | mission_goal=initialization.mission_goal, 144 | traffic_light_status=current_input.traffic_light_data, 145 | ) 146 | 147 | def _build_feature( 148 | self, 149 | present_idx: int, 150 | ego_state_list: List[EgoState], 151 | tracked_objects_list: List[TrackedObjects], 152 | route_roadblocks_ids: list[int], 153 | map_api: AbstractMap, 154 | mission_goal: StateSE2, 155 | traffic_light_status: List[TrafficLightStatusData] = None, 156 | ): 157 | present_ego_state = ego_state_list[present_idx] 158 | query_xy = present_ego_state.center 159 | 160 | route_roadblocks_ids = route_roadblock_correction( 161 | present_ego_state, map_api, route_roadblocks_ids 162 | ) 163 | 164 | data = {} 165 | data["current_state"] = self._get_ego_current_state( 166 | ego_state_list[present_idx], ego_state_list[present_idx - 1] 167 | ) 168 | 169 | ego_features = self._get_ego_features(ego_states=ego_state_list) 170 | agent_features = self._get_agent_features( 171 | query_xy=query_xy, 172 | present_idx=present_idx, 173 | tracked_objects_list=tracked_objects_list, 174 | ) 175 | 176 | data["agent"] = {} 177 | for k in agent_features.keys(): 178 | data["agent"][k] = np.concatenate( 179 | [ego_features[k][None, ...], agent_features[k]], axis=0 180 | ) 181 | 182 | data["map"] = self._get_map_features( 183 | map_api=map_api, 184 | query_xy=query_xy, 185 | route_roadblock_ids=route_roadblocks_ids, 186 | traffic_light_status=traffic_light_status, 187 | radius=self.radius, 188 | ) 189 | 190 | return NuplanFeature.normalize(data, first_time=True, radius=self.radius) 191 | 192 | def _get_ego_current_state(self, ego_state: EgoState, prev_state: EgoState): 193 | steering_angle, yaw_rate = self.calculate_additional_ego_states( 194 | ego_state, prev_state 195 | ) 196 | 197 | state = np.zeros(7, dtype=np.float64) 198 | state[0:2] = ego_state.rear_axle.array 199 | state[2] = ego_state.rear_axle.heading 200 | state[3] = ego_state.dynamic_car_state.rear_axle_velocity_2d.x 201 | state[4] = ego_state.dynamic_car_state.rear_axle_acceleration_2d.x 202 | state[5] = steering_angle 203 | state[6] = yaw_rate 204 | return state 205 | 206 | def _get_ego_features(self, ego_states: List[EgoState]): 207 | """note that rear axle velocity and acceleration are in ego local frame, 208 | and need to be transformed to the global frame. 209 | """ 210 | T = len(ego_states) 211 | 212 | position = np.zeros((T, 2), dtype=np.float64) 213 | heading = np.zeros((T), dtype=np.float64) 214 | velocity = np.zeros((T, 2), dtype=np.float64) 215 | acceleration = np.zeros((T, 2), dtype=np.float64) 216 | shape = np.zeros((T, 2), dtype=np.float64) 217 | valid_mask = np.ones(T, dtype=np.bool) 218 | 219 | for t, state in enumerate(ego_states): 220 | position[t] = state.rear_axle.array 221 | heading[t] = state.rear_axle.heading 222 | velocity[t] = rotate_round_z_axis( 223 | state.dynamic_car_state.rear_axle_velocity_2d.array, 224 | -state.rear_axle.heading, 225 | ) 226 | acceleration[t] = rotate_round_z_axis( 227 | state.dynamic_car_state.rear_axle_acceleration_2d.array, 228 | -state.rear_axle.heading, 229 | ) 230 | shape[t] = np.array([self.width, self.length]) 231 | 232 | category = np.array( 233 | self.interested_objects_types.index(TrackedObjectType.EGO), dtype=np.int8 234 | ) 235 | 236 | return { 237 | "position": position, 238 | "heading": heading, 239 | "velocity": velocity, 240 | "acceleration": acceleration, 241 | "shape": shape, 242 | "category": category, 243 | "valid_mask": valid_mask, 244 | } 245 | 246 | def _get_agent_features( 247 | self, 248 | query_xy: Point2D, 249 | present_idx: int, 250 | tracked_objects_list: List[TrackedObjects], 251 | ): 252 | present_tracked_objects = tracked_objects_list[present_idx] 253 | present_agents = present_tracked_objects.get_tracked_objects_of_types( 254 | self.interested_objects_types 255 | ) 256 | N, T = min(len(present_agents), self.max_agents), len(tracked_objects_list) 257 | 258 | position = np.zeros((N, T, 2), dtype=np.float64) 259 | heading = np.zeros((N, T), dtype=np.float64) 260 | velocity = np.zeros((N, T, 2), dtype=np.float64) 261 | shape = np.zeros((N, T, 2), dtype=np.float64) 262 | category = np.zeros((N,), dtype=np.int8) 263 | valid_mask = np.zeros((N, T), dtype=np.bool) 264 | 265 | if N == 0: 266 | return { 267 | "position": position, 268 | "heading": heading, 269 | "velocity": velocity, 270 | "shape": shape, 271 | "category": category, 272 | "valid_mask": valid_mask, 273 | } 274 | 275 | agent_ids = np.array([agent.track_token for agent in present_agents]) 276 | agent_cur_pos = np.array([agent.center.array for agent in present_agents]) 277 | distance = np.linalg.norm(agent_cur_pos - query_xy.array[None, :], axis=1) 278 | agent_ids_sorted = agent_ids[np.argsort(distance)[: self.max_agents]] 279 | agent_ids_sorted = {agent_id: i for i, agent_id in enumerate(agent_ids_sorted)} 280 | 281 | for t, tracked_objects in enumerate(tracked_objects_list): 282 | for agent in tracked_objects.get_tracked_objects_of_types( 283 | self.interested_objects_types 284 | ): 285 | if agent.track_token not in agent_ids_sorted: 286 | continue 287 | 288 | idx = agent_ids_sorted[agent.track_token] 289 | position[idx, t] = agent.center.array 290 | heading[idx, t] = agent.center.heading 291 | velocity[idx, t] = agent.velocity.array 292 | shape[idx, t] = np.array([agent.box.width, agent.box.length]) 293 | valid_mask[idx, t] = True 294 | 295 | if t == present_idx: 296 | category[idx] = self.interested_objects_types.index( 297 | agent.tracked_object_type 298 | ) 299 | 300 | return { 301 | "position": position, 302 | "heading": heading, 303 | "velocity": velocity, 304 | "shape": shape, 305 | "category": category, 306 | "valid_mask": valid_mask, 307 | } 308 | 309 | def _get_map_features( 310 | self, 311 | map_api: AbstractMap, 312 | query_xy: Point2D, 313 | route_roadblock_ids: List[str], 314 | traffic_light_status: List[TrafficLightStatusData], 315 | radius: float, 316 | sample_points: int = 20, 317 | ): 318 | route_ids = set(int(route_id) for route_id in route_roadblock_ids) 319 | tls = {tl.lane_connector_id: tl.status for tl in traffic_light_status} 320 | 321 | map_objects = map_api.get_proximal_map_objects( 322 | query_xy, 323 | radius, 324 | [ 325 | SemanticMapLayer.LANE, 326 | SemanticMapLayer.LANE_CONNECTOR, 327 | SemanticMapLayer.CROSSWALK, 328 | ], 329 | ) 330 | lane_objects = ( 331 | map_objects[SemanticMapLayer.LANE] 332 | + map_objects[SemanticMapLayer.LANE_CONNECTOR] 333 | ) 334 | crosswalk_objects = map_objects[SemanticMapLayer.CROSSWALK] 335 | 336 | object_ids = [int(obj.id) for obj in lane_objects + crosswalk_objects] 337 | object_types = ( 338 | [SemanticMapLayer.LANE] * len(map_objects[SemanticMapLayer.LANE]) 339 | + [SemanticMapLayer.LANE_CONNECTOR] 340 | * len(map_objects[SemanticMapLayer.LANE_CONNECTOR]) 341 | + [SemanticMapLayer.CROSSWALK] 342 | * len(map_objects[SemanticMapLayer.CROSSWALK]) 343 | ) 344 | 345 | M, P = len(lane_objects) + len(crosswalk_objects), sample_points 346 | point_position = np.zeros((M, 3, P, 2), dtype=np.float64) 347 | point_vector = np.zeros((M, 3, P, 2), dtype=np.float64) 348 | point_side = np.zeros((M, 3), dtype=np.int8) 349 | point_orientation = np.zeros((M, 3, P), dtype=np.float64) 350 | polygon_center = np.zeros((M, 3), dtype=np.float64) 351 | polygon_position = np.zeros((M, 2), dtype=np.float64) 352 | polygon_orientation = np.zeros(M, dtype=np.float64) 353 | polygon_type = np.zeros(M, dtype=np.int8) 354 | polygon_on_route = np.zeros(M, dtype=np.bool) 355 | polygon_tl_status = np.zeros(M, dtype=np.int8) 356 | polygon_speed_limit = np.zeros(M, dtype=np.float64) 357 | polygon_has_speed_limit = np.zeros(M, dtype=np.bool) 358 | 359 | for lane in lane_objects: 360 | object_id = int(lane.id) 361 | idx = object_ids.index(object_id) 362 | speed_limit = lane.speed_limit_mps 363 | 364 | centerline = self._sample_discrete_path( 365 | lane.baseline_path.discrete_path, sample_points + 1 366 | ) 367 | left_bound = self._sample_discrete_path( 368 | lane.left_boundary.discrete_path, sample_points + 1 369 | ) 370 | right_bound = self._sample_discrete_path( 371 | lane.right_boundary.discrete_path, sample_points + 1 372 | ) 373 | edges = np.stack([centerline, left_bound, right_bound], axis=0) 374 | 375 | point_vector[idx] = edges[:, 1:] - edges[:, :-1] 376 | point_position[idx] = edges[:, :-1] 377 | point_orientation[idx] = np.arctan2( 378 | point_vector[idx, :, :, 1], point_vector[idx, :, :, 0] 379 | ) 380 | point_side[idx] = np.arange(3) 381 | 382 | polygon_center[idx] = np.concatenate( 383 | [ 384 | centerline[int(sample_points / 2)], 385 | [point_orientation[idx, 0, int(sample_points / 2)]], 386 | ], 387 | axis=-1, 388 | ) 389 | polygon_position[idx] = centerline[0] 390 | polygon_orientation[idx] = point_orientation[idx, 0, 0] 391 | polygon_type[idx] = self.polygon_types.index(object_types[idx]) 392 | polygon_on_route[idx] = int(lane.get_roadblock_id()) in route_ids 393 | polygon_tl_status[idx] = ( 394 | tls[object_id] if object_id in tls else TrafficLightStatusType.UNKNOWN 395 | ) 396 | polygon_has_speed_limit[idx] = speed_limit is not None 397 | polygon_speed_limit[idx] = ( 398 | lane.speed_limit_mps if lane.speed_limit_mps else 0 399 | ) 400 | 401 | for crosswalk in crosswalk_objects: 402 | idx = object_ids.index(int(crosswalk.id)) 403 | edges = self._get_crosswalk_edges(crosswalk) 404 | point_vector[idx] = edges[:, 1:] - edges[:, :-1] 405 | point_position[idx] = edges[:, :-1] 406 | point_orientation[idx] = np.arctan2( 407 | point_vector[idx, :, :, 1], point_vector[idx, :, :, 0] 408 | ) 409 | point_side[idx] = np.arange(3) 410 | polygon_center[idx] = np.concatenate( 411 | [ 412 | edges[0, int(sample_points / 2)], 413 | [point_orientation[idx, 0, int(sample_points / 2)]], 414 | ], 415 | axis=-1, 416 | ) 417 | polygon_position[idx] = edges[0, 0] 418 | polygon_orientation[idx] = point_orientation[idx, 0, 0] 419 | polygon_type[idx] = self.polygon_types.index(object_types[idx]) 420 | polygon_on_route[idx] = False 421 | polygon_tl_status[idx] = TrafficLightStatusType.UNKNOWN 422 | polygon_has_speed_limit[idx] = False 423 | 424 | return { 425 | "point_position": point_position, 426 | "point_vector": point_vector, 427 | "point_orientation": point_orientation, 428 | "point_side": point_side, 429 | "polygon_center": polygon_center, 430 | "polygon_position": polygon_position, 431 | "polygon_orientation": polygon_orientation, 432 | "polygon_type": polygon_type, 433 | "polygon_on_route": polygon_on_route, 434 | "polygon_tl_status": polygon_tl_status, 435 | "polygon_has_speed_limit": polygon_has_speed_limit, 436 | "polygon_speed_limit": polygon_speed_limit, 437 | } 438 | 439 | def _sample_discrete_path(self, discrete_path: List[StateSE2], num_points: int): 440 | path = np.stack([point.array for point in discrete_path], axis=0) 441 | return interpolate_polyline(path, num_points) 442 | 443 | def _get_crosswalk_edges( 444 | self, crosswalk: PolygonMapObject, sample_points: int = 21 445 | ): 446 | bbox = shapely.minimum_rotated_rectangle(crosswalk.polygon) 447 | coords = np.stack(bbox.exterior.coords.xy, axis=-1) 448 | edge1 = coords[[3, 0]] # right boundary 449 | edge2 = coords[[2, 1]] # left boundary 450 | 451 | edges = np.stack([(edge1 + edge2) * 0.5, edge2, edge1], axis=0) # [3, 2, 2] 452 | vector = edges[:, 1] - edges[:, 0] # [3, 2] 453 | steps = np.linspace(0, 1, sample_points, endpoint=True)[None, :] 454 | points = edges[:, 0][:, None, :] + vector[:, None, :] * steps[:, :, None] 455 | 456 | return points 457 | 458 | def calculate_additional_ego_states( 459 | self, current_state: EgoState, prev_state: EgoState, dt=0.1 460 | ): 461 | cur_velocity = current_state.dynamic_car_state.rear_axle_velocity_2d.x 462 | angle_diff = current_state.rear_axle.heading - prev_state.rear_axle.heading 463 | angle_diff = (angle_diff + np.pi) % (2 * np.pi) - np.pi 464 | yaw_rate = angle_diff / 0.1 465 | 466 | if abs(cur_velocity) < 0.2: 467 | return 0.0, 0.0 # if the car is almost stopped, the yaw rate is unreliable 468 | else: 469 | steering_angle = np.arctan( 470 | yaw_rate * self.ego_params.wheel_base / abs(cur_velocity) 471 | ) 472 | steering_angle = np.clip(steering_angle, -2 / 3 * np.pi, 2 / 3 * np.pi) 473 | yaw_rate = np.clip(yaw_rate, -0.95, 0.95) 474 | 475 | return steering_angle, yaw_rate 476 | --------------------------------------------------------------------------------