├── input └── .keep ├── src ├── __init__.py ├── lib │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── block │ │ │ ├── __init__.py │ │ │ ├── conv_block.py │ │ │ ├── linear_block.py │ │ │ └── feat_module.py │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── yaw │ │ │ ├── __init__.py │ │ │ ├── yaw_predictor.py │ │ │ └── lyft_yaw_regressor.py │ │ │ ├── multi │ │ │ ├── __init__.py │ │ │ ├── multi_utils.py │ │ │ ├── lyft_multi_regressor.py │ │ │ ├── multi_model_predictor.py │ │ │ ├── lyft_multi_model.py │ │ │ ├── efficientnet_multi.py │ │ │ ├── timm_multi.py │ │ │ ├── resnest_multi.py │ │ │ └── pretrained_cnn_multi.py │ │ │ ├── single │ │ │ ├── __init__.py │ │ │ ├── lyft_regressor.py │ │ │ └── lyft_model.py │ │ │ ├── deep_ensemble │ │ │ ├── __init__.py │ │ │ ├── lyft_multi_deep_ensemble_regressor.py │ │ │ └── lyft_multi_deep_ensemble_predictor.py │ │ │ ├── multi_agent │ │ │ ├── __init__.py │ │ │ ├── lyft_multi_agent_regressor.py │ │ │ └── smp_multi_agent_model.py │ │ │ ├── cnn_collections │ │ │ ├── __init__.py │ │ │ └── efficient_net_wrapper.py │ │ │ └── rnn_head_multi │ │ │ ├── __init__.py │ │ │ ├── rnn_head_multi_regressor.py │ │ │ ├── target_scale_wrapper.py │ │ │ └── lstm_head_multi_predictor.py │ ├── data │ │ ├── __init__.py │ │ └── tuned_map_api.py │ ├── dataset │ │ ├── __init__.py │ │ ├── transform_dataset.py │ │ ├── custom_ego_dataset.py │ │ ├── multi_agent_dataset.py │ │ ├── fast_agent_dataset.py │ │ └── faster_agent_dataset.py │ ├── mixture │ │ ├── __init__.py │ │ └── gmm.py │ ├── sampling │ │ └── __init__.py │ ├── training │ │ ├── __init__.py │ │ ├── ignite_utils.py │ │ ├── lr_scheduler.py │ │ ├── distributed_evaluator.py │ │ ├── snapshot_object_when_lr_increase.py │ │ ├── exponential_moving_average.py │ │ └── scene_sampler.py │ ├── utils │ │ ├── __init__.py │ │ ├── dotdict.py │ │ ├── yaml_utils.py │ │ ├── timer_utils.py │ │ ├── resumable_distributed_sampler.py │ │ ├── distributed_utils.py │ │ └── numba_utils.py │ ├── evaluation │ │ ├── __init__.py │ │ └── mask.py │ ├── functions │ │ ├── __init__.py │ │ ├── residual_add.py │ │ ├── mse.py │ │ ├── mish_activation.py │ │ └── transform.py │ ├── rasterization │ │ ├── __init__.py │ │ ├── combined_rasterizer.py │ │ ├── rasterizer_builder.py │ │ └── channel_semantic_rasterizer.py │ └── transforms │ │ ├── yaw.py │ │ ├── single_agent.py │ │ ├── rnn_head_single_agent.py │ │ ├── __init__.py │ │ ├── multi_agent.py │ │ ├── cross_drop.py │ │ └── augmentation.py ├── ensemble │ ├── results │ │ └── .keep │ ├── flags │ │ └── 20201126_ensemble.yaml │ ├── ensemble_test.py │ ├── ensemble_val.py │ └── ensemble_batch.py └── modeling │ ├── __init__.py │ ├── results │ └── .keep │ ├── flags │ ├── 20201104_cosine_aug.yaml │ ├── 20201110_cosine_aug_2_min_history_0.yaml │ ├── 20201104_vanilla_cosine_min_history_0.yaml │ ├── 20201105_cosine_seresnext50.yaml │ ├── 20201115_cosine_seresnext50_im224_aug.yaml │ ├── 20201112_snapshot_ensemble_resnet18.yaml │ ├── 20201115_cosine_seresnext50_im224.yaml │ ├── 20201113_res50_im224_incvalid.yaml │ └── 20201118_cosine_seresnext50_im224_aug_val.yaml │ ├── configs │ ├── 0905_cfg.yaml │ ├── 0905_cfg_full.yaml │ ├── 0927_cfg_full_im128.yaml │ ├── 1119_cfg_full_im224_tuned.yaml │ ├── 1111_cfg_full_agenttypebox.yaml │ └── 1120_cfg_full_im224_numhistory10_aug.yaml │ ├── check_history_avail.py │ ├── calc_num_history_vs_loss.py │ ├── load_flag.py │ ├── calc_target_scale.py │ ├── calc_history_avail.py │ ├── calc_num_history.py │ └── builder.py ├── pyproject.toml ├── LICENSE ├── .gitignore ├── requirements.txt └── README.md /input/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/ensemble/results/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/mixture/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/block/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modeling/results/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/functions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/yaw/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/rasterization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/single/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/deep_ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/cnn_collections/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/nn/models/rnn_head_multi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/lib/transforms/yaw.py: -------------------------------------------------------------------------------- 1 | def transform_yaw(batch): 2 | """Used in pred_mode='single' or 'multi', predict only single agent. This is baseline.""" 3 | return batch["image"], batch["yaw"] 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # from l5kit library 2 | [tool.black] 3 | line-length = 119 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | include_trailing_comma = true 8 | force_grid_wrap = 0 9 | use_parentheses = true 10 | line_length = 119 11 | -------------------------------------------------------------------------------- /src/lib/transforms/single_agent.py: -------------------------------------------------------------------------------- 1 | def transform_single_agent(batch): 2 | """Used in pred_mode='single' or 'multi', predict only single agent. This is baseline.""" 3 | return batch["image"], batch["target_positions"], batch["target_availabilities"] 4 | -------------------------------------------------------------------------------- /src/lib/utils/dotdict.py: -------------------------------------------------------------------------------- 1 | class DotDict(dict): 2 | """dot.notation access to dictionary attributes 3 | 4 | Refer: https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary/23689767#23689767 5 | """ # NOQA 6 | 7 | __getattr__ = dict.get 8 | __setattr__ = dict.__setitem__ 9 | __delattr__ = dict.__delitem__ 10 | -------------------------------------------------------------------------------- /src/lib/utils/yaml_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def save_yaml(filepath, content, width=120): 5 | with open(filepath, 'w') as f: 6 | yaml.dump(content, f, width=width) 7 | 8 | 9 | def load_yaml(filepath): 10 | with open(filepath, 'r') as f: 11 | # content = yaml.safe_load(f) 12 | content = yaml.full_load(f) 13 | return content 14 | -------------------------------------------------------------------------------- /src/lib/functions/residual_add.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def residual_add(lhs, rhs): 5 | lhs_ch, rhs_ch = lhs.shape[1], rhs.shape[1] 6 | if lhs_ch < rhs_ch: 7 | out = lhs + rhs[:, :lhs_ch] 8 | elif lhs_ch > rhs_ch: 9 | out = torch.cat([lhs[:, :rhs_ch] + rhs, lhs[:, rhs_ch:]], dim=1) 10 | else: 11 | out = lhs + rhs 12 | return out 13 | -------------------------------------------------------------------------------- /src/lib/transforms/rnn_head_single_agent.py: -------------------------------------------------------------------------------- 1 | def transform_rnn_head_single_agent(batch): 2 | """Used in pred_mode='rnn_head_multi', predict only single agent.""" 3 | return ( 4 | batch["image"], 5 | batch["history_positions"], 6 | batch["history_availabilities"], 7 | batch["target_positions"], 8 | batch["target_availabilities"] 9 | ) 10 | -------------------------------------------------------------------------------- /src/lib/functions/mse.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch.nn.functional as F 3 | 4 | 5 | # --- Single --- 6 | def mse_loss(gt: Tensor, pred: Tensor, avails: Tensor) -> Tensor: 7 | loss = F.mse_loss(gt[avails > 0.], pred[avails > 0.]) 8 | return loss 9 | 10 | 11 | # --- Multi --- 12 | def mse_loss_multi(gt: Tensor, pred: Tensor, confidences: Tensor, avails: Tensor) -> Tensor: 13 | loss = F.mse_loss(gt[avails > 0.], pred[avails > 0.]) 14 | return loss 15 | -------------------------------------------------------------------------------- /src/lib/training/ignite_utils.py: -------------------------------------------------------------------------------- 1 | from ignite.engine import Engine 2 | 3 | 4 | def create_trainer(model, optimizer, device) -> Engine: 5 | model.to(device) 6 | 7 | def update_fn(engine, batch): 8 | model.train() 9 | optimizer.zero_grad() 10 | loss, metrics = model(*[elem.to(device) for elem in batch]) 11 | loss.backward() 12 | optimizer.step() 13 | return metrics 14 | trainer = Engine(update_fn) 15 | return trainer 16 | -------------------------------------------------------------------------------- /src/lib/utils/timer_utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from time import perf_counter 3 | 4 | 5 | @contextmanager 6 | def timer(name): 7 | t0 = perf_counter() 8 | yield 9 | t1 = perf_counter() 10 | print("[{}] done in {:.3f} s".format(name, t1 - t0)) 11 | 12 | 13 | @contextmanager 14 | def timer_ms(name): 15 | t0 = perf_counter() 16 | yield 17 | t1 = perf_counter() 18 | print("[{}] done in {:.3f} ms".format(name, (t1 - t0) * 1000.)) 19 | -------------------------------------------------------------------------------- /src/lib/dataset/transform_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from torch.utils.data.dataset import Dataset 4 | 5 | 6 | class TransformDataset(Dataset): 7 | def __init__(self, dataset: Dataset, transform: Callable): 8 | self.dataset = dataset 9 | self.transform = transform 10 | 11 | def __getitem__(self, index): 12 | batch = self.dataset[index] 13 | return self.transform(batch) 14 | 15 | def __len__(self): 16 | return len(self.dataset) 17 | -------------------------------------------------------------------------------- /src/modeling/flags/20201104_cosine_aug.yaml: -------------------------------------------------------------------------------- 1 | cfg_filepath: configs/0905_cfg_full.yaml 2 | debug: false 3 | device: cuda:0 4 | ema_decay: 0.999 5 | epoch: 1 6 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 7 | load_predictor_filepath: '' 8 | min_frame_future: 10 9 | min_frame_history: 0 10 | model_kwargs: {} 11 | model_name: resnet18 12 | n_valid_data: 10000 13 | out_dir: results/20201104_cosine_aug 14 | pred_mode: multi 15 | resume_if_possible: true 16 | scene_sampler: false 17 | scene_sampler_min_state_index: 0 18 | scheduler_kwargs: 19 | T_max: 2100000 20 | scheduler_trigger: 21 | - 1 22 | - iteration 23 | scheduler_type: CosineAnnealingLR 24 | snapshot_freq: 5000 25 | validation_chopped: true 26 | validation_freq: 20000 27 | -------------------------------------------------------------------------------- /src/lib/nn/models/yaw/yaw_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Dict, Optional 4 | 5 | 6 | class LyftYawPredictor(nn.Module): 7 | 8 | target_scale: Optional[Tensor] 9 | 10 | def __init__(self, base_model: nn.Module): 11 | super().__init__() 12 | self.base_model = base_model 13 | self.in_channels = base_model.in_channels 14 | # X, Y coords for the future positions (output shape: Bx50x2) 15 | 16 | def forward(self, x, x_feat=None): 17 | if x_feat is None: 18 | h = self.base_model(x) 19 | else: 20 | h = self.base_model(x, x_feat) 21 | # h: (bs, num_preds(pred) + num_modes(confidence) --> Only return 2) 22 | return h[:, :2] 23 | -------------------------------------------------------------------------------- /src/modeling/flags/20201110_cosine_aug_2_min_history_0.yaml: -------------------------------------------------------------------------------- 1 | cfg_filepath: configs/0905_cfg_full.yaml 2 | debug: false 3 | device: cuda:0 4 | ema_decay: 0.999 5 | epoch: 1 6 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 7 | load_predictor_filepath: '' 8 | min_frame_future: 10 9 | min_frame_history: 0 10 | model_kwargs: {} 11 | model_name: resnet18 12 | n_valid_data: 10000 13 | out_dir: results/20201110_cosine_aug_2_min_history_0 14 | pred_mode: multi 15 | resume_if_possible: true 16 | scene_sampler: false 17 | scene_sampler_min_state_index: 0 18 | scheduler_kwargs: 19 | T_max: 2100000 20 | scheduler_trigger: 21 | - 1 22 | - iteration 23 | scheduler_type: CosineAnnealingLR 24 | snapshot_freq: 5000 25 | validation_chopped: true 26 | validation_freq: 20000 27 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/multi_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | 4 | def calc_in_out_channels(cfg: Dict, num_modes: int = 3) -> Tuple[int, int, int, int]: 5 | # DEPRECATED. in_channels calculation only work for "py_semantic" rasterizer... 6 | num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2 7 | in_channels = 3 + num_history_channels 8 | future_len = cfg["model_params"]["future_num_frames"] 9 | num_targets = 2 * future_len 10 | num_preds = num_targets * num_modes 11 | num_modes = num_modes 12 | out_dim = num_preds + num_modes 13 | return in_channels, out_dim, num_preds, future_len 14 | 15 | 16 | def calc_out_channels(cfg: Dict, num_modes: int = 3) -> Tuple[int, int, int]: 17 | future_len = cfg["model_params"]["future_num_frames"] 18 | num_targets = 2 * future_len 19 | num_preds = num_targets * num_modes 20 | num_modes = num_modes 21 | out_dim = num_preds + num_modes 22 | return out_dim, num_preds, future_len 23 | -------------------------------------------------------------------------------- /src/lib/training/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any 2 | 3 | from torch import optim 4 | 5 | from pytorch_pfn_extras.training.extension import Extension, PRIORITY_READER 6 | from pytorch_pfn_extras.training.manager import ExtensionsManager 7 | 8 | 9 | class LRScheduler(Extension): 10 | """A thin wrapper to resume the lr_scheduler""" 11 | 12 | trigger = 1, 'iteration' 13 | priority = PRIORITY_READER 14 | name = None 15 | 16 | def __init__(self, optimizer: optim.Optimizer, scheduler_type: str, scheduler_kwargs: Mapping[str, Any]) -> None: 17 | super().__init__() 18 | self.scheduler = getattr(optim.lr_scheduler, scheduler_type)(optimizer, **scheduler_kwargs) 19 | 20 | def __call__(self, manager: ExtensionsManager) -> None: 21 | self.scheduler.step() 22 | 23 | def state_dict(self) -> None: 24 | return self.scheduler.state_dict() 25 | 26 | def load_state_dict(self, to_load) -> None: 27 | self.scheduler.load_state_dict(to_load) 28 | -------------------------------------------------------------------------------- /src/lib/functions/mish_activation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from https://www.kaggle.com/iafoss/mish-activation 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class MishFunction(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x): 12 | ctx.save_for_backward(x) 13 | return x * torch.tanh(F.softplus(x)) # x * tanh(ln(1 + exp(x))) 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | x = ctx.saved_variables[0] 18 | sigmoid = torch.sigmoid(x) 19 | tanh_sp = torch.tanh(F.softplus(x)) 20 | return grad_output * (tanh_sp + x * sigmoid * (1 - tanh_sp * tanh_sp)) 21 | 22 | 23 | class Mish(nn.Module): 24 | def forward(self, x): 25 | return MishFunction.apply(x) 26 | 27 | 28 | def to_Mish(model): 29 | for child_name, child in model.named_children(): 30 | if isinstance(child, nn.ReLU): 31 | setattr(model, child_name, Mish()) 32 | else: 33 | to_Mish(child) 34 | -------------------------------------------------------------------------------- /src/modeling/flags/20201104_vanilla_cosine_min_history_0.yaml: -------------------------------------------------------------------------------- 1 | blur: 2 | blur_limit: 3 | - 3 4 | - 5 5 | p: 0.0 6 | cfg_filepath: configs/0905_cfg_full.yaml 7 | cutout: 8 | p: 0.0 9 | scale_max: 0.99 10 | scale_min: 0.75 11 | debug: false 12 | device: cuda:0 13 | downscale: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 5 18 | p: 0.0 19 | ema_decay: 0.999 20 | epoch: 1 21 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 22 | load_predictor_filepath: '' 23 | min_frame_future: 10 24 | min_frame_history: 0 25 | model_kwargs: {} 26 | model_name: resnet18 27 | n_valid_data: 10000 28 | out_dir: results/20201104_vanilla_cosine_min_history_0 29 | pred_mode: multi 30 | resume_if_possible: true 31 | scene_sampler: false 32 | scene_sampler_min_state_index: 0 33 | scheduler_kwargs: 34 | T_max: 2100000 35 | scheduler_trigger: 36 | - 1 37 | - iteration 38 | scheduler_type: CosineAnnealingLR 39 | snapshot_freq: 5000 40 | validation_chopped: true 41 | validation_freq: 20000 42 | -------------------------------------------------------------------------------- /src/ensemble/flags/20201126_ensemble.yaml: -------------------------------------------------------------------------------- 1 | outdir: "src/ensemble/results/20201126_ensemble_1_submission.csv" 2 | weight: [1, 1, 1, 1, 1, 1, 1, 1, 1] 3 | sigma: 0.001 4 | N_sample: 3000 5 | covariance_type: "spherical" 6 | file_list: 7 | - src/modeling/results/20201110_cosine_aug_2_min_history_0/prediction_ema/submission.csv 8 | - src/modeling/results/20201104_vanilla_cosine_min_history_0/prediction_ema/submission.csv 9 | - src/modeling/results/20201104_cosine_aug/prediction_ema/submission.csv 10 | - src/modeling/results/20201105_cosine_seresnext50/prediction_ema/submission.csv 11 | - src/modeling/results/20201113_res50_im224_incvalid/prediction_ema/submission.csv 12 | - src/modeling/results/20201112_snapshot_ensemble_resnet18/4th_cycle/prediction/submission.csv 13 | - src/modeling/results/20201115_cosine_seresnext50_im224_aug/prediction_ema/submission.csv 14 | - src/modeling/results/20201115_cosine_seresnext50_im224/prediction_ema/submission.csv 15 | - src/modeling/results/20201118_cosine_seresnext50_im224_aug_val/prediction_ema/submission.csv 16 | -------------------------------------------------------------------------------- /src/lib/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.pardir) 4 | sys.path.append(os.path.join(os.pardir, os.pardir)) 5 | from lib.transforms.single_agent import transform_single_agent 6 | from lib.transforms.multi_agent import transform_multi_agent, collate_fn_multi_agent 7 | from lib.transforms.rnn_head_single_agent import transform_rnn_head_single_agent 8 | from lib.transforms.yaw import transform_yaw 9 | 10 | pred_mode_to_transform = { 11 | "single": transform_single_agent, 12 | "multi": transform_single_agent, 13 | "multi_agent": transform_multi_agent, 14 | "rnn_head_multi": transform_rnn_head_single_agent, 15 | "multi_deep_ensemble": transform_single_agent, 16 | "yaw": transform_yaw, 17 | } 18 | 19 | pred_mode_to_collate_fn = { 20 | "single": None, # Use default 21 | "multi": None, # Use default 22 | "multi_agent": collate_fn_multi_agent, 23 | "rnn_head_multi": None, # Use default 24 | "multi_deep_ensemble": None, # Use default 25 | "yaw": None, # Use default 26 | } 27 | -------------------------------------------------------------------------------- /src/lib/nn/models/single/lyft_regressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import pytorch_pfn_extras as ppe 5 | 6 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_single 7 | from lib.functions.mse import mse_loss 8 | 9 | 10 | class LyftRegressor(nn.Module): 11 | """Single mode prediction""" 12 | 13 | def __init__(self, predictor, lossfun=mse_loss): 14 | super().__init__() 15 | self.predictor = predictor 16 | self.lossfun = lossfun 17 | self.prefix = "" 18 | 19 | def forward(self, image, targets, target_availabilities): 20 | outputs = self.predictor(image).reshape(targets.shape) 21 | loss = self.lossfun(targets, outputs, target_availabilities) 22 | metrics = { 23 | f"{self.prefix}loss": loss.item(), 24 | f"{self.prefix}nll": pytorch_neg_multi_log_likelihood_single(targets, outputs, target_availabilities).item() 25 | } 26 | ppe.reporting.report(metrics, self) 27 | return loss, metrics 28 | -------------------------------------------------------------------------------- /src/lib/transforms/multi_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def transform_multi_agent(batch): 5 | """Used in pred_mode='single' or 'multi', predict only single agent. This is baseline.""" 6 | image = torch.as_tensor(batch["image"]) 7 | centroid_pixel = torch.as_tensor(batch["centroid_pixel"]) 8 | target_positions = torch.as_tensor(batch["target_positions"]) 9 | target_availabilities = torch.as_tensor(batch["target_availabilities"]) 10 | return image, centroid_pixel, target_positions, target_availabilities 11 | 12 | 13 | def collate_fn_multi_agent(data_list): 14 | image = torch.stack([d[0] for d in data_list], dim=0) 15 | centroid_pixel = torch.cat([d[1] for d in data_list], dim=0).type(torch.long) 16 | target_positions = torch.cat([d[2] for d in data_list], dim=0) 17 | target_availabilities = torch.cat([d[3] for d in data_list], dim=0) 18 | batch_agents = torch.cat([torch.full((d[1].shape[0],), i, dtype=torch.long) for i, d in enumerate(data_list)]) 19 | return image, centroid_pixel, batch_agents, target_positions, target_availabilities 20 | -------------------------------------------------------------------------------- /src/modeling/flags/20201105_cosine_seresnext50.yaml: -------------------------------------------------------------------------------- 1 | blur: 2 | blur_limit: 3 | - 3 4 | - 5 5 | p: 0.0 6 | cfg_filepath: configs/0927_cfg_full_im128.yaml 7 | cutout: 8 | p: 0.0 9 | scale_max: 0.99 10 | scale_min: 0.75 11 | debug: false 12 | device: cuda:0 13 | downscale: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 5 18 | p: 0.0 19 | ema_decay: 0.999 20 | epoch: 1 21 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 22 | load_predictor_filepath: '' 23 | min_frame_future: 10 24 | min_frame_history: 0 25 | model_kwargs: 26 | hdim: 4096 27 | use_bn: false 28 | model_name: se_resnext50_32x4d 29 | n_valid_data: 10000 30 | out_dir: results/20201105_cosine_seresnext50 31 | pred_mode: multi 32 | resume_if_possible: true 33 | scene_sampler: false 34 | scene_sampler_min_state_index: 0 35 | scheduler_kwargs: 36 | T_max: 2100000 37 | scheduler_trigger: 38 | - 1 39 | - iteration 40 | scheduler_type: CosineAnnealingLR 41 | snapshot_freq: 5000 42 | target_scale_filepath: '' 43 | validation_chopped: true 44 | validation_freq: 20000 45 | -------------------------------------------------------------------------------- /src/modeling/flags/20201115_cosine_seresnext50_im224_aug.yaml: -------------------------------------------------------------------------------- 1 | blur: 2 | blur_limit: 3 | - 3 4 | - 5 5 | p: 0.0 6 | cfg_filepath: configs/0905_cfg_full.yaml 7 | cutout: 8 | p: 0.0 9 | scale_max: 0.99 10 | scale_min: 0.75 11 | debug: false 12 | device: cuda:0 13 | downscale: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 5 18 | p: 0.0 19 | ema_decay: 0.999 20 | epoch: 1 21 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 22 | load_predictor_filepath: '' 23 | min_frame_future: 10 24 | min_frame_history: 0 25 | model_kwargs: 26 | hdim: 4096 27 | use_bn: false 28 | model_name: se_resnext50_32x4d 29 | n_valid_data: 10000 30 | out_dir: results/20201115_cosine_seresnext50_im224_aug 31 | pred_mode: multi 32 | resume_if_possible: true 33 | scene_sampler: false 34 | scene_sampler_min_state_index: 0 35 | scheduler_kwargs: 36 | T_max: 2100000 37 | scheduler_trigger: 38 | - 1 39 | - iteration 40 | scheduler_type: CosineAnnealingLR 41 | snapshot_freq: 5000 42 | target_scale_filepath: '' 43 | validation_chopped: true 44 | validation_freq: 20000 45 | -------------------------------------------------------------------------------- /src/modeling/flags/20201112_snapshot_ensemble_resnet18.yaml: -------------------------------------------------------------------------------- 1 | blur: 2 | blur_limit: 3 | - 3 4 | - 5 5 | p: 0.0 6 | cfg_filepath: configs/0905_cfg_full.yaml 7 | cutout: 8 | p: 0.0 9 | scale_max: 0.99 10 | scale_min: 0.75 11 | debug: false 12 | device: cuda:0 13 | downscale: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 5 18 | p: 0.0 19 | ema_decay: 0.999 20 | epoch: 1 21 | feat_mode: none 22 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 23 | load_predictor_filepath: '' 24 | min_frame_future: 10 25 | min_frame_history: 0 26 | model_kwargs: {} 27 | model_name: resnet18 28 | n_valid_data: 10000 29 | out_dir: results/20201112_snapshot_ensemble_resnet18 30 | override_sample_function_name: '' 31 | pred_mode: multi 32 | resume_if_possible: true 33 | scene_sampler: false 34 | scene_sampler_min_state_index: 0 35 | scheduler_kwargs: 36 | T_0: 413488 37 | scheduler_trigger: 38 | - 1 39 | - iteration 40 | scheduler_type: CosineAnnealingWarmRestarts 41 | snapshot_freq: 5000 42 | target_scale_filepath: '' 43 | validation_chopped: true 44 | validation_freq: 20000 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Preferred Networks, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/modeling/flags/20201115_cosine_seresnext50_im224.yaml: -------------------------------------------------------------------------------- 1 | blur: 2 | blur_limit: 3 | - 3 4 | - 5 5 | p: 0.0 6 | cfg_filepath: configs/0905_cfg_full.yaml 7 | cutout: 8 | p: 0.0 9 | scale_max: 0.99 10 | scale_min: 0.75 11 | debug: false 12 | device: cuda:0 13 | downscale: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 5 18 | p: 0.0 19 | ema_decay: 0.999 20 | epoch: 1 21 | feat_mode: none 22 | include_valid: false 23 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 24 | load_predictor_filepath: '' 25 | min_frame_future: 10 26 | min_frame_history: 0 27 | model_kwargs: 28 | use_bn: False 29 | hdim: 4096 30 | model_name: se_resnext50_32x4d 31 | n_valid_data: 10000 32 | out_dir: results/20201115_cosine_seresnext50_im224 33 | override_sample_function_name: '' 34 | pred_mode: multi 35 | resume_if_possible: true 36 | scene_sampler: false 37 | scene_sampler_min_state_index: 0 38 | scheduler_kwargs: 39 | T_max: 2100000 40 | scheduler_trigger: 41 | - 1 42 | - iteration 43 | scheduler_type: CosineAnnealingLR 44 | snapshot_freq: 5000 45 | target_scale_filepath: '' 46 | validation_chopped: true 47 | validation_freq: 20000 48 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/lyft_multi_regressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_pfn_extras as ppe 4 | 5 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 6 | from lib.functions.mse import mse_loss_multi 7 | 8 | 9 | class LyftMultiRegressor(nn.Module): 10 | """Single mode prediction""" 11 | 12 | def __init__(self, predictor, lossfun=pytorch_neg_multi_log_likelihood_batch): 13 | super().__init__() 14 | self.predictor = predictor 15 | self.lossfun = lossfun 16 | self.prefix = "" 17 | 18 | def forward(self, image, targets, target_availabilities, x_feat=None): 19 | if x_feat is None: 20 | pred, confidences = self.predictor(image) 21 | else: 22 | pred, confidences = self.predictor(image, x_feat) 23 | loss = self.lossfun(targets, pred, confidences, target_availabilities) 24 | metrics = { 25 | f"{self.prefix}loss": loss.item(), 26 | f"{self.prefix}nll": pytorch_neg_multi_log_likelihood_batch(targets, pred, confidences, target_availabilities).item() 27 | } 28 | ppe.reporting.report(metrics, self) 29 | return loss, metrics 30 | -------------------------------------------------------------------------------- /src/lib/functions/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def transform_points_batch(points: Tensor, transf_matrix: Tensor) -> Tensor: 6 | """ 7 | Transform points using transformation matrix. 8 | Note this function assumes points.shape[1] == matrix.shape[1] - 1, which means that the last row on the matrix 9 | does not influence the final result. 10 | For 2D points only the first 2x3 part of the matrix will be used. 11 | 12 | Args: 13 | points (Tensor): Input points (Nx2). 14 | transf_matrix (Tensor): Nx3x3 transformation matrix for 2D input 15 | 16 | Returns: 17 | transformed_points (Tensor): array of shape (N,2) for 2D input points, or (N,3) points for 3D input points 18 | """ 19 | bs, cdim = points.shape 20 | assert cdim == 2 21 | assert transf_matrix.shape == (bs, 3, 3) 22 | 23 | num_dims = 2 24 | transf_matrix = transf_matrix.transpose(1, 2) # same with transf_matrix.permute(0, 2, 1) 25 | 26 | transf_points = torch.bmm(points.unsqueeze(1), transf_matrix[:, :num_dims, :num_dims]).squeeze(1) 27 | assert transf_points.shape == (bs, 2) 28 | return transf_points + transf_matrix[:, -1, :num_dims] 29 | -------------------------------------------------------------------------------- /src/lib/nn/models/rnn_head_multi/rnn_head_multi_regressor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from torch import nn, Tensor 4 | import pytorch_pfn_extras as ppe 5 | 6 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 7 | 8 | 9 | class RNNHeadMultiRegressor(nn.Module): 10 | 11 | def __init__(self, predictor, lossfun=pytorch_neg_multi_log_likelihood_batch): 12 | super().__init__() 13 | self.predictor = predictor 14 | self.lossfun = lossfun 15 | self.prefix = "" 16 | 17 | def forward( 18 | self, 19 | image: Tensor, 20 | history_positions: Tensor, 21 | history_availabilities: Tensor, 22 | targets: Tensor, 23 | target_availabilities: Tensor, 24 | ) -> Tuple[Tensor, Tensor]: 25 | pred, confidences = self.predictor(image, history_positions, history_availabilities) 26 | loss = self.lossfun(targets, pred, confidences, target_availabilities) 27 | metrics = { 28 | f"{self.prefix}loss": loss.item(), 29 | f"{self.prefix}nll": pytorch_neg_multi_log_likelihood_batch(targets, pred, confidences, target_availabilities).item() 30 | } 31 | ppe.reporting.report(metrics, self) 32 | return loss, metrics 33 | -------------------------------------------------------------------------------- /src/modeling/flags/20201113_res50_im224_incvalid.yaml: -------------------------------------------------------------------------------- 1 | augmentation_in_validation: false 2 | blur: 3 | blur_limit: 4 | - 3 5 | - 5 6 | p: 0.0 7 | cfg_filepath: configs/0905_cfg_full.yaml 8 | crossdrop: 9 | fill_value: 0 10 | max_h_cut: 0.3 11 | max_w_cut: 0.3 12 | p: 0.0 13 | cutout: 14 | p: 0.0 15 | scale_max: 0.99 16 | scale_min: 0.75 17 | debug: false 18 | device: cuda:0 19 | downscale: 20 | fill_value: 0 21 | max_h_size: 20 22 | max_w_size: 20 23 | num_holes: 5 24 | p: 0.0 25 | ema_decay: 0.999 26 | epoch: 1 27 | feat_mode: none 28 | include_valid: 'true' 29 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 30 | load_predictor_filepath: '' 31 | min_frame_future: 10 32 | min_frame_history: 0 33 | model_kwargs: 34 | hdim: 4096 35 | use_bn: false 36 | model_name: resnet50 37 | n_valid_data: 10000 38 | out_dir: results/20201113_res50_im224_incvalid 39 | override_sample_function_name: '' 40 | pred_mode: multi 41 | resume_if_possible: true 42 | scene_sampler: false 43 | scene_sampler_min_state_index: 0 44 | scheduler_kwargs: 45 | T_max: 2301400 46 | scheduler_trigger: 47 | - 1 48 | - iteration 49 | scheduler_type: CosineAnnealingLR 50 | snapshot_freq: 5000 51 | target_scale_filepath: '' 52 | validation_chopped: true 53 | validation_freq: 20000 54 | -------------------------------------------------------------------------------- /src/modeling/flags/20201118_cosine_seresnext50_im224_aug_val.yaml: -------------------------------------------------------------------------------- 1 | augmentation_in_validation: false 2 | blur: 3 | blur_limit: 4 | - 3 5 | - 5 6 | p: 0.0 7 | cfg_filepath: configs/0905_cfg_full.yaml 8 | crossdrop: 9 | fill_value: 0 10 | max_h_cut: 0.3 11 | max_w_cut: 0.3 12 | p: 0.0 13 | cutout: 14 | fill_value: 0 15 | max_h_size: 20 16 | max_w_size: 20 17 | num_holes: 10 18 | p: 0.9 19 | debug: false 20 | device: cuda:0 21 | downscale: 22 | fill_value: 0 23 | max_h_size: 20 24 | max_w_size: 20 25 | num_holes: 5 26 | p: 0.0 27 | ema_decay: 0.999 28 | epoch: 1 29 | feat_mode: none 30 | include_valid: true 31 | l5kit_data_folder: ../../input/lyft-motion-prediction-autonomous-vehicles 32 | load_predictor_filepath: '' 33 | min_frame_future: 10 34 | min_frame_history: 0 35 | model_kwargs: 36 | hdim: 4096 37 | use_bn: false 38 | model_name: se_resnext50_32x4d 39 | n_valid_data: 10000 40 | out_dir: results/20201118_cosine_seresnext50_im224_aug_val 41 | override_sample_function_name: '' 42 | pred_mode: multi 43 | resume_if_possible: true 44 | scene_sampler: false 45 | scene_sampler_min_state_index: 0 46 | scheduler_kwargs: 47 | T_max: 2301400 48 | scheduler_trigger: 49 | - 1 50 | - iteration 51 | scheduler_type: CosineAnnealingLR 52 | snapshot_freq: 5000 53 | target_scale_filepath: '' 54 | validation_chopped: true 55 | validation_freq: 20000 56 | -------------------------------------------------------------------------------- /src/lib/nn/models/rnn_head_multi/target_scale_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn, Tensor 3 | 4 | 5 | class TargetScaleWrapper(nn.Module): 6 | 7 | def __init__( 8 | self, 9 | predictor: nn.Module, 10 | target_scale: Tensor, 11 | ) -> None: 12 | super().__init__() 13 | self.predictor = predictor 14 | self.target_scale: Tensor 15 | self.register_buffer("target_scale", target_scale) # (future_len, 2) 16 | 17 | def forward(self, image: Tensor, history_positions: Tensor, history_availabilities: Tensor): 18 | """ 19 | Args: 20 | image: 21 | history_positions: (batch_size, history_len, 2) 22 | history_availabilities: (batch_size, history_len) 23 | Returns: 24 | pred: (batch_size, num_modes, future_len, 2) 25 | confidence: (batch_size, num_modes) 26 | """ 27 | assert history_positions.shape[1] <= self.target_scale.shape[0] 28 | history_positions = ( 29 | history_positions / self.target_scale[np.newaxis, np.newaxis, :history_positions.shape[1], :] 30 | ) 31 | pred, confidence = self.predictor(image, history_positions, history_availabilities) 32 | pred = pred * self.target_scale[np.newaxis, np.newaxis, :pred.shape[2], :] 33 | return pred, confidence 34 | -------------------------------------------------------------------------------- /src/lib/nn/block/conv_block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from lib.functions.residual_add import residual_add 5 | 6 | 7 | class ConvBlock(nn.Module): 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 10 | bias=True, use_bn=True, 11 | activation=F.relu, dropout_ratio=-1, residual=False, padding_mode='zeros'): 12 | super(ConvBlock, self).__init__() 13 | self.conv = nn.Conv2d( 14 | in_channels, out_channels, kernel_size=kernel_size, stride=stride, 15 | padding=padding, bias=bias, padding_mode=padding_mode) 16 | if use_bn: 17 | self.bn = nn.BatchNorm2d(out_channels) 18 | if dropout_ratio > 0.: 19 | self.dropout = nn.Dropout2d(p=dropout_ratio) 20 | else: 21 | self.dropout = None 22 | self.activation = activation 23 | self.use_bn = use_bn 24 | self.dropout_ratio = dropout_ratio 25 | self.residual = residual 26 | 27 | def forward(self, x): 28 | if self.use_bn: 29 | h = self.bn(self.conv(x)) 30 | else: 31 | h = self.conv(x) 32 | if self.activation is not None: 33 | h = self.activation(h) 34 | if self.residual: 35 | h = residual_add(h, x) 36 | if self.dropout_ratio > 0: 37 | h = self.dropout(h) 38 | return h 39 | -------------------------------------------------------------------------------- /src/lib/nn/models/deep_ensemble/lyft_multi_deep_ensemble_regressor.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import pytorch_pfn_extras as ppe 3 | 4 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 5 | from lib.nn.models.deep_ensemble.lyft_multi_deep_ensemble_predictor import LyftMultiDeepEnsemblePredictor 6 | 7 | 8 | class LyftMultiDeepEnsembleRegressor(nn.Module): 9 | 10 | def __init__(self, predictor: LyftMultiDeepEnsemblePredictor, lossfun=pytorch_neg_multi_log_likelihood_batch): 11 | super().__init__() 12 | self.predictor = predictor 13 | self.lossfun = lossfun 14 | self.prefix = "" 15 | 16 | def forward(self, image, targets, target_availabilities, x_feat=None): 17 | if x_feat is None: 18 | predictions = self.predictor(image) 19 | else: 20 | predictions = self.predictor(image, x_feat) 21 | 22 | metrics = {} 23 | total_loss = 0.0 24 | for name, (pred, confidences) in zip(self.predictor.names, predictions): 25 | loss = self.lossfun(targets, pred, confidences, target_availabilities) 26 | total_loss += loss 27 | metrics[f"{name}/{self.prefix}loss"] = loss.item() 28 | metrics[f"{name}/{self.prefix}nll"] = pytorch_neg_multi_log_likelihood_batch( 29 | targets, pred, confidences, target_availabilities 30 | ).item() 31 | 32 | ppe.reporting.report(metrics, self) 33 | return total_loss, metrics 34 | -------------------------------------------------------------------------------- /src/lib/mixture/gmm.py: -------------------------------------------------------------------------------- 1 | import numba as nb 2 | import numpy as np 3 | from sklearn.mixture import GaussianMixture 4 | # from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters 5 | 6 | 7 | @nb.jit(nb.types.Tuple( 8 | (nb.float64[:], nb.float64[:, :]) 9 | )(nb.float64[:, :], nb.float64[:, :]), nopython=True, nogil=True) 10 | def _estimate_gaussian_parameters(X, resp): 11 | nk = resp.sum(axis=0) + 10 * np.finfo(resp.dtype).eps 12 | means = np.dot(np.ascontiguousarray(resp.T), X) / np.ascontiguousarray(np.expand_dims(nk, 1)) 13 | return nk, means 14 | 15 | 16 | class GaussianMixtureIdentity(GaussianMixture): 17 | def _initialize(self, X, resp): 18 | n_samples, _ = X.shape 19 | self.covariances_ = np.zeros(self.n_components)+1.0 20 | self.precisions_cholesky_ = np.zeros(self.n_components)+1.0 21 | weights, means = _estimate_gaussian_parameters(X, resp) 22 | weights /= n_samples 23 | 24 | self.weights_ = (weights if self.weights_init is None 25 | else self.weights_init) 26 | self.means_ = means if self.means_init is None else self.means_init 27 | 28 | def _m_step(self, X, log_resp): 29 | n_samples, _ = X.shape 30 | self.covariances_ = np.zeros(self.n_components)+1.0 31 | self.precisions_cholesky_ = np.zeros(self.n_components)+1.0 32 | self.weights_, self.means_ = _estimate_gaussian_parameters(X, np.exp(log_resp)) 33 | self.weights_ /= n_samples 34 | -------------------------------------------------------------------------------- /src/lib/nn/block/linear_block.py: -------------------------------------------------------------------------------- 1 | from pytorch_pfn_extras.nn import LazyLinear 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from lib.functions.residual_add import residual_add 6 | 7 | 8 | class LinearBlock(nn.Module): 9 | 10 | def __init__(self, in_features, out_features, bias=True, 11 | use_bn=True, activation=F.relu, dropout_ratio=-1, residual=False,): 12 | super(LinearBlock, self).__init__() 13 | if in_features is None: 14 | self.linear = LazyLinear(in_features, out_features, bias=bias) 15 | else: 16 | self.linear = nn.Linear(in_features, out_features, bias=bias) 17 | if use_bn: 18 | self.bn = nn.BatchNorm1d(out_features) 19 | if dropout_ratio > 0.: 20 | self.dropout = nn.Dropout(p=dropout_ratio) 21 | else: 22 | self.dropout = None 23 | self.activation = activation 24 | self.use_bn = use_bn 25 | self.dropout_ratio = dropout_ratio 26 | self.residual = residual 27 | 28 | def __call__(self, x): 29 | if x.ndim != 2: 30 | raise ValueError(f'Only x.ndim == 2 is supported! got x.shape {x.shape}') 31 | h = self.linear(x) 32 | if self.use_bn: 33 | h = self.bn(h) 34 | if self.activation is not None: 35 | h = self.activation(h) 36 | if self.residual: 37 | h = residual_add(h, x) 38 | if self.dropout_ratio > 0: 39 | h = self.dropout(h) 40 | return h 41 | -------------------------------------------------------------------------------- /src/lib/training/distributed_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from pytorch_pfn_extras.training.extensions import Evaluator 5 | 6 | 7 | class DistributedEvaluator(Evaluator): 8 | def __init__( 9 | self, 10 | iterator, 11 | target, 12 | eval_hook=None, 13 | eval_func=None, 14 | local_rank=0, 15 | world_size=1, 16 | device: Union[torch.device, str] = "cpu", 17 | **kwargs, 18 | ): 19 | super(DistributedEvaluator, self).__init__( 20 | iterator, target, eval_hook=eval_hook, eval_func=eval_func, **kwargs 21 | ) 22 | self.local_rank = local_rank 23 | self.world_size = world_size 24 | self.device = device 25 | 26 | def evaluate(self): 27 | local_rank = self.local_rank 28 | world_size = self.world_size 29 | 30 | result = super(DistributedEvaluator, self).evaluate() 31 | print(f"[DEBUG] evaluate: local_rank {local_rank}, result {result}") 32 | keys_list = list(result.keys()) 33 | keys_list.sort() # To make order same for all process. 34 | for key in keys_list: 35 | value = torch.as_tensor(result[key], device=self.device) 36 | torch.distributed.all_reduce(value, op=torch.distributed.ReduceOp.SUM) 37 | result[key] = float(value.item() / world_size) 38 | if local_rank == 0: 39 | print( 40 | f"[DEBUG] evaluate: After all_reduce: local_rank {local_rank}, result {result}" 41 | ) 42 | return result 43 | -------------------------------------------------------------------------------- /src/lib/nn/block/feat_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class FeatModule(nn.Module): 6 | def __init__( 7 | self, 8 | feat_module_type: str = "none", 9 | channels: int = -1, 10 | feat_channels: int = -1, 11 | ): 12 | super(FeatModule, self).__init__() 13 | self.feat_module_type = feat_module_type 14 | if feat_module_type == "none": 15 | self.lin_feat = None 16 | elif feat_module_type == "sigmoid": 17 | self.lin_feat = nn.Linear(feat_channels, channels) 18 | elif feat_module_type == "film": 19 | self.lin_feat = nn.Linear(feat_channels, 2 * channels) 20 | else: 21 | raise ValueError(f"[ERROR] Unexpected value feat_module_type={feat_module_type}") 22 | 23 | def forward(self, h, h_feat=None): 24 | if self.feat_module_type == "none": 25 | # Do nothing 26 | return h 27 | elif self.feat_module_type == "sigmoid": 28 | assert h_feat is not None 29 | h *= torch.sigmoid(self.lin_feat(h_feat)) 30 | return h 31 | elif self.feat_module_type == "film": 32 | assert h_feat is not None 33 | ch = h.shape[-1] 34 | h_feat = self.lin_feat(h_feat) 35 | gamma, beta = h_feat[:, :ch], h_feat[:, ch:] 36 | # gamma = torch.tanh(gamma) 37 | h = gamma * h + beta 38 | return h 39 | else: 40 | raise ValueError(f"[ERROR] Unexpected value self.feat_module_type={self.feat_module_type}") 41 | -------------------------------------------------------------------------------- /src/modeling/configs/0905_cfg.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 224 6 | - 224 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "py_semantic" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | model_params: 29 | future_delta_time: 0.1 30 | future_num_frames: 50 31 | future_step_size: 1 32 | history_delta_time: 0.1 33 | history_num_frames: 10 34 | history_step_size: 1 35 | # not used. 36 | model_architecture: resnet50 37 | train_data_loader: 38 | batch_size: 12 39 | key: scenes/train.zarr 40 | num_workers: 4 41 | shuffle: true 42 | valid_data_loader: 43 | batch_size: 32 44 | key: scenes/validate.zarr 45 | num_workers: 4 46 | shuffle: false 47 | train_params: 48 | checkpoint_every_n_steps: 5000 49 | max_num_steps: 10000 50 | -------------------------------------------------------------------------------- /src/modeling/configs/0905_cfg_full.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 224 6 | - 224 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "py_semantic" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | model_params: 29 | future_delta_time: 0.1 30 | future_num_frames: 50 31 | future_step_size: 1 32 | history_delta_time: 0.1 33 | history_num_frames: 10 34 | history_step_size: 1 35 | # not used. 36 | model_architecture: resnet50 37 | train_data_loader: 38 | batch_size: 12 39 | key: scenes/train_full.zarr 40 | num_workers: 4 41 | shuffle: true 42 | valid_data_loader: 43 | batch_size: 32 44 | key: scenes/validate.zarr 45 | num_workers: 4 46 | shuffle: false 47 | train_params: 48 | checkpoint_every_n_steps: 5000 49 | max_num_steps: 10000 50 | -------------------------------------------------------------------------------- /src/modeling/configs/0927_cfg_full_im128.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 128 6 | - 128 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "py_semantic" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | model_params: 29 | future_delta_time: 0.1 30 | future_num_frames: 50 31 | future_step_size: 1 32 | history_delta_time: 0.1 33 | history_num_frames: 10 34 | history_step_size: 1 35 | # not used. 36 | model_architecture: resnet50 37 | train_data_loader: 38 | batch_size: 12 39 | key: scenes/train_full.zarr 40 | num_workers: 4 41 | shuffle: true 42 | valid_data_loader: 43 | batch_size: 32 44 | key: scenes/validate.zarr 45 | num_workers: 4 46 | shuffle: false 47 | train_params: 48 | checkpoint_every_n_steps: 5000 49 | max_num_steps: 10000 50 | -------------------------------------------------------------------------------- /src/lib/utils/resumable_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.distributed import DistributedSampler 3 | 4 | 5 | class ResumableDistributedSampler(DistributedSampler): 6 | 7 | def __iter__(self): 8 | if not hasattr(self, "seed"): 9 | # For pytorch==1.5.0 10 | self.seed = 0 11 | if self.shuffle: 12 | # deterministically shuffle based on epoch and seed 13 | g = torch.Generator() 14 | g.manual_seed(self.seed + self.epoch) 15 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 16 | else: 17 | indices = list(range(len(self.dataset))) 18 | 19 | # add extra samples to make it evenly divisible 20 | indices += indices[:(self.total_size - len(indices))] 21 | assert len(indices) == self.total_size 22 | 23 | # subsample 24 | indices = indices[self.rank:self.total_size:self.num_replicas] 25 | assert len(indices) == self.num_samples 26 | 27 | if hasattr(self, "iteration"): 28 | indices = indices[self.iteration:] 29 | 30 | return iter(indices) 31 | 32 | def __len__(self) -> int: 33 | if hasattr(self, "iteration"): 34 | return self.num_samples - self.iteration 35 | return self.num_samples 36 | 37 | def set_epoch(self, epoch: int) -> None: 38 | if hasattr(self, "iteration"): 39 | delattr(self, "iteration") 40 | self.epoch = epoch 41 | 42 | def resume(self, iteration: int, epoch: int) -> None: 43 | self.iteration = iteration 44 | self.epoch = epoch 45 | -------------------------------------------------------------------------------- /src/modeling/configs/1119_cfg_full_im224_tuned.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 224 6 | - 224 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "tuned_box+tuned_semantic" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | model_params: 29 | future_delta_time: 0.1 30 | future_num_frames: 50 31 | future_step_size: 1 32 | history_delta_time: 0.1 33 | history_num_frames: 10 34 | history_step_size: 1 35 | # not used. 36 | model_architecture: resnet50 37 | train_data_loader: 38 | batch_size: 12 39 | key: scenes/train_full.zarr 40 | num_workers: 4 41 | shuffle: true 42 | valid_data_loader: 43 | batch_size: 32 44 | key: scenes/validate.zarr 45 | num_workers: 4 46 | shuffle: false 47 | train_params: 48 | checkpoint_every_n_steps: 5000 49 | max_num_steps: 10000 50 | -------------------------------------------------------------------------------- /src/lib/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def check_is_mpi() -> bool: 7 | return "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ.keys() 8 | 9 | 10 | def init_distributed(master_addr: str = "localhost", master_port: str = "8899"): 11 | os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] 12 | os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] 13 | 14 | # TODO: Check!! Below 2 settings are fine?? 15 | os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", master_addr) 16 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", master_port) 17 | 18 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 19 | torch.cuda.set_device(local_rank) 20 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 21 | rank = torch.distributed.get_rank() 22 | world_size = torch.distributed.get_world_size() 23 | return rank, world_size, local_rank 24 | 25 | 26 | def setup_distributed(): 27 | is_mpi = check_is_mpi() 28 | if is_mpi: 29 | rank, world_size, local_rank = init_distributed() 30 | else: 31 | rank, world_size, local_rank = None, None, 0 32 | return is_mpi, rank, world_size, local_rank 33 | 34 | 35 | def split_valid_dataset(valid_dataset, rank, world_size): 36 | valid_dataset_indices = list(range(len(valid_dataset))) 37 | local_valid_dataset_indices = valid_dataset_indices[ 38 | rank: len(valid_dataset_indices): world_size 39 | ] 40 | local_valid_dataset = torch.utils.data.Subset(valid_dataset, local_valid_dataset_indices) 41 | return local_valid_dataset 42 | -------------------------------------------------------------------------------- /src/lib/transforms/cross_drop.py: -------------------------------------------------------------------------------- 1 | from albumentations.core.transforms_interface import ImageOnlyTransform 2 | import albumentations.augmentations.functional as F 3 | import random 4 | 5 | class CrossDrop(ImageOnlyTransform): 6 | 7 | def __init__(self, max_h_cut=0.2, max_w_cut=0.2, fill_value=0, always_apply=False, p=0.5): 8 | super(CrossDrop, self).__init__(always_apply, p) 9 | self.max_h_cut = max_h_cut 10 | self.max_w_cut = max_w_cut 11 | self.fill_value = fill_value 12 | 13 | def apply(self, image, fill_value=0, holes=(), **params): 14 | return F.cutout(image, holes, fill_value) 15 | 16 | def get_params_dependent_on_targets(self, params): 17 | img = params["image"] 18 | height, width = img.shape[:2] 19 | 20 | y1 = int(random.random() * self.max_h_cut * height) 21 | x1 = int(random.random() * self.max_w_cut * width) 22 | 23 | y2 = int(random.random() * self.max_h_cut * height) 24 | x2 = int(random.random() * self.max_w_cut * width) 25 | 26 | y3 = int(random.random() * self.max_h_cut * height) 27 | x3 = int(random.random() * self.max_w_cut * width) 28 | 29 | y4 = int(random.random() * self.max_h_cut * height) 30 | x4 = int(random.random() * self.max_w_cut * width) 31 | 32 | return {"holes": [ 33 | (0, 0, x1, y1), 34 | (width-x2, 0, width, y2), 35 | (0, height-y3, x3, height), 36 | (width-x4, height-y4, width, height) 37 | ]} 38 | 39 | @property 40 | def targets_as_params(self): 41 | return ["image"] 42 | 43 | def get_transform_init_args_names(self): 44 | return ("max_h_cut", "max_w_cut") 45 | -------------------------------------------------------------------------------- /src/modeling/check_history_avail.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate num history for chopped valid/test data... 3 | """ 4 | import argparse 5 | from distutils.util import strtobool 6 | import numpy as np 7 | import torch 8 | from pathlib import Path 9 | 10 | from l5kit.dataset import AgentDataset 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.dataset import Subset 13 | 14 | from l5kit.data import LocalDataManager, ChunkedDataset 15 | 16 | import sys 17 | import os 18 | 19 | from tqdm import tqdm 20 | sys.path.append(os.pardir) 21 | sys.path.append(os.path.join(os.pardir, os.pardir)) 22 | from lib.dataset.faster_agent_dataset import FasterAgentDataset 23 | from lib.evaluation.mask import load_mask_chopped 24 | from modeling.load_flag import Flags 25 | from lib.rasterization.rasterizer_builder import build_custom_rasterizer 26 | from lib.utils.yaml_utils import save_yaml, load_yaml 27 | 28 | 29 | def parse(): 30 | parser = argparse.ArgumentParser(description='') 31 | parser.add_argument('--out', '-o', default='results/tmp', 32 | help='Directory to output the result') 33 | parser.add_argument('--debug', '-d', type=strtobool, default='false', 34 | help='Debug mode') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | if __name__ == '__main__': 40 | args = parse() 41 | 42 | processed_dir = Path("../../input/processed_data") 43 | npz_path = processed_dir / f"history_avail.npz" 44 | print(f"Load from {npz_path}") 45 | results = np.load(npz_path) 46 | his_avail_valid = results["his_avail_valid"] 47 | his_avail_test = results["his_avail_test"] 48 | import IPython; IPython.embed() 49 | n_his_valid = np.sum(his_avail_valid, axis=1) 50 | -------------------------------------------------------------------------------- /src/modeling/configs/1111_cfg_full_agenttypebox.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 224 6 | - 224 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "semantic_debug+agent_type_box" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | 29 | # Whether to add channels only for the selected agent 30 | enable_selected_agent_channels: True 31 | 32 | model_params: 33 | future_delta_time: 0.1 34 | future_num_frames: 50 35 | future_step_size: 1 36 | history_delta_time: 0.1 37 | history_num_frames: 10 38 | history_step_size: 1 39 | # not used. 40 | model_architecture: resnet50 41 | train_data_loader: 42 | batch_size: 12 43 | key: scenes/train_full.zarr 44 | num_workers: 4 45 | shuffle: true 46 | valid_data_loader: 47 | batch_size: 32 48 | key: scenes/validate.zarr 49 | num_workers: 4 50 | shuffle: false 51 | train_params: 52 | checkpoint_every_n_steps: 5000 53 | max_num_steps: 10000 54 | -------------------------------------------------------------------------------- /src/lib/training/snapshot_object_when_lr_increase.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from torch import optim 4 | from pytorch_pfn_extras.training.extension import Extension, PRIORITY_READER 5 | from pytorch_pfn_extras.training.manager import ExtensionsManager 6 | from pytorch_pfn_extras.training.extensions import snapshot_object 7 | 8 | 9 | class SnapshotObjectWhenLRIncrease(Extension): 10 | 11 | trigger = 1, 'iteration' 12 | priority = PRIORITY_READER 13 | name = None 14 | 15 | def __init__( 16 | self, 17 | target: Any, 18 | optimizer: optim.Optimizer, 19 | param_group: int = 0, 20 | filename: str = "snapshot_{cycle_count}th_cycle.pt", 21 | saver_rank: Optional[int] = None 22 | ) -> None: 23 | super().__init__() 24 | self.cycle_count = 0 25 | self.lr_before = float("inf") 26 | self.target = target 27 | self.optimizer = optimizer 28 | self.param_group = param_group 29 | self.filename = filename 30 | self.saver_rank = saver_rank 31 | 32 | def __call__(self, manager: ExtensionsManager) -> None: 33 | lr_after = self.optimizer.param_groups[self.param_group]['lr'] 34 | if self.lr_before < lr_after: 35 | filename = self.filename.format(cycle_count=self.cycle_count) 36 | save_func = snapshot_object(self.target, filename, saver_rank=self.saver_rank) 37 | save_func(manager) 38 | self.cycle_count += 1 39 | self.lr_before = lr_after 40 | 41 | def state_dict(self) -> None: 42 | return {"cycle_count": self.cycle_count, "lr_before": self.lr_before} 43 | 44 | def load_state_dict(self, to_load) -> None: 45 | self.cycle_count = to_load["cycle_count"] 46 | self.lr_before = to_load["lr_before"] 47 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi_agent/lyft_multi_agent_regressor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | import pytorch_pfn_extras as ppe 6 | 7 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 8 | 9 | 10 | class LyftMultiAgentRegressor(nn.Module): 11 | """Multi agent, multi mode prediction""" 12 | 13 | def __init__(self, predictor, lossfun=pytorch_neg_multi_log_likelihood_batch): 14 | super().__init__() 15 | self.predictor = predictor 16 | self.lossfun = lossfun 17 | self.prefix = "" 18 | 19 | def forward( 20 | self, 21 | image: Tensor, 22 | centroid_pixel: Tensor, 23 | batch_agents: Tensor, 24 | targets: Tensor, 25 | target_availabilities: Tensor 26 | ) -> Tuple[Tensor, Dict]: 27 | """ 28 | 29 | Args: 30 | image: (batch_size=n_frames, ch, height, width) 31 | centroid_pixel: (n_agents, coords=2) 32 | batch_agents: (n_agents,) 33 | targets: (n_agents,) 34 | target_availabilities: (n_agents, future_len=50) 35 | 36 | Returns: 37 | loss: 38 | metrics: 39 | """ 40 | # pred (n_agents)x(modes)x(time)x(2D coords) 41 | # confidences (n_agents)x(modes) 42 | pred, confidences = self.predictor(image, centroid_pixel, batch_agents) 43 | loss = self.lossfun(targets, pred, confidences, target_availabilities) 44 | metrics = { 45 | f"{self.prefix}loss": loss.item(), 46 | f"{self.prefix}nll": pytorch_neg_multi_log_likelihood_batch( 47 | targets, pred, confidences, target_availabilities).item() 48 | } 49 | ppe.reporting.report(metrics, self) 50 | return loss, metrics 51 | -------------------------------------------------------------------------------- /src/modeling/configs/1120_cfg_full_im224_numhistory10_aug.yaml: -------------------------------------------------------------------------------- 1 | format_version: 4 2 | raster_params: 3 | # raster image size [pixels] 4 | raster_size: 5 | - 224 6 | - 224 7 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 8 | pixel_size: 9 | - 0.5 10 | - 0.5 11 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 12 | ego_center: 13 | - 0.25 14 | - 0.5 15 | map_type: "augmented_box+tuned_semantic" 16 | 17 | # the keys are relative to the dataset environment variable 18 | satellite_map_key: "aerial_map/aerial_map.png" 19 | semantic_map_key: "semantic_map/semantic_map.pb" 20 | dataset_meta_key: "meta.json" 21 | 22 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 23 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 24 | filter_agents_threshold: 0.5 25 | 26 | # whether to completely disable traffic light faces in the semantic rasterizer 27 | disable_traffic_light_faces: False 28 | # Optional configs only for AugmentedBoxRasterizer 29 | agent_drop_ratio: 0.5 30 | agent_drop_prob: 0.25 31 | min_extent_ratio: 0.9 32 | max_extent_ratio: 1.1 33 | model_params: 34 | future_delta_time: 0.1 35 | future_num_frames: 50 36 | future_step_size: 1 37 | history_delta_time: 0.1 38 | history_num_frames: 10 39 | history_step_size: 1 40 | # not used. 41 | model_architecture: resnet50 42 | train_data_loader: 43 | batch_size: 12 44 | key: scenes/train_full.zarr 45 | num_workers: 4 46 | shuffle: true 47 | valid_data_loader: 48 | batch_size: 32 49 | key: scenes/validate.zarr 50 | num_workers: 4 51 | shuffle: false 52 | train_params: 53 | checkpoint_every_n_steps: 5000 54 | max_num_steps: 10000 55 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/multi_model_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from typing import Dict, Optional 4 | 5 | 6 | class LyftMultiModelPredictor(nn.Module): 7 | 8 | target_scale: Optional[Tensor] 9 | 10 | def __init__(self, base_model: nn.Module, cfg: Dict, num_modes: int = 3, 11 | target_scale: Optional[Tensor] = None): 12 | super().__init__() 13 | self.base_model = base_model 14 | self.in_channels = base_model.in_channels 15 | # X, Y coords for the future positions (output shape: Bx50x2) 16 | self.future_len = cfg["model_params"]["future_num_frames"] 17 | num_targets = 2 * self.future_len 18 | self.num_preds = num_targets * num_modes 19 | self.num_modes = num_modes 20 | 21 | if target_scale is None: 22 | self.target_scale = None 23 | else: 24 | assert target_scale.shape == (self.future_len, 2) 25 | self.register_buffer("target_scale", target_scale) 26 | 27 | def forward(self, x, x_feat=None): 28 | if x_feat is None: 29 | h = self.base_model(x) 30 | else: 31 | h = self.base_model(x, x_feat) 32 | # h: (bs, num_preds(pred) + num_modes(confidence) ) 33 | assert h.shape[1] == self.num_modes + self.num_preds 34 | 35 | # pred (bs)x(modes)x(time)x(2D coords) 36 | # confidences (bs)x(modes) 37 | bs, _ = h.shape 38 | pred, confidences = torch.split(h, self.num_preds, dim=1) 39 | pred = pred.view(bs, self.num_modes, self.future_len, 2) 40 | if self.target_scale is not None: 41 | pred = pred * self.target_scale[None, None, :, :] 42 | assert confidences.shape == (bs, self.num_modes) 43 | confidences = torch.softmax(confidences, dim=1) 44 | return pred, confidences 45 | -------------------------------------------------------------------------------- /src/lib/nn/models/yaw/lyft_yaw_regressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import pytorch_pfn_extras as ppe 5 | 6 | 7 | def _mse_cos_sin_loss(target_yaw, pred): 8 | # target_yaw in radians (bs,) --> target_cos_sin (bs, 2) 9 | # pred (bs, 2) 10 | target_cos_sin = torch.stack([torch.cos(target_yaw), torch.sin(target_yaw)], dim=1) 11 | return F.mse_loss(target_cos_sin, pred) 12 | 13 | 14 | def _mae_cos_sin_loss(target_yaw, pred): 15 | # target_yaw in radians (bs,) --> target_cos_sin (bs, 2) 16 | # pred (bs, 2) 17 | target_cos_sin = torch.stack([torch.cos(target_yaw), torch.sin(target_yaw)], dim=1) 18 | return F.l1_loss(target_cos_sin, pred) 19 | 20 | 21 | class LyftYawRegressor(nn.Module): 22 | """Single mode prediction""" 23 | 24 | def __init__(self, predictor, lossfun: str = ""): 25 | super().__init__() 26 | self.predictor = predictor 27 | if lossfun == "mse": 28 | self.lossfun = _mse_cos_sin_loss 29 | elif lossfun == "mae": 30 | self.lossfun = _mae_cos_sin_loss 31 | else: 32 | print(f"[WARNING] Unknown lossfun {lossfun}, use mse loss...") 33 | self.lossfun = _mse_cos_sin_loss 34 | 35 | self.prefix = "" 36 | 37 | def forward(self, image, target_yaw, x_feat=None): 38 | if x_feat is None: 39 | pred = self.predictor(image) 40 | else: 41 | pred = self.predictor(image, x_feat) 42 | loss = self.lossfun(target_yaw, pred) 43 | metrics = { 44 | f"{self.prefix}loss": loss.item(), 45 | f"{self.prefix}mse": _mse_cos_sin_loss(target_yaw, pred).item(), 46 | f"{self.prefix}mae": _mae_cos_sin_loss(target_yaw, pred).item(), 47 | } 48 | ppe.reporting.report(metrics, self) 49 | return loss, metrics 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | /lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm 107 | .idea 108 | 109 | # --- project specific setting --- 110 | /input 111 | jobs/log/ 112 | src/modeling/results/ 113 | src/eda/agent_type_images/ 114 | src/postprocess/results/ 115 | src/ensemble/results/ 116 | -------------------------------------------------------------------------------- /src/lib/nn/models/single/lyft_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | from torch import nn 4 | from typing import Dict 5 | 6 | 7 | class LyftModel(nn.Module): 8 | 9 | def __init__(self, cfg: Dict): 10 | super().__init__() 11 | 12 | # TODO: support other than resnet18? 13 | backbone = resnet18(pretrained=True, progress=True) 14 | self.backbone = backbone 15 | 16 | num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2 17 | num_in_channels = 3 + num_history_channels 18 | 19 | self.backbone.conv1 = nn.Conv2d( 20 | num_in_channels, 21 | self.backbone.conv1.out_channels, 22 | kernel_size=self.backbone.conv1.kernel_size, 23 | stride=self.backbone.conv1.stride, 24 | padding=self.backbone.conv1.padding, 25 | bias=False, 26 | ) 27 | 28 | # This is 512 for resnet18 and resnet34; 29 | # And it is 2048 for the other resnets 30 | backbone_out_features = 512 31 | 32 | # X, Y coords for the future positions (output shape: Bx50x2) 33 | num_targets = 2 * cfg["model_params"]["future_num_frames"] 34 | 35 | # You can add more layers here. 36 | self.head = nn.Sequential( 37 | # nn.Dropout(0.2), 38 | nn.Linear(in_features=backbone_out_features, out_features=4096), 39 | ) 40 | 41 | self.logit = nn.Linear(4096, out_features=num_targets) 42 | 43 | def forward(self, x): 44 | x = self.backbone.conv1(x) 45 | x = self.backbone.bn1(x) 46 | x = self.backbone.relu(x) 47 | x = self.backbone.maxpool(x) 48 | 49 | x = self.backbone.layer1(x) 50 | x = self.backbone.layer2(x) 51 | x = self.backbone.layer3(x) 52 | x = self.backbone.layer4(x) 53 | 54 | x = self.backbone.avgpool(x) 55 | x = torch.flatten(x, 1) 56 | 57 | x = self.head(x) 58 | x = self.logit(x) 59 | 60 | return x 61 | -------------------------------------------------------------------------------- /src/lib/rasterization/combined_rasterizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from l5kit.rasterization import Rasterizer 3 | from typing import List, Optional 4 | 5 | 6 | class CombinedRasterizer(Rasterizer): 7 | def __init__(self, rasterzer_list: List[Rasterizer]): 8 | super(CombinedRasterizer, self).__init__() 9 | self.rasterzer_list = rasterzer_list 10 | try: 11 | self.raster_channels = sum([r.raster_channels for r in rasterzer_list]) 12 | except: 13 | self.raster_channels = -1 14 | 15 | def rasterize( 16 | self, 17 | history_frames: np.ndarray, 18 | history_agents: List[np.ndarray], 19 | history_tl_faces: List[np.ndarray], 20 | agent: Optional[np.ndarray] = None, 21 | ) -> np.ndarray: 22 | image_list = [ 23 | rasterizer.rasterize(history_frames, history_agents, history_tl_faces, agent) 24 | for rasterizer in self.rasterzer_list] 25 | return np.concatenate(image_list, -1) 26 | 27 | def to_rgb(self, in_im: np.ndarray, **kwargs: dict) -> np.ndarray: 28 | try: 29 | ch_list = [rasterizer.raster_channels for rasterizer in self.rasterzer_list] 30 | except Exception as e: 31 | print("Need to set rasterizer.raster_channels attribute to use to_rgb!") 32 | raise e 33 | 34 | ch_split_indices = np.cumsum(ch_list)[:-1] 35 | image = None 36 | for i, im in enumerate(np.split(in_im, ch_split_indices, axis=-1)): 37 | this_rgb_image = self.rasterzer_list[i].to_rgb(im, **kwargs) 38 | if image is None: 39 | image = this_rgb_image 40 | else: 41 | # Overwrite this_rgb_image on top of image 42 | mask_box = np.any(this_rgb_image > 0, -1) 43 | image[mask_box] = this_rgb_image[mask_box] 44 | return image 45 | 46 | def __repr__(self): 47 | return "Combined: " + " + ".join([str(rasterizer) for rasterizer in self.rasterzer_list]) 48 | -------------------------------------------------------------------------------- /src/lib/transforms/augmentation.py: -------------------------------------------------------------------------------- 1 | from albumentations import Compose 2 | import albumentations as A 3 | import numpy as np 4 | from lib.transforms.cross_drop import CrossDrop 5 | 6 | from lib.rasterization.agent_type_box_rasterizer import CAR_LABEL_INDEX, CYCLIST_LABEL_INDEX, PEDESTRIAN_LABEL_INDEX 7 | 8 | label_id_to_index = { 9 | CAR_LABEL_INDEX: 0, 10 | CYCLIST_LABEL_INDEX: 1, 11 | PEDESTRIAN_LABEL_INDEX: 2, 12 | } 13 | 14 | 15 | def _agent_type_onehot(label_probabilities): 16 | label = np.argmax(label_probabilities) 17 | 18 | x_feat = np.array([0, 0, 0], dtype=np.float32) 19 | x_feat[label_id_to_index[label]] = 1.0 20 | return x_feat 21 | 22 | 23 | class ImageAugmentation(object): 24 | 25 | def __init__(self, flags): 26 | self.aug = None 27 | self.set_augmentation(flags) 28 | 29 | self.feat_mode = flags.feat_mode 30 | 31 | def set_augmentation(self, flags): 32 | aug_list = [] 33 | if flags.blur["p"] > 0.0: 34 | aug_list.append(A.Blur(**flags.blur)) 35 | if flags.cutout["p"] > 0.0: 36 | aug_list.append(A.Cutout(**flags.cutout)) 37 | if flags.downscale["p"] > 0.0: 38 | aug_list.append(A.Downscale(**flags.downscale)) 39 | if flags.crossdrop["p"] > 0.0: 40 | aug_list.append(CrossDrop(**flags.crossdrop)) 41 | self.aug = Compose(aug_list) if len(aug_list)!=0 else None 42 | 43 | def transform(self, batch): 44 | if self.aug is not None: 45 | image = batch["image"] 46 | image = self.aug(image=np.moveaxis(image, 0, 2))["image"] 47 | image = np.moveaxis(image, 2, 0) 48 | batch["image"] = image 49 | 50 | if self.feat_mode == "none": 51 | return batch["image"], batch["target_positions"], batch["target_availabilities"] 52 | elif self.feat_mode == "agent_type": 53 | x_feat = _agent_type_onehot(batch["label_probabilities"]) 54 | return batch["image"], batch["target_positions"], batch["target_availabilities"], x_feat 55 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/lyft_multi_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet18 3 | from torch import nn 4 | from typing import Dict 5 | 6 | 7 | class LyftMultiModel(nn.Module): 8 | 9 | def __init__(self, cfg: Dict, num_modes=3, in_channels: int = 0): 10 | super().__init__() 11 | 12 | # TODO: support other than resnet18? 13 | backbone = resnet18(pretrained=True, progress=True) 14 | self.backbone = backbone 15 | 16 | num_in_channels = in_channels 17 | 18 | self.in_channels = num_in_channels 19 | self.backbone.conv1 = nn.Conv2d( 20 | num_in_channels, 21 | self.backbone.conv1.out_channels, 22 | kernel_size=self.backbone.conv1.kernel_size, 23 | stride=self.backbone.conv1.stride, 24 | padding=self.backbone.conv1.padding, 25 | bias=False, 26 | ) 27 | 28 | # This is 512 for resnet18 and resnet34; 29 | # And it is 2048 for the other resnets 30 | backbone_out_features = 512 31 | 32 | # X, Y coords for the future positions (output shape: Bx50x2) 33 | self.future_len = cfg["model_params"]["future_num_frames"] 34 | num_targets = 2 * self.future_len 35 | 36 | # You can add more layers here. 37 | self.head = nn.Sequential( 38 | # nn.Dropout(0.2), 39 | nn.Linear(in_features=backbone_out_features, out_features=4096), 40 | ) 41 | 42 | self.num_preds = num_targets * num_modes 43 | self.num_modes = num_modes 44 | 45 | self.logit = nn.Linear(4096, out_features=self.num_preds + num_modes) 46 | 47 | def forward(self, x): 48 | x = self.backbone.conv1(x) 49 | x = self.backbone.bn1(x) 50 | x = self.backbone.relu(x) 51 | x = self.backbone.maxpool(x) 52 | 53 | x = self.backbone.layer1(x) 54 | x = self.backbone.layer2(x) 55 | x = self.backbone.layer3(x) 56 | x = self.backbone.layer4(x) 57 | 58 | x = self.backbone.avgpool(x) 59 | x = torch.flatten(x, 1) 60 | 61 | x = self.head(x) 62 | x = self.logit(x) 63 | return x 64 | -------------------------------------------------------------------------------- /src/lib/nn/models/cnn_collections/efficient_net_wrapper.py: -------------------------------------------------------------------------------- 1 | try: 2 | from efficientnet_pytorch import EfficientNet 3 | from efficientnet_pytorch.utils import get_same_padding_conv2d, round_filters 4 | 5 | _efficientnet_pytorch_available = True 6 | except ImportError as e: 7 | _efficientnet_pytorch_available = False 8 | 9 | 10 | from torch import nn 11 | 12 | 13 | class EfficientNetWrapper(nn.Module): 14 | 15 | def __init__(self, model_name='efficientnet-b0', use_pretrained=False, in_channels=3): 16 | super(EfficientNetWrapper, self).__init__() 17 | if use_pretrained: 18 | self.model = EfficientNet.from_pretrained(model_name=model_name, in_channels=in_channels) 19 | else: 20 | model = EfficientNet.from_name(model_name, num_classes=1000) 21 | if in_channels != 3: 22 | Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) 23 | out_channels = round_filters(32, model._global_params) 24 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 25 | self.model = model 26 | 27 | def custom_extract_features(self, model, inputs, block_start_index=0, apply_stem=True): 28 | """ Returns output of the final convolution layer """ 29 | 30 | if apply_stem: 31 | # Stem 32 | x = model._swish(model._bn0(model._conv_stem(inputs))) 33 | else: 34 | x = inputs 35 | 36 | # Blocks 37 | for idx, block in enumerate(model._blocks): 38 | if idx < block_start_index: 39 | continue 40 | drop_connect_rate = model._global_params.drop_connect_rate 41 | if drop_connect_rate: 42 | drop_connect_rate *= float(idx) / len(model._blocks) 43 | x = block(x, drop_connect_rate=drop_connect_rate) 44 | 45 | # Head 46 | x = model._swish(model._bn1(model._conv_head(x))) 47 | return x 48 | 49 | def forward(self, x): 50 | x = self.custom_extract_features(self.model, x) 51 | x = self.model._avg_pooling(x) 52 | bs = x.shape[0] 53 | x = x.view(bs, -1) 54 | return x 55 | -------------------------------------------------------------------------------- /src/ensemble/ensemble_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import yaml 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | from numpy.random import multinomial, multivariate_normal 10 | from sklearn.mixture import GaussianMixture 11 | import warnings 12 | warnings.simplefilter('ignore') 13 | 14 | sys.path.append("./src") 15 | sys.path.append(os.pardir) 16 | sys.path.append(os.path.join(os.pardir, os.pardir)) 17 | from lib.mixture.gmm import GaussianMixtureIdentity 18 | 19 | NUM_TEST = 71122 20 | SEED = 0 21 | np.random.seed(SEED) 22 | 23 | 24 | def main(flags): 25 | outdir = flags["outdir"] 26 | w = flags["weight"] 27 | sigma = flags["sigma"] 28 | N_sample = flags["N_sample"] 29 | covariance_type = flags["covariance_type"] 30 | file_list = flags["file_list"] 31 | df_list = [pd.read_csv(file) for file in file_list] 32 | I = np.eye(100)*sigma 33 | samples = [multivariate_normal(np.zeros(100), I) for _ in range(N_sample)] 34 | data = [] 35 | 36 | for idx in tqdm(range(NUM_TEST)): 37 | rows = [df_list[i].loc[idx] for i in range(len(df_list))] 38 | mu = [np.vstack([rows[i][5:105].values,rows[i][105:205].values,rows[i][205:305].values]) 39 | for i in range(len(rows))] 40 | mu = np.concatenate(mu) 41 | confidence = [w[i]*rows[i][2:5].values for i in range(len(w))] 42 | confidence = np.concatenate(confidence) 43 | confidence /= confidence.sum() 44 | x = mu[np.random.choice(3*len(w), size=N_sample, p=confidence)]+samples 45 | if covariance_type=="identity": 46 | gauss = GaussianMixtureIdentity(3, "spherical", random_state=SEED) 47 | else: 48 | gauss = GaussianMixture(3, covariance_type, random_state=SEED) 49 | gauss.fit(x) 50 | confidence_fit = gauss.weights_ 51 | mu_fit = gauss.means_ 52 | row = [None, None] 53 | row += confidence_fit.tolist() 54 | row += mu_fit.reshape(-1).tolist() 55 | data.append(row) 56 | 57 | 58 | df = pd.DataFrame(data=data,columns=df_list[0].columns) 59 | df["timestamp"] = df_list[0]["timestamp"] 60 | df["track_id"] = df_list[0]["track_id"] 61 | df.to_csv(outdir,index=False) 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='') 65 | parser.add_argument('--yaml_filepath', '-y', type=str, 66 | help='Flags yaml file path') 67 | args = parser.parse_args() 68 | with open(args.yaml_filepath, 'r') as f: 69 | flags = yaml.safe_load(f) 70 | main(flags) 71 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/efficientnet_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Dict 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch.nn import Sequential 9 | 10 | sys.path.append(os.pardir) 11 | sys.path.append(os.path.join(os.pardir, os.pardir)) 12 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir)) 13 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir, os.pardir)) 14 | from lib.nn.block.linear_block import LinearBlock 15 | from lib.nn.models.cnn_collections.efficient_net_wrapper import EfficientNetWrapper 16 | from lib.nn.models.multi.multi_utils import calc_out_channels 17 | 18 | 19 | class EfficientNetMulti(nn.Module): 20 | def __init__( 21 | self, cfg, num_modes=3, model_name="efficientnet-b0", use_pretrained=True, use_bn=True, hdim: int = 512, 22 | in_channels: int = 0 23 | ): 24 | super(EfficientNetMulti, self).__init__() 25 | out_dim, num_preds, future_len = calc_out_channels(cfg, num_modes=num_modes) 26 | self.in_channels = in_channels 27 | self.out_dim = out_dim 28 | self.num_preds = num_preds 29 | self.future_len = future_len 30 | self.num_modes = num_modes 31 | 32 | # self.conv0 = nn.Conv2d( 33 | # in_channels, 3, kernel_size=3, stride=1, padding=1, bias=True) 34 | self.base_model = EfficientNetWrapper( 35 | model_name=model_name, use_pretrained=use_pretrained, in_channels=in_channels 36 | ) 37 | activation = F.leaky_relu 38 | 39 | inch = None 40 | lin1 = LinearBlock(inch, hdim, use_bn=use_bn, activation=activation, residual=False) 41 | lin2 = LinearBlock(hdim, out_dim, use_bn=use_bn, activation=None, residual=False) 42 | self.lin_layers = Sequential(lin1, lin2) 43 | 44 | def forward(self, x): 45 | # h = self.conv0(x) 46 | h = x 47 | h = self.base_model(h) 48 | 49 | for layer in self.lin_layers: 50 | h = layer(h) 51 | return h 52 | 53 | 54 | if __name__ == "__main__": 55 | # --- test instantiation --- 56 | from lib.utils.yaml_utils import load_yaml 57 | 58 | cfg = load_yaml("../../../../modeling/configs/0905_cfg.yaml") 59 | num_modes = 3 60 | model = EfficientNetMulti(cfg, num_modes=num_modes) 61 | print(type(model)) 62 | print(model) 63 | 64 | bs = 3 65 | in_channels = model.in_channels 66 | height, width = 224, 224 67 | device = "cuda:0" 68 | 69 | x = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 70 | model.to(device) 71 | # pred, confidences = model(x) 72 | # print("pred", pred.shape, "confidences", confidences.shape) 73 | h = model(x) 74 | print("h", h.shape) 75 | -------------------------------------------------------------------------------- /src/lib/nn/models/rnn_head_multi/lstm_head_multi_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn, Tensor 6 | import timm 7 | 8 | 9 | class LSTMHeadMultiPredictor(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | backbone: str, 14 | in_channels: int, 15 | encoder_hidden_dim: int = 128, 16 | decoder_hidden_dim: int = 128, 17 | num_modes: int = 3, 18 | future_len: int = 50, 19 | ) -> None: 20 | super().__init__() 21 | self.in_channels = in_channels 22 | self.encoder_hidden_dim = encoder_hidden_dim 23 | self.decoder_hidden_dim = decoder_hidden_dim 24 | self.num_modes = num_modes 25 | self.future_len = future_len 26 | 27 | self.backbone = timm.create_model( 28 | model_name=backbone, 29 | pretrained=True, 30 | num_classes=0, 31 | in_chans=in_channels 32 | ) 33 | self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) 34 | self.feature_dim = self.backbone.num_features + encoder_hidden_dim 35 | self.feat_to_confidence = nn.Linear(self.feature_dim, num_modes) 36 | self.feat_to_dec_hidden = nn.Linear(self.feature_dim, num_modes * decoder_hidden_dim) 37 | self.encoder = nn.LSTM(input_size=3, hidden_size=encoder_hidden_dim, batch_first=True) 38 | self.decoder = nn.LSTM(input_size=1, hidden_size=decoder_hidden_dim, batch_first=True) 39 | self.dec_hidden_to_target = nn.Linear(decoder_hidden_dim, 2) 40 | 41 | def forward(self, image: Tensor, history_positions: Tensor, history_availabilities: Tensor): 42 | batch_size = image.shape[0] 43 | 44 | feat = self.backbone.forward_features(image) 45 | feat = self.avg_pool(feat).reshape(batch_size, -1) 46 | 47 | enc_input = torch.cat([history_positions, history_availabilities[..., np.newaxis]], dim=-1) 48 | enc_input = torch.flip(enc_input, dims=(1,)) 49 | _, (enc_hidden, _) = self.encoder(enc_input) 50 | assert enc_hidden.shape == (1, batch_size, self.encoder_hidden_dim) 51 | feat = torch.cat([feat, enc_hidden[0]], dim=1) 52 | 53 | confidence = torch.softmax(self.feat_to_confidence(feat), dim=1) 54 | dec_hidden = self.feat_to_dec_hidden(feat) 55 | dec_hidden = dec_hidden.reshape(1, batch_size * self.num_modes, self.decoder_hidden_dim) 56 | 57 | dec_input = torch.linspace(0, 1, self.future_len, device=dec_hidden.device) 58 | dec_input = dec_input.reshape(1, self.future_len, 1).expand(batch_size * self.num_modes, -1, -1) 59 | dec_output, _ = self.decoder(dec_input, (dec_hidden, torch.zeros_like(dec_hidden))) 60 | dec_output = dec_output.reshape(batch_size * self.num_modes * self.future_len, self.decoder_hidden_dim) 61 | pred = self.dec_hidden_to_target(dec_output) 62 | pred = pred.reshape(batch_size, self.num_modes, self.future_len, 2) 63 | 64 | return pred, confidence 65 | -------------------------------------------------------------------------------- /src/modeling/calc_num_history_vs_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate num history for chopped valid/test data... 3 | """ 4 | import argparse 5 | from distutils.util import strtobool 6 | import numpy as np 7 | import torch 8 | from pathlib import Path 9 | 10 | # from l5kit.dataset import AgentDataset 11 | # from torch.utils.data import DataLoader 12 | # from torch.utils.data.dataset import Subset 13 | # 14 | # from l5kit.data import LocalDataManager, ChunkedDataset 15 | 16 | import sys 17 | import os 18 | 19 | from tqdm import tqdm 20 | sys.path.append(os.pardir) 21 | sys.path.append(os.path.join(os.pardir, os.pardir)) 22 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 23 | # from lib.dataset.faster_agent_dataset import FasterAgentDataset 24 | # from lib.evaluation.mask import load_mask_chopped 25 | # from modeling.load_flag import Flags 26 | # from lib.rasterization.rasterizer_builder import build_custom_rasterizer 27 | # from lib.utils.yaml_utils import save_yaml, load_yaml 28 | 29 | 30 | def parse(): 31 | parser = argparse.ArgumentParser(description='') 32 | parser.add_argument('--pred_npz_path', '-p', default='results/tmp/eval_ema/pred.npz', 33 | help='pred.npz filepath') 34 | parser.add_argument('--debug', '-d', type=strtobool, default='false', 35 | help='Debug mode') 36 | parser.add_argument('--device', type=str, default='cuda:0', 37 | help='Device') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | if __name__ == '__main__': 43 | args = parse() 44 | debug = args.debug 45 | device = args.device 46 | 47 | # --- Load n_availability --- 48 | processed_dir = Path("../../input/processed_data") 49 | npz_path = processed_dir / f"n_history.npz" 50 | print(f"Load from {npz_path}") 51 | n_his = np.load(npz_path) 52 | n_his_avail_valid = n_his["n_his_avail_valid"] 53 | n_his_avail_test = n_his["n_his_avail_test"] 54 | 55 | # --- Load pred --- 56 | preds = np.load(args.pred_npz_path) 57 | coords = preds["coords"] 58 | confs = preds["confs"] 59 | targets = preds["targets"] 60 | target_availabilities = preds["target_availabilities"] 61 | 62 | # Evaluate loss 63 | errors = pytorch_neg_multi_log_likelihood_batch( 64 | torch.as_tensor(targets, device=device), 65 | torch.as_tensor(coords, device=device), 66 | torch.as_tensor(confs, device=device), 67 | torch.as_tensor(target_availabilities, device=device), 68 | reduction="none") 69 | print("errors", errors.shape, torch.mean(errors)) 70 | 71 | n_his_avail_valid = torch.as_tensor(n_his_avail_valid.astype(np.int64), device=device) 72 | for i in range(1, 11): 73 | this_error = errors[n_his_avail_valid == i] 74 | mean_error = torch.mean(this_error) 75 | print(f"i=={i:4.0f}: {mean_error:10.4f}, {len(this_error)}") 76 | 77 | for i in [20, 50, 100]: 78 | this_error = errors[n_his_avail_valid >= i] 79 | mean_error = torch.mean(this_error) 80 | print(f"i>={i:4.0f}: {mean_error:10.4f}, {len(this_error)}") 81 | import IPython; IPython.embed() 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | alembic==1.4.3 3 | argon2-cffi==20.1.0 4 | asciitree==0.3.3 5 | async-generator==1.10 6 | attrs==20.2.0 7 | backcall==0.2.0 8 | bleach==3.2.1 9 | catboost==0.24.1 10 | category-encoders==2.2.2 11 | certifi==2020.6.20 12 | cffi==1.14.3 13 | chardet==3.0.4 14 | cliff==3.4.0 15 | cloudpickle==1.3.0 16 | cmaes==0.6.1 17 | cmd2==1.3.10 18 | colorama==0.4.3 19 | colorlog==4.2.1 20 | cupy-cuda101==7.0.0 21 | cycler==0.10.0 22 | decorator==4.4.2 23 | defusedxml==0.6.0 24 | efficientnet-pytorch==0.7.0 25 | entrypoints==0.3 26 | fasteners==0.15 27 | fastrlock==0.5 28 | future==0.18.2 29 | graphviz==0.14.1 30 | idna==2.10 31 | imageio==2.9.0 32 | imgaug==0.2.6 33 | ipykernel==5.3.4 34 | ipython==7.18.1 35 | ipython-genutils==0.2.0 36 | ipywidgets==7.5.1 37 | jedi==0.17.2 38 | Jinja2==2.11.2 39 | joblib==0.14.1 40 | json5==0.9.5 41 | jsonschema==3.2.0 42 | jupyter==1.0.0 43 | jupyter-client==6.1.7 44 | jupyter-console==6.2.0 45 | jupyter-core==4.6.3 46 | jupyterlab==2.2.8 47 | jupyterlab-pygments==0.1.2 48 | jupyterlab-server==1.2.0 49 | kiwisolver==1.2.0 50 | l5kit==1.1.0 51 | lightgbm==3.0.0 52 | llvmlite==0.35.0 53 | Mako==1.1.3 54 | MarkupSafe==1.1.1 55 | matplotlib==3.2.0 56 | mistune==0.8.4 57 | monotonic==1.5 58 | mpi4py==3.0.3 59 | munch==2.5.0 60 | nbclient==0.5.0 61 | nbconvert==6.0.6 62 | nbformat==5.0.7 63 | nest-asyncio==1.4.1 64 | networkx==2.4 65 | nose==1.3.7 66 | notebook==6.1.4 67 | numba==0.52.0 68 | numcodecs==0.7.2 69 | numpy==1.18.1 70 | opencv-contrib-python-headless==4.4.0.44 71 | opencv-python==4.2.0.32 72 | opencv-python-headless==4.4.0.44 73 | optuna==1.3.0 74 | packaging==20.4 75 | pandas==1.0.2 76 | pandocfilters==1.4.2 77 | parso==0.7.1 78 | patsy==0.5.1 79 | pbr==5.5.0 80 | pexpect==4.8.0 81 | pickleshare==0.7.5 82 | Pillow==7.1.2 83 | plotly==4.10.0 84 | pretrainedmodels==0.7.4 85 | prettytable==0.7.2 86 | prometheus-client==0.8.0 87 | prompt-toolkit==3.0.7 88 | protobuf==3.13.0 89 | PTable==0.9.2 90 | ptyprocess==0.6.0 91 | pycparser==2.20 92 | Pygments==2.7.1 93 | pygobject==3.26.1 94 | pymap3d==2.4.3 95 | pyparsing==2.4.7 96 | pyperclip==1.8.0 97 | pyrsistent==0.17.3 98 | python-apt==1.6.5+ubuntu0.3 99 | python-dateutil==2.8.1 100 | python-editor==1.0.4 101 | pytorch-ignite==0.4.1 102 | pytorch-pfn-extras==0.3.1 103 | pytz==2020.1 104 | PyWavelets==1.1.1 105 | PyYAML==5.3.1 106 | pyzmq==19.0.2 107 | qtconsole==4.7.7 108 | QtPy==1.9.0 109 | requests==2.24.0 110 | resnest==0.0.6b20200912 111 | retrying==1.3.3 112 | scikit-image==0.16.2 113 | scikit-learn==0.22.2.post1 114 | scipy==1.4.1 115 | seaborn==0.9.0 116 | segmentation-models-pytorch==0.1.0 117 | Send2Trash==1.5.0 118 | Shapely==1.7.1 119 | six==1.15.0 120 | SQLAlchemy==1.3.19 121 | statsmodels==0.12.0 122 | stevedore==3.2.2 123 | terminado==0.9.1 124 | testpath==0.4.4 125 | timm==0.3.1 126 | toml==0.10.0 127 | torch==1.6.0+cu101 128 | torchvision==0.7.0 129 | tornado==6.0.4 130 | tqdm==4.45.0 131 | traitlets==5.0.4 132 | transforms3d==0.3.1 133 | urllib3==1.25.10 134 | wcwidth==0.2.5 135 | webencodings==0.5.1 136 | widgetsnbextension==3.5.1 137 | xgboost==1.2.0 138 | zarr==2.4.0 -------------------------------------------------------------------------------- /src/lib/nn/models/deep_ensemble/lyft_multi_deep_ensemble_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Tuple, Optional 2 | 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | from lib.nn.models.multi.multi_model_predictor import LyftMultiModelPredictor 7 | 8 | 9 | # NOTE: Since there is no initial variation in the trained model, 10 | # the variability is incorporated by converting the input by one of the dihedral group D4 11 | D4 = [ 12 | lambda x: x, # (x, y) 13 | lambda x: x.transpose(2, 3).flip(3), # (y, -x) 14 | lambda x: x.flip(2).flip(3), # (-x, -y) 15 | lambda x: x.transpose(2, 3).flip(2), # (-y, x) 16 | lambda x: x.flip(3), # (x, -y) 17 | lambda x: x.transpose(2, 3), # (y, x) 18 | lambda x: x.flip(2), # (-x, y) 19 | lambda x: x.transpose(2, 3).flip(2).flip(3), # (-y, -x) 20 | ] 21 | D4_inv = [D4[k] for k in [0, 3, 2, 1, 4, 5, 6, 7]] 22 | 23 | 24 | class D4Module(nn.Module): 25 | def __init__(self, k: int) -> None: 26 | super().__init__() 27 | self.k = k 28 | 29 | def forward(self, x: Tensor): 30 | x = D4[self.k](x) 31 | return x 32 | 33 | 34 | class LyftMultiDeepEnsemblePredictor(nn.Module): 35 | 36 | def __init__(self, predictors: Sequence[LyftMultiModelPredictor], names: Sequence[str], use_D4: bool = False): 37 | super().__init__() 38 | 39 | if len(predictors) > 8: 40 | raise ValueError("We only support up to 8 models yet.") 41 | 42 | self.predictors = nn.ModuleList(predictors) 43 | self.names = names 44 | self.use_D4 = use_D4 45 | 46 | @torch.no_grad() 47 | def load_state_dict(self, to_load, strict: bool = True) -> None: 48 | if any([key.startswith("predictors") for key in to_load.keys()]): 49 | super().load_state_dict(to_load, strict=strict) 50 | return 51 | 52 | print("Loading from non ensemble snapshot") 53 | old_snapshot = True 54 | for key in to_load.keys(): 55 | if key.startswith("base_model."): 56 | old_snapshot = False 57 | 58 | for predictor in self.predictors: 59 | if old_snapshot: 60 | predictor.base_model.load_state_dict(to_load, strict=strict) 61 | else: 62 | predictor.load_state_dict(to_load, strict=strict) 63 | 64 | # for k, predictor in enumerate(self.predictors): 65 | # for module in predictor.modules(): 66 | # if isinstance(module, nn.Conv2d): 67 | # module.weight[:] = D4[k](module.weight) 68 | 69 | def forward(self, x: Tensor, x_feat: Optional[Tensor] = None) -> Sequence[Tuple[Tensor, Tensor]]: 70 | ys = [] 71 | for k, predictor in enumerate(self.predictors): 72 | if self.use_D4: 73 | x = D4[k](x) 74 | ys.append(predictor(x, x_feat)) 75 | return ys 76 | 77 | def get_kth_predictor(self, k: int): 78 | if k >= len(self.predictors): 79 | raise ValueError 80 | if self.use_D4: 81 | return nn.Sequential(D4Module(k), self.predictors[k]) 82 | else: 83 | return self.predictors[k] 84 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/timm_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import nn, Tensor 7 | import timm 8 | from timm.models.registry import register_model 9 | from timm.models.helpers import build_model_with_cfg 10 | from timm.models.resnet import BasicBlock, ResNet 11 | from timm.models.resnet import _cfg as timm_resnet_cfg 12 | from segmentation_models_pytorch.base.modules import SCSEModule 13 | 14 | sys.path.append(os.pardir) 15 | sys.path.append(os.path.join(os.pardir, os.pardir)) 16 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir)) 17 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir, os.pardir)) 18 | from lib.nn.models.multi.multi_utils import calc_out_channels 19 | from lib.nn.block.feat_module import FeatModule 20 | from lib.nn.block.scea_module import SCEAModule 21 | 22 | 23 | @register_model 24 | def scseresnet18(pretrained=False, **kwargs): 25 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer=SCSEModule), **kwargs) 26 | default_cfg = timm_resnet_cfg(url='', interpolation='bicubic'), 27 | return build_model_with_cfg( 28 | ResNet, 'scseresnet18', default_cfg=default_cfg, pretrained=pretrained, **model_args 29 | ) 30 | 31 | 32 | @register_model 33 | def scearesnet18(pretrained=False, **kwargs): 34 | model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer=SCEAModule), **kwargs) 35 | default_cfg = timm_resnet_cfg() 36 | return build_model_with_cfg( 37 | ResNet, 'scearesnet18', default_cfg=default_cfg, pretrained=pretrained, **model_args 38 | ) 39 | 40 | 41 | class TimmMulti(nn.Module): 42 | def __init__( 43 | self, 44 | cfg, 45 | in_channels: int, 46 | num_modes=3, 47 | backbone: str = "resnet18", 48 | use_pretrained=True, 49 | hdim: int = 4096, 50 | feat_module_type: str = "none", 51 | feat_channels: int = -1, 52 | ) -> None: 53 | super().__init__() 54 | out_dim, num_preds, future_len = calc_out_channels(cfg, num_modes=num_modes) 55 | self.in_channels = in_channels 56 | self.num_modes = num_modes 57 | self.hdim = hdim 58 | self.out_dim = out_dim 59 | self.num_preds = num_preds 60 | self.future_len = future_len 61 | self.backbone = timm.create_model( 62 | model_name=backbone, 63 | pretrained=use_pretrained, 64 | num_classes=0, 65 | in_chans=in_channels 66 | ) 67 | self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) 68 | self.feature_dim = self.backbone.num_features 69 | self.dense = nn.Sequential( 70 | nn.Linear(self.feature_dim, hdim), 71 | nn.LeakyReLU(), 72 | nn.Linear(hdim, out_dim) 73 | ) 74 | 75 | self.feat_module_type = feat_module_type 76 | self.feat_module = FeatModule( 77 | feat_module_type=feat_module_type, 78 | channels=self.feature_dim, 79 | feat_channels=feat_channels, 80 | ) 81 | self.feat_channels = feat_channels 82 | 83 | def forward(self, x: Tensor, x_feat: Optional[Tensor] = None) -> Tensor: 84 | x = self.backbone.forward_features(x) 85 | x = self.avg_pool(x).reshape(*x.shape[:2]) 86 | x = self.feat_module(x, x_feat) 87 | x = self.dense(x) 88 | return x 89 | -------------------------------------------------------------------------------- /src/lib/dataset/custom_ego_dataset.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional 3 | import numpy as np 4 | 5 | from l5kit.data import get_frames_slice_from_scenes 6 | 7 | 8 | def get_frame_custom(self, scene_index: int, state_index: int, track_id: Optional[int] = None) -> dict: 9 | """Customized `get_frame` function, which returns all `data` entries of `sample_function`. 10 | A utility function to get the rasterisation and trajectory target for a given agent in a given frame 11 | 12 | Args: 13 | self: Ego Dataset 14 | scene_index (int): the index of the scene in the zarr 15 | state_index (int): a relative frame index in the scene 16 | track_id (Optional[int]): the agent to rasterize or None for the AV 17 | Returns: 18 | dict: the rasterised image, the target trajectory (position and yaw) along with their availability, 19 | the 2D matrix to center that agent, the agent track (-1 if ego) and the timestamp 20 | 21 | """ 22 | frames = self.dataset.frames[get_frames_slice_from_scenes(self.dataset.scenes[scene_index])] 23 | 24 | tl_faces = self.dataset.tl_faces 25 | try: 26 | if self.cfg["raster_params"]["disable_traffic_light_faces"]: 27 | tl_faces = np.empty(0, dtype=self.dataset.tl_faces.dtype) # completely disable traffic light faces 28 | except KeyError: 29 | warnings.warn( 30 | "disable_traffic_light_faces not found in config, this will raise an error in the future", 31 | RuntimeWarning, 32 | stacklevel=2, 33 | ) 34 | data = self.sample_function(state_index, frames, self.dataset.agents, tl_faces, track_id) 35 | # 0,1,C -> C,0,1 36 | image = data["image"].transpose(2, 0, 1) 37 | 38 | target_positions = np.array(data["target_positions"], dtype=np.float32) 39 | target_yaws = np.array(data["target_yaws"], dtype=np.float32) 40 | 41 | history_positions = np.array(data["history_positions"], dtype=np.float32) 42 | history_yaws = np.array(data["history_yaws"], dtype=np.float32) 43 | 44 | timestamp = frames[state_index]["timestamp"] 45 | track_id = np.int64(-1 if track_id is None else track_id) # always a number to avoid crashing torch 46 | 47 | data["image"] = image 48 | data["target_positions"] = target_positions 49 | data["target_yaws"] = target_yaws 50 | data["history_positions"] = history_positions 51 | data["history_yaws"] = history_yaws 52 | data["track_id"] = track_id 53 | data["timestamp"] = timestamp 54 | return data 55 | # return { 56 | # "image": image, 57 | # "target_positions": target_positions, 58 | # "target_yaws": target_yaws, 59 | # "target_availabilities": data["target_availabilities"], 60 | # "history_positions": history_positions, 61 | # "history_yaws": history_yaws, 62 | # "history_availabilities": data["history_availabilities"], 63 | # "world_to_image": data["raster_from_world"], # TODO deprecate 64 | # "raster_from_world": data["raster_from_world"], 65 | # "raster_from_agent": data["raster_from_agent"], 66 | # "agent_from_world": data["agent_from_world"], 67 | # "world_from_agent": data["world_from_agent"], 68 | # "track_id": track_id, 69 | # "timestamp": timestamp, 70 | # "centroid": data["centroid"], 71 | # "yaw": data["yaw"], 72 | # "extent": data["extent"], 73 | # } 74 | -------------------------------------------------------------------------------- /src/lib/rasterization/rasterizer_builder.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from l5kit.data import DataManager 4 | from l5kit.rasterization import Rasterizer, build_rasterizer 5 | 6 | from lib.rasterization.agent_type_box_rasterizer import AgentTypeBoxRasterizer 7 | from lib.rasterization.augmented_box_rasterizer import AugmentedBoxRasterizer 8 | from lib.rasterization.channel_semantic_rasterizer import ChannelSemanticRasterizer 9 | from lib.rasterization.channel_semantic_tl_rasterizer import ChannelSemanticTLRasterizer 10 | from lib.rasterization.combined_rasterizer import CombinedRasterizer 11 | from lib.rasterization.tl_semantic_rasterizer import TLSemanticRasterizer 12 | from lib.rasterization.tuned_semantic_rasterizer import TunedSemanticRasterizer 13 | from lib.rasterization.velocity_rasterizer import VelocityBoxRasterizer 14 | from lib.rasterization.tuned_box_rasterizer import TunedBoxRasterizer 15 | 16 | 17 | def build_one_rasterizer(map_type: str, cfg: dict, data_manager: DataManager, eval: bool = False) -> Rasterizer: 18 | cfg_ = deepcopy(cfg) 19 | cfg_["raster_params"]["map_type"] = map_type 20 | if map_type == "channel_semantic": 21 | return ChannelSemanticRasterizer.from_cfg(cfg_, data_manager) 22 | elif map_type == "channel_semantic_tl": 23 | return ChannelSemanticTLRasterizer.from_cfg(cfg_, data_manager) 24 | elif map_type == "velocity_box": 25 | return VelocityBoxRasterizer.from_cfg(cfg_, data_manager) 26 | elif map_type == "agent_type_box": 27 | return AgentTypeBoxRasterizer.from_cfg(cfg_, data_manager) 28 | elif map_type == "augmented_box": 29 | return AugmentedBoxRasterizer.from_cfg(cfg_, data_manager, eval=eval) 30 | elif map_type == "tl_semantic": 31 | return TLSemanticRasterizer.from_cfg(cfg_, data_manager) 32 | elif map_type == "tuned_box": 33 | return TunedBoxRasterizer.from_cfg(cfg_, data_manager) 34 | elif map_type == "tuned_semantic": 35 | return TunedSemanticRasterizer.from_cfg(cfg_, data_manager) 36 | else: 37 | # Use Original l5kit rasterizer 38 | rasterizer = build_rasterizer(cfg_, data_manager) 39 | if map_type == "py_satellite": 40 | rasterizer.raster_channels = (rasterizer.history_num_frames + 1) * 2 + 3 41 | elif map_type == "satellite_debug": 42 | rasterizer.raster_channels = 3 43 | elif map_type == "py_semantic": 44 | rasterizer.raster_channels = (rasterizer.history_num_frames + 1) * 2 + 3 45 | elif map_type == "semantic_debug": 46 | rasterizer.raster_channels = 3 47 | elif map_type == "box_debug": 48 | rasterizer.raster_channels = (rasterizer.history_num_frames + 1) * 2 49 | elif map_type == "stub_debug": 50 | history_num_frames = cfg_["model_params"]["history_num_frames"] 51 | rasterizer.raster_channels = history_num_frames * 2 52 | return rasterizer 53 | 54 | 55 | def build_custom_rasterizer(cfg: dict, data_manager: DataManager, eval: bool = False) -> Rasterizer: 56 | raster_cfg = cfg["raster_params"] 57 | map_type = raster_cfg["map_type"] 58 | 59 | map_type_list = map_type.split("+") 60 | rasterizer_list = [build_one_rasterizer(map_type, cfg, data_manager, eval=eval) 61 | for map_type in map_type_list] 62 | if len(rasterizer_list) == 1: 63 | # Only 1 rasterizer used. 64 | rasterizer = rasterizer_list[0] 65 | else: 66 | # If more than 2, use combined rasterizer. 67 | rasterizer = CombinedRasterizer(rasterizer_list) 68 | return rasterizer 69 | -------------------------------------------------------------------------------- /src/lib/utils/numba_utils.py: -------------------------------------------------------------------------------- 1 | import numba as nb 2 | import numpy as np 3 | 4 | 5 | @nb.jit(nb.float64[:](nb.float64[:], nb.float64[:, :]), nopython=True, nogil=True) 6 | def transform_point_nb(point: np.ndarray, transf_matrix: np.ndarray) -> np.ndarray: 7 | """ Transform a single vector using transformation matrix. 8 | 9 | Args: 10 | point (np.ndarray): vector of shape (N) 11 | transf_matrix (np.ndarray): transformation matrix of shape (N+1, N+1) 12 | 13 | Returns: 14 | np.ndarray: vector of same shape as input point 15 | """ 16 | point_ext = np.ascontiguousarray(np.hstack((point, np.ones(1)))) 17 | # (N+1, N+1) @ (N+1) 18 | # p = np.matmul(transf_matrix, point_ext)[: point.shape[0]] 19 | p = np.dot(np.ascontiguousarray(transf_matrix), point_ext)[: point.shape[0]] 20 | return p 21 | 22 | 23 | @nb.jit(nb.float64[:, :](nb.float64[:, :], nb.float64[:, :]), nopython=True, nogil=True) 24 | def transform_points_nb(points: np.ndarray, transf_matrix: np.ndarray) -> np.ndarray: 25 | """ 26 | Transform points using transformation matrix. 27 | Note this function assumes points.shape[1] == matrix.shape[1] - 1, which means that the last row on the matrix 28 | does not influence the final result. 29 | For 2D points only the first 2x3 part of the matrix will be used. 30 | 31 | Args: 32 | points (np.ndarray): Input points (Nx2) or (Nx3). 33 | transf_matrix (np.ndarray): 3x3 or 4x4 transformation matrix for 2D and 3D input respectively 34 | 35 | Returns: 36 | np.ndarray: array of shape (N,2) for 2D input points, or (N,3) points for 3D input points 37 | """ 38 | assert len(points.shape) == len(transf_matrix.shape) == 2 39 | assert transf_matrix.shape[0] == transf_matrix.shape[1] 40 | 41 | assert points.shape[1] in [2, 3] 42 | # if points.shape[1] not in [2, 3]: 43 | # raise AssertionError("Points input should be (N, 2) or (N,3) shape, received {}".format(points.shape)) 44 | # assert points.shape[1] == 2 45 | 46 | assert points.shape[1] == transf_matrix.shape[1] - 1, "points dim should be one less than matrix dim" 47 | 48 | num_dims = len(transf_matrix) - 1 49 | transf_matrix = transf_matrix.T 50 | 51 | # return points @ transf_matrix[:num_dims, :num_dims] + transf_matrix[-1, :num_dims] 52 | return np.dot( 53 | np.ascontiguousarray(points), 54 | np.ascontiguousarray(transf_matrix[:num_dims, :num_dims]) 55 | ) + transf_matrix[-1, :num_dims] 56 | 57 | 58 | # --- For SemanticRasterizer --- 59 | @nb.jit(nb.int64[:](nb.float64[:], nb.float64[:, :, :], nb.float64), nopython=True, nogil=True) 60 | def elements_within_bounds_nb(center: np.ndarray, bounds: np.ndarray, half_extent: float) -> np.ndarray: 61 | """ 62 | Get indices of elements for which the bounding box described by bounds intersects the one defined around 63 | center (square with side 2*half_side) 64 | 65 | Args: 66 | center (float): XY of the center 67 | bounds (np.ndarray): array of shape Nx2x2 [[x_min,y_min],[x_max, y_max]] 68 | half_extent (float): half the side of the bounding box centered around center 69 | 70 | Returns: 71 | np.ndarray: indices of elements inside radius from center 72 | """ 73 | x_center, y_center = center 74 | 75 | x_min_in = x_center > bounds[:, 0, 0] - half_extent 76 | y_min_in = y_center > bounds[:, 0, 1] - half_extent 77 | x_max_in = x_center < bounds[:, 1, 0] + half_extent 78 | y_max_in = y_center < bounds[:, 1, 1] + half_extent 79 | indices = np.nonzero(x_min_in & y_min_in & x_max_in & y_max_in)[0] 80 | return indices 81 | -------------------------------------------------------------------------------- /src/modeling/load_flag.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import deepcopy 3 | from distutils.util import strtobool 4 | 5 | from typing import Dict, Any, Tuple 6 | 7 | from dataclasses import dataclass, field 8 | 9 | from lib.utils.yaml_utils import load_yaml 10 | 11 | 12 | @dataclass 13 | class Flags: 14 | # --- Overall configs --- 15 | debug: bool = False 16 | cfg_filepath: str = "configs/0905_cfg.yaml" 17 | # --- Data configs --- 18 | l5kit_data_folder: str = "../../input/lyft-motion-prediction-autonomous-vehicles" 19 | min_frame_history: int = 10 # minimum frame history used in AgentDataset 20 | min_frame_future: int = 1 # minimum frame future used in AgentDataset 21 | override_sample_function_name: str = "" # override sample function. Ex. "generate_agent_sample_tl_history" 22 | # --- Model configs --- 23 | pred_mode: str = "multi" 24 | model_name: str = "resnet18" 25 | model_kwargs: Dict[str, Any] = field(default_factory=dict) 26 | target_scale_filepath: str = "" 27 | # --- Training configs --- 28 | device: str = "cuda:0" 29 | out_dir: str = "results/multi_train" 30 | epoch: int = 2 31 | snapshot_freq: int = 5000 32 | scheduler_type: str = "exponential" 33 | scheduler_kwargs: Dict[str, Any] = field(default_factory=lambda: {"gamma": 0.999999}) 34 | scheduler_trigger: Tuple[int, str] = (1, "iteration") 35 | ema_decay: float = 0.999 # negative value is to inactivate ema. 36 | validation_freq: int = 20000 # validation frequency 37 | validation_chopped: bool = False # use chopped validation dataset or not. 38 | n_valid_data: int = 10000 # number of validation data 39 | resume_if_possible: bool = True # Resume when predictor.pt is found in outdir 40 | load_predictor_filepath: str = "" # Start from this pretrained predictor, if specified 41 | scene_sampler: bool = False # Generate one example per scene, if specified 42 | scene_sampler_min_state_index: int = 0 # min_state_index for SceneSampler 43 | augmentation_in_validation : bool = False # apply augmentation in validation 44 | cutout: Dict[str, Any] = field(default_factory=lambda: {"p": 0.0, "scale_min": 0.75, "scale_max": 0.99}) # "p": 0 means no augmentation 45 | blur: Dict[str, Any] = field(default_factory=lambda: {"p": 0.0, "blur_limit": [3, 5]}) # "p": 0 means no augmentation 46 | downscale: Dict[str, Any] = field(default_factory=lambda: {"p": 0.0, "num_holes": 5, "max_h_size": 20, "max_w_size": 20, "fill_value": 0}) # "p": 0 means no augmentation 47 | crossdrop: Dict[str, Any] = field(default_factory=lambda: {"p": 0.0, "max_h_cut": 0.3, "max_w_cut": 0.3, "fill_value": 0}) # "p": 0 means no augmentation 48 | feat_mode: str = "none" # Append `x_feat` feature. "agent_type" is supported now. 49 | include_valid: bool = False # Include validation dataset as train data or not. 50 | lossfun: str = "pytorch_neg_multi_log_likelihood_batch" # Loss function 51 | 52 | def update(self, param_dict: Dict): 53 | # Overwrite by `param_dict` 54 | for key, value in param_dict.items(): 55 | if not hasattr(self, key): 56 | raise ValueError(f"[ERROR] Unexpected key for flag = {key}") 57 | setattr(self, key, value) 58 | 59 | 60 | def load_flags(mode="") -> Flags: 61 | parser = argparse.ArgumentParser(description='') 62 | parser.add_argument('--yaml_filepath', '-y', type=str, default="./flags/20200905_sample_flag.yaml", 63 | help='Flags yaml file path') 64 | args = parser.parse_args() 65 | 66 | # --- Default setting --- 67 | 68 | flags = load_yaml(args.yaml_filepath) 69 | # print("yaml flags", flags) 70 | base_flags = Flags() 71 | base_flags.update(flags) 72 | return base_flags 73 | -------------------------------------------------------------------------------- /src/lib/evaluation/mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from zarr import convenience 7 | 8 | from l5kit.data import ChunkedDataset, get_agents_slice_from_frames 9 | from l5kit.dataset.select_agents import TH_DISTANCE_AV, TH_EXTENT_RATIO, TH_YAW_DEGREE, select_agents 10 | 11 | MIN_FUTURE_STEPS = 10 12 | 13 | 14 | def get_mask_chopped_path( 15 | zarr_path: str, th_agent_prob: float, num_frames_to_copy: int, min_frame_future: int 16 | ) -> Path: 17 | zarr_path = Path(zarr_path) 18 | dest_path = zarr_path.parent / f"{zarr_path.stem}_chopped_valid" 19 | os.makedirs(str(dest_path), exist_ok=True) 20 | mask_chopped_path = dest_path / f"mask_{th_agent_prob}_{num_frames_to_copy}_{min_frame_future}.npz" 21 | return mask_chopped_path 22 | 23 | 24 | def create_chopped_mask( 25 | zarr_path: str, th_agent_prob: float, num_frames_to_copy: int, min_frame_future: int 26 | ) -> str: 27 | """Create mask to emulate chopped dataset with gt data. 28 | 29 | Args: 30 | zarr_path (str): input zarr path to be chopped 31 | th_agent_prob (float): threshold over agents probabilities used in select_agents function 32 | num_frames_to_copy (int): number of frames to copy from the beginning of each scene, others will be discarded 33 | min_frame_future (int): minimum number of frames that must be available in the future for an agent 34 | 35 | Returns: 36 | str: Path to saved mask 37 | """ 38 | zarr_path = Path(zarr_path) 39 | mask_chopped_path = get_mask_chopped_path(zarr_path, th_agent_prob, num_frames_to_copy, min_frame_future) 40 | 41 | # Create standard mask for the dataset so we can use it to filter out unreliable agents 42 | zarr_dt = ChunkedDataset(str(zarr_path)) 43 | zarr_dt.open() 44 | 45 | agents_mask_path = Path(zarr_path) / f"agents_mask/{th_agent_prob}" 46 | if not agents_mask_path.exists(): # don't check in root but check for the path 47 | select_agents( 48 | zarr_dt, 49 | th_agent_prob=th_agent_prob, 50 | th_yaw_degree=TH_YAW_DEGREE, 51 | th_extent_ratio=TH_EXTENT_RATIO, 52 | th_distance_av=TH_DISTANCE_AV, 53 | ) 54 | agents_mask_origin = np.asarray(convenience.load(str(agents_mask_path))) 55 | 56 | # compute the chopped boolean mask, but also the original one limited to frames of interest for GT csv 57 | agents_mask_orig_bool = np.zeros(len(zarr_dt.agents), dtype=np.bool) 58 | 59 | for idx in range(len(zarr_dt.scenes)): 60 | scene = zarr_dt.scenes[idx] 61 | 62 | frame_original = zarr_dt.frames[scene["frame_index_interval"][0] + num_frames_to_copy - 1] 63 | slice_agents_original = get_agents_slice_from_frames(frame_original) 64 | 65 | mask = agents_mask_origin[slice_agents_original][:, 1] >= min_frame_future 66 | agents_mask_orig_bool[slice_agents_original] = mask.copy() 67 | 68 | # store the mask and the GT csv of frames on interest 69 | np.savez(str(mask_chopped_path), agents_mask_orig_bool) 70 | return str(mask_chopped_path) 71 | 72 | 73 | def load_mask_chopped( 74 | zarr_path: str, th_agent_prob: float, num_frames_to_copy: int, min_frame_future: int 75 | ) -> np.ndarray: 76 | mask_chopped_path = get_mask_chopped_path( 77 | zarr_path, th_agent_prob, num_frames_to_copy, min_frame_future) 78 | if not mask_chopped_path.exists(): 79 | print(f"Cache not exist, creating {mask_chopped_path}") 80 | mask_chopped_path2 = create_chopped_mask(zarr_path, th_agent_prob, num_frames_to_copy, min_frame_future) 81 | assert str(mask_chopped_path) == str(mask_chopped_path2) 82 | agents_mask_orig_bool = np.load(str(mask_chopped_path))["arr_0"] 83 | return agents_mask_orig_bool 84 | -------------------------------------------------------------------------------- /src/lib/training/exponential_moving_average.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | from torch import nn 4 | 5 | 6 | class EMA(object): 7 | """Exponential moving average of model parameters. 8 | 9 | Ref 10 | - https://github.com/tensorflow/addons/blob/v0.10.0/tensorflow_addons/optimizers/moving_average.py#L26-L103 11 | - https://anmoljoshi.com/Pytorch-Dicussions/ 12 | 13 | Args: 14 | model (nn.Module): Model with parameters whose EMA will be kept. 15 | decay (float): Decay rate for exponential moving average. 16 | strict (bool): Apply strict check for `assign` & `resume`. 17 | use_dynamic_decay (bool): Dynamically change decay rate. If `True`, small decay rate is 18 | used at the beginning of training to move moving average faster. 19 | """ # NOQA 20 | 21 | def __init__( 22 | self, 23 | model: nn.Module, 24 | decay: float, 25 | strict: bool = True, 26 | use_dynamic_decay: bool = True, 27 | ): 28 | self.decay = decay 29 | self.model = model 30 | self.strict = strict 31 | self.use_dynamic_decay = use_dynamic_decay 32 | self.logger = getLogger(__name__) 33 | self.n_step = 0 34 | 35 | self.shadow = {} 36 | self.original = {} 37 | 38 | # Flag to manage which parameter is assigned. 39 | # When `False`, original model's parameter is used. 40 | # When `True` (`assign` method is called), `shadow` parameter (ema param) is used. 41 | self._assigned = False 42 | 43 | # Register model parameters 44 | for name, param in model.named_parameters(): 45 | if param.requires_grad: 46 | self.shadow[name] = param.data.clone() 47 | 48 | def step(self): 49 | self.n_step += 1 50 | if self.use_dynamic_decay: 51 | _n_step = float(self.n_step) 52 | decay = min(self.decay, (1.0 + _n_step) / (10.0 + _n_step)) 53 | else: 54 | decay = self.decay 55 | 56 | for name, param in self.model.named_parameters(): 57 | if param.requires_grad: 58 | assert name in self.shadow 59 | new_average = (1.0 - decay) * param.data + decay * self.shadow[name] 60 | self.shadow[name] = new_average.clone() 61 | 62 | # alias 63 | __call__ = step 64 | 65 | def assign(self): 66 | """Assign exponential moving average of parameter values to the respective parameters.""" 67 | if self._assigned: 68 | if self.strict: 69 | raise ValueError("[ERROR] `assign` is called again before `resume`.") 70 | else: 71 | self.logger.warning( 72 | "`assign` is called again before `resume`." 73 | "shadow parameter is already assigned, skip." 74 | ) 75 | return 76 | 77 | for name, param in self.model.named_parameters(): 78 | if param.requires_grad: 79 | assert name in self.shadow 80 | self.original[name] = param.data.clone() 81 | param.data = self.shadow[name] 82 | self._assigned = True 83 | 84 | def resume(self): 85 | """Restore original parameters to a model. 86 | 87 | That is, put back the values that were in each parameter at the last call to `assign`. 88 | """ 89 | if not self._assigned: 90 | if self.strict: 91 | raise ValueError("[ERROR] `resume` is called before `assign`.") 92 | else: 93 | self.logger.warning("`resume` is called before `assign`, skip.") 94 | return 95 | 96 | for name, param in self.model.named_parameters(): 97 | if param.requires_grad: 98 | assert name in self.shadow 99 | param.data = self.original[name] 100 | self._assigned = False 101 | -------------------------------------------------------------------------------- /src/lib/training/scene_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import Sampler 5 | import torch.distributed as dist 6 | import numpy as np 7 | 8 | 9 | def get_valid_starts_and_ends(get_frame_arguments: np.ndarray, min_state_index: int = 0): 10 | get_frame_arguments = get_frame_arguments[:] # put on the memory if the array is zarr 11 | 12 | scene_change_points = np.where(np.diff(get_frame_arguments[:, 1], 1) > 0)[0] + 1 13 | starts = np.r_[0, scene_change_points] 14 | ends = np.r_[scene_change_points, len(get_frame_arguments)] 15 | 16 | valid_starts, valid_ends = [], [] 17 | while len(starts) > 0: 18 | ok = get_frame_arguments[starts, 2] >= min_state_index 19 | valid_starts.append(starts[ok]) 20 | valid_ends.append(ends[ok]) 21 | starts, ends = starts[~ok], ends[~ok] 22 | 23 | starts += 1 24 | ok = starts < ends 25 | starts, ends = starts[ok], ends[ok] 26 | 27 | return np.concatenate(valid_starts), np.concatenate(valid_ends) 28 | 29 | 30 | class SceneSampler(Sampler): 31 | 32 | def __init__(self, get_frame_arguments: np.ndarray, min_state_index: int = 0) -> None: 33 | self.starts, self.ends = get_valid_starts_and_ends(get_frame_arguments, min_state_index) 34 | 35 | def __len__(self) -> int: 36 | return len(self.starts) 37 | 38 | def __iter__(self): 39 | indices = np.random.permutation(len(self.starts)) 40 | return iter(np.random.randint(self.starts[indices], self.ends[indices])) 41 | 42 | 43 | class DistributedSceneSampler(Sampler): 44 | 45 | def __init__( 46 | self, 47 | get_frame_arguments: np.ndarray, 48 | min_state_index: int = 0, 49 | num_replicas=None, 50 | rank=None, 51 | shuffle=True, 52 | seed=0 53 | ) -> None: 54 | if num_replicas is None: 55 | if not dist.is_available(): 56 | raise RuntimeError("Requires distributed package to be available") 57 | num_replicas = dist.get_world_size() 58 | if rank is None: 59 | if not dist.is_available(): 60 | raise RuntimeError("Requires distributed package to be available") 61 | rank = dist.get_rank() 62 | self.starts, self.ends = get_valid_starts_and_ends(get_frame_arguments, min_state_index) 63 | self.num_replicas = num_replicas 64 | self.rank = rank 65 | self.epoch = 0 66 | self.num_samples = int(math.ceil(len(self.starts) * 1.0 / self.num_replicas)) 67 | self.total_size = self.num_samples * self.num_replicas 68 | self.shuffle = shuffle 69 | self.seed = seed 70 | 71 | def __iter__(self): 72 | if self.shuffle: 73 | # deterministically shuffle based on epoch and seed 74 | g = torch.Generator() 75 | g.manual_seed(self.seed + self.epoch) 76 | indices = torch.randperm(len(self.starts), generator=g).tolist() 77 | else: 78 | indices = list(range(len(self.starts))) 79 | 80 | # add extra samples to make it evenly divisible 81 | indices += indices[:(self.total_size - len(indices))] 82 | assert len(indices) == self.total_size 83 | 84 | # subsample 85 | indices = indices[self.rank:self.total_size:self.num_replicas] 86 | assert len(indices) == self.num_samples 87 | 88 | return iter(np.random.randint(self.starts[indices], self.ends[indices])) 89 | 90 | def __len__(self): 91 | return self.num_samples 92 | 93 | def set_epoch(self, epoch): 94 | r""" 95 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 96 | use a different random ordering for each epoch. Otherwise, the next iteration of this 97 | sampler will yield the same ordering. 98 | 99 | Arguments: 100 | epoch (int): Epoch number. 101 | """ 102 | self.epoch = epoch 103 | -------------------------------------------------------------------------------- /src/lib/dataset/multi_agent_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional, cast, Tuple 3 | import numpy as np 4 | 5 | from l5kit.data import ChunkedDataset, get_frames_slice_from_scenes 6 | from l5kit.dataset import EgoDataset 7 | from l5kit.kinematic import Perturbation 8 | from l5kit.rasterization import Rasterizer, RenderContext 9 | 10 | import sys 11 | import os 12 | sys.path.append(os.pardir) 13 | sys.path.append(os.path.join(os.pardir, os.pardir)) 14 | from lib.sampling.multi_agent_sampling import generate_multi_agent_sample 15 | 16 | 17 | class MultiAgentDataset(EgoDataset): 18 | def __init__( 19 | self, 20 | cfg: dict, 21 | zarr_dataset: ChunkedDataset, 22 | rasterizer: Rasterizer, 23 | perturbation: Optional[Perturbation] = None, 24 | min_frame_history: int = 1, 25 | min_frame_future: int = 10, 26 | ): 27 | super(MultiAgentDataset, self).__init__(cfg, zarr_dataset, rasterizer, perturbation) 28 | 29 | render_context = RenderContext( 30 | raster_size_px=np.array(cfg["raster_params"]["raster_size"]), 31 | pixel_size_m=np.array(cfg["raster_params"]["pixel_size"]), 32 | center_in_raster_ratio=np.array(cfg["raster_params"]["ego_center"]), 33 | ) 34 | 35 | self.sample_function = partial( 36 | generate_multi_agent_sample, 37 | render_context=render_context, 38 | history_num_frames=cfg["model_params"]["history_num_frames"], 39 | history_step_size=cfg["model_params"]["history_step_size"], 40 | future_num_frames=cfg["model_params"]["future_num_frames"], 41 | future_step_size=cfg["model_params"]["future_step_size"], 42 | filter_agents_threshold=cfg["raster_params"]["filter_agents_threshold"], 43 | rasterizer=rasterizer, 44 | perturbation=perturbation, 45 | min_frame_history=min_frame_history, 46 | min_frame_future=min_frame_future, 47 | ) 48 | 49 | def get_frame(self, scene_index: int, state_index: int, track_id: Optional[int] = None) -> dict: 50 | """ 51 | A utility function to get the rasterisation and trajectory target for a given agent in a given frame 52 | 53 | Args: 54 | scene_index (int): the index of the scene in the zarr 55 | state_index (int): a relative frame index in the scene 56 | track_id (Optional[int]): the agent to rasterize or None for the AV 57 | Returns: 58 | dict: the rasterised image, the target trajectory (position and yaw) along with their availability, 59 | the 2D matrix to center that agent, the agent track (-1 if ego) and the timestamp 60 | 61 | """ 62 | frames = self.dataset.frames[get_frames_slice_from_scenes(self.dataset.scenes[scene_index])] 63 | data = self.sample_function(state_index, frames, self.dataset.agents, self.dataset.tl_faces, track_id) 64 | # 0,1,C -> C,0,1 65 | image = data["image"].transpose(2, 0, 1) 66 | 67 | target_positions = np.array(data["target_positions"], dtype=np.float32) 68 | target_yaws = np.array(data["target_yaws"], dtype=np.float32) 69 | 70 | history_positions = np.array(data["history_positions"], dtype=np.float32) 71 | history_yaws = np.array(data["history_yaws"], dtype=np.float32) 72 | 73 | timestamp = frames[state_index]["timestamp"] 74 | track_id = np.int64(-1 if track_id is None else track_id) # always a number to avoid crashing torch 75 | 76 | return { 77 | "image": image, 78 | "target_positions": target_positions, 79 | "target_yaws": target_yaws, 80 | "target_availabilities": data["target_availabilities"], 81 | "history_positions": history_positions, 82 | "history_yaws": history_yaws, 83 | "history_availabilities": data["history_availabilities"], 84 | # "world_to_image": data["raster_from_world"], 85 | "raster_from_world": data["raster_from_world"], # (3, 3) 86 | "track_id": track_id, 87 | "timestamp": timestamp, 88 | "centroid": data["centroid"], 89 | "yaw": data["yaw"], 90 | "extent": data["extent"], 91 | "track_ids": data["track_ids"], 92 | "centroid_pixel": data["centroid_pixel"].astype(np.int64), 93 | } 94 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/resnest_multi.py: -------------------------------------------------------------------------------- 1 | # using ResNeSt-50 as an example 2 | from resnest.torch import resnest50, resnest101, resnest200, resnest269 3 | 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from typing import Dict 9 | 10 | import sys 11 | import os 12 | 13 | from torch.nn import Sequential 14 | 15 | sys.path.append(os.pardir) 16 | sys.path.append(os.path.join(os.pardir, os.pardir)) 17 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir)) 18 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir, os.pardir)) 19 | from lib.nn.block.linear_block import LinearBlock 20 | from lib.nn.models.multi.multi_utils import calc_in_out_channels, calc_out_channels 21 | 22 | model_name_dict = { 23 | "resnest50": resnest50, 24 | "resnest101": resnest101, 25 | "resnest200": resnest200, 26 | "resnest269": resnest269, 27 | } 28 | 29 | 30 | class ResNeStMulti(nn.Module): 31 | 32 | def __init__(self, cfg, num_modes=3, model_name='resnest50', 33 | use_bn: bool = True, 34 | hdim: int = 512, 35 | pretrained=True, 36 | in_channels: int = 0): 37 | super(ResNeStMulti, self).__init__() 38 | out_dim, num_preds, future_len = calc_out_channels(cfg, num_modes=num_modes) 39 | self.in_channels = in_channels 40 | self.out_dim = out_dim 41 | self.num_preds = num_preds 42 | self.future_len = future_len 43 | self.num_modes = num_modes 44 | 45 | # self.conv0 = nn.Conv2d( 46 | # in_channels, 3, kernel_size=3, stride=1, padding=1, bias=True) 47 | 48 | self.base_model = model_name_dict[model_name](pretrained=pretrained) 49 | # --- Replace first conv block, instead of preparing conv0 ---- 50 | if isinstance(self.base_model.conv1, Sequential): 51 | conv = self.base_model.conv1[0] 52 | self.base_model.conv1[0] = nn.Conv2d( 53 | in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, 54 | padding=conv.padding, bias=True) 55 | else: 56 | conv = self.base_model.conv1 57 | self.base_model.conv1 = nn.Conv2d( 58 | in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, 59 | padding=conv.padding, bias=True) 60 | 61 | activation = F.leaky_relu 62 | self.do_pooling = True 63 | # if self.do_pooling: 64 | # inch = self.base_model.last_linear.in_features 65 | # else: 66 | # inch = None 67 | inch = None 68 | lin1 = LinearBlock(inch, hdim, use_bn=use_bn, activation=activation, residual=False) 69 | lin2 = LinearBlock(hdim, out_dim, use_bn=use_bn, activation=None, residual=False) 70 | self.lin_layers = Sequential(lin1, lin2) 71 | 72 | def calc_features(self, x): 73 | x = self.base_model.conv1(x) 74 | x = self.base_model.bn1(x) 75 | x = self.base_model.relu(x) 76 | x = self.base_model.maxpool(x) 77 | 78 | x = self.base_model.layer1(x) 79 | x = self.base_model.layer2(x) 80 | x = self.base_model.layer3(x) 81 | x = self.base_model.layer4(x) 82 | return x 83 | 84 | def forward(self, x): 85 | h = self.calc_features(x) 86 | 87 | if self.do_pooling: 88 | # h = torch.sum(h, dim=(-1, -2)) 89 | h = torch.mean(h, dim=(-1, -2)) 90 | else: 91 | # [128, 2048, 4, 4] when input is (128, 128) 92 | bs, ch, height, width = h.shape 93 | h = h.view(bs, ch*height*width) 94 | for layer in self.lin_layers: 95 | h = layer(h) 96 | return h 97 | 98 | 99 | if __name__ == '__main__': 100 | # --- test instantiation --- 101 | from lib.utils.yaml_utils import load_yaml 102 | 103 | cfg = load_yaml("../../../../modeling/configs/0905_cfg.yaml") 104 | num_modes = 3 105 | model = ResNeStMulti(cfg, num_modes=num_modes) 106 | print(type(model)) 107 | print(model) 108 | 109 | bs = 3 110 | in_channels = model.in_channels 111 | height, width = 224, 224 112 | device = "cuda:0" 113 | 114 | x = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 115 | model.to(device) 116 | # pred, confidences = model(x) 117 | # print("pred", pred.shape, "confidences", confidences.shape) 118 | h = model(x) 119 | print("h", h.shape) 120 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi_agent/smp_multi_agent_model.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | import os 8 | 9 | from torch.nn import Sequential 10 | from torchvision.ops.roi_pool import roi_pool 11 | 12 | sys.path.append(os.pardir) 13 | sys.path.append(os.path.join(os.pardir, os.pardir)) 14 | from lib.nn.block.linear_block import LinearBlock 15 | from lib.nn.models.multi.multi_utils import calc_out_channels 16 | 17 | 18 | class SMPMultiAgentModel(nn.Module): 19 | def __init__( 20 | self, 21 | cfg, 22 | num_modes=3, 23 | in_channels: int = 0, 24 | hdim: int = 512, 25 | use_bn: bool = False, 26 | model_name: str = "smp_fpn", 27 | encoder_name: str = "resnet18", 28 | roi_kernel_size: float = 1.0, 29 | ): 30 | super(SMPMultiAgentModel, self).__init__() 31 | out_dim, num_preds, future_len = calc_out_channels(cfg, num_modes=num_modes) 32 | 33 | if model_name == "smp_unet": 34 | self.base_model = smp.Unet(encoder_name, in_channels=in_channels) 35 | elif model_name == "smp_fpn": 36 | self.base_model = smp.FPN(encoder_name, in_channels=in_channels) 37 | else: 38 | raise NotImplementedError(f"model_name {model_name} not supported in SMPMultiAgentModel") 39 | 40 | # HACKING, skip conv2d to get 41 | decoder_channels = self.base_model.segmentation_head[0].in_channels 42 | print("decoder_channels", decoder_channels) 43 | self.base_model.segmentation_head[0] = nn.Identity() 44 | 45 | activation = F.leaky_relu 46 | lin_head1 = LinearBlock(decoder_channels, hdim, use_bn=use_bn, activation=activation, residual=False) 47 | lin_head2 = LinearBlock(hdim, out_dim, use_bn=use_bn, activation=None, residual=False) 48 | self.lin_layers = Sequential(lin_head1, lin_head2) 49 | # self.lin_head = nn.Linear(decoder_channels, out_channels) 50 | 51 | self.in_channels = in_channels 52 | self.model_name = model_name 53 | self.num_preds = num_preds 54 | self.out_dim = out_dim 55 | self.future_len = future_len 56 | self.num_modes = num_modes 57 | self.roi_kernel_size = roi_kernel_size 58 | 59 | def forward(self, image, centroid_pixel, batch_agents): 60 | # (bs, ch, height, width) 61 | h_image = self.base_model(image) 62 | 63 | # (n_agents, ch) 64 | roi_kernel_size = self.roi_kernel_size 65 | if roi_kernel_size == 1.0: 66 | # Kernel size is 1, simply take the position of feature as agent feature. 67 | # TODO: how to handle agents outside of image...?? 68 | x_pixel = torch.clamp(centroid_pixel[:, 0], 0, h_image.shape[3] - 1) 69 | y_pixel = torch.clamp(centroid_pixel[:, 1], 0, h_image.shape[2] - 1) 70 | h_points = h_image[batch_agents, :, y_pixel, x_pixel] 71 | else: 72 | # Take ROI pooling around the position to get the agent feature. 73 | # TODO: how to handle agents outside of image...?? 74 | k = roi_kernel_size / 2 75 | x1 = torch.clamp(centroid_pixel[:, 0] - k, 0, h_image.shape[3] - 1).type(h_image.dtype) 76 | x2 = torch.clamp(centroid_pixel[:, 0] + k, 0, h_image.shape[3] - 1).type(h_image.dtype) 77 | y1 = torch.clamp(centroid_pixel[:, 1] - k, 0, h_image.shape[2] - 1).type(h_image.dtype) 78 | y2 = torch.clamp(centroid_pixel[:, 1] + k, 0, h_image.shape[2] - 1).type(h_image.dtype) 79 | boxes = torch.stack([batch_agents.type(h_image.dtype), x1, y1, x2, y2], dim=1) 80 | h_points = roi_pool(h_image, boxes, output_size=1)[:, :, 0, 0] 81 | 82 | for layer in self.lin_layers: 83 | h_points = layer(h_points) 84 | 85 | h = h_points 86 | # pred (n_agents)x(modes)x(time)x(2D coords) 87 | # confidences (n_agents)x(modes) 88 | n_agents, _ = h.shape 89 | pred, confidences = torch.split(h, self.num_preds, dim=1) 90 | pred = pred.view(n_agents, self.num_modes, self.future_len, 2) 91 | assert confidences.shape == (n_agents, self.num_modes) 92 | confidences = torch.softmax(confidences, dim=1) 93 | return pred, confidences 94 | 95 | 96 | if __name__ == '__main__': 97 | bs = 2 98 | in_channels = 4 99 | out_channels = 3 100 | image = torch.ones([bs, in_channels, 128, 128]) 101 | model = smp.Unet("resnet18", in_channels=in_channels, classes=out_channels) # 16 102 | model = smp.FPN("resnet18", in_channels=in_channels, classes=out_channels) # 128 103 | decoder_channels = model.segmentation_head[0].in_channels 104 | print("decoder_channels", decoder_channels) 105 | out = model(image) 106 | print("out", out.shape) 107 | -------------------------------------------------------------------------------- /src/modeling/calc_target_scale.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import dataclasses 4 | import numpy as np 5 | import torch 6 | from pathlib import Path 7 | 8 | from l5kit.data import LocalDataManager, ChunkedDataset 9 | import sys 10 | import os 11 | 12 | from tqdm import tqdm 13 | 14 | sys.path.append(os.pardir) 15 | sys.path.append(os.path.join(os.pardir, os.pardir)) 16 | from lib.evaluation.mask import load_mask_chopped 17 | from lib.rasterization.rasterizer_builder import build_custom_rasterizer 18 | from lib.dataset.faster_agent_dataset import FasterAgentDataset 19 | from lib.utils.yaml_utils import save_yaml, load_yaml 20 | from modeling.load_flag import load_flags, Flags 21 | 22 | 23 | def calc_target_scale(agent_dataset, n_sample: int = 10000) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 24 | sub_indices = np.linspace(0, len(agent_dataset) - 1, num=n_sample, dtype=np.int64) 25 | pos_list = [] 26 | for i in tqdm(sub_indices): 27 | d = agent_dataset[i] 28 | pos = d["target_positions"] 29 | pos[~d["target_availabilities"].astype(bool)] = np.nan 30 | pos_list.append(pos) 31 | agents_pos = np.array(pos_list) 32 | target_scale_abs_mean = np.nanmean(np.abs(agents_pos), axis=0) 33 | target_scale_abs_max = np.nanmax(np.abs(agents_pos), axis=0) 34 | target_scale_std = np.nanstd(agents_pos, axis=0) 35 | return target_scale_abs_mean, target_scale_abs_max, target_scale_std 36 | 37 | 38 | if __name__ == '__main__': 39 | mode = "" 40 | flags: Flags = load_flags(mode=mode) 41 | flags_dict = dataclasses.asdict(flags) 42 | cfg = load_yaml(flags.cfg_filepath) 43 | out_dir = Path(flags.out_dir) 44 | print(f"cfg {cfg}") 45 | os.makedirs(str(out_dir), exist_ok=True) 46 | print(f"flags: {flags_dict}") 47 | save_yaml(out_dir / 'flags.yaml', flags_dict) 48 | save_yaml(out_dir / 'cfg.yaml', cfg) 49 | debug = flags.debug 50 | 51 | # set env variable for data 52 | os.environ["L5KIT_DATA_FOLDER"] = flags.l5kit_data_folder 53 | dm = LocalDataManager(None) 54 | 55 | print("init dataset") 56 | train_cfg = cfg["train_data_loader"] 57 | valid_cfg = cfg["valid_data_loader"] 58 | 59 | # Build StubRasterizer for fast dataset access 60 | cfg["raster_params"]["map_type"] = "stub_debug" 61 | rasterizer = build_custom_rasterizer(cfg, dm) 62 | print("rasterizer", rasterizer) 63 | 64 | train_path = "scenes/sample.zarr" if debug else train_cfg["key"] 65 | 66 | train_agents_mask = None 67 | if flags.validation_chopped: 68 | # Use chopped dataset to calc statistics... 69 | num_frames_to_chop = 100 70 | th_agent_prob = cfg["raster_params"]["filter_agents_threshold"] 71 | min_frame_future = 1 72 | num_frames_to_copy = num_frames_to_chop 73 | train_agents_mask = load_mask_chopped( 74 | dm.require(train_path), th_agent_prob, num_frames_to_copy, min_frame_future) 75 | print("train_path", train_path, "train_agents_mask", train_agents_mask.shape) 76 | 77 | train_zarr = ChunkedDataset(dm.require(train_path)).open(cached=False) 78 | print("train_zarr", type(train_zarr)) 79 | print(f"Open Dataset {flags.pred_mode}...") 80 | 81 | train_agent_dataset = FasterAgentDataset( 82 | cfg, train_zarr, rasterizer, min_frame_history=flags.min_frame_history, 83 | min_frame_future=flags.min_frame_future, agents_mask=train_agents_mask 84 | ) 85 | print("train_agent_dataset", len(train_agent_dataset)) 86 | n_sample = 1_000_000 # Take 1M sample. 87 | target_scale_abs_mean, target_scale_abs_max, target_scale_std = calc_target_scale(train_agent_dataset, n_sample) 88 | 89 | chopped_str = "_chopped" if flags.validation_chopped else "" 90 | agent_prob = cfg["raster_params"]["filter_agents_threshold"] 91 | filename = f"target_scale_abs_mean_{agent_prob}_{flags.min_frame_history}_{flags.min_frame_future}{chopped_str}.npz" 92 | cache_path = Path(train_zarr.path) / filename 93 | np.savez_compressed(cache_path, target_scale=target_scale_abs_mean) 94 | print("Saving to ", cache_path) 95 | 96 | filename = f"target_scale_abs_max_{agent_prob}_{flags.min_frame_history}_{flags.min_frame_future}{chopped_str}.npz" 97 | cache_path = Path(train_zarr.path) / filename 98 | np.savez_compressed(cache_path, target_scale=target_scale_abs_max) 99 | print("Saving to ", cache_path) 100 | 101 | filename = f"target_scale_std_{agent_prob}_{flags.min_frame_history}_{flags.min_frame_future}{chopped_str}.npz" 102 | cache_path = Path(train_zarr.path) / filename 103 | np.savez_compressed(cache_path, target_scale=target_scale_std) 104 | print("Saving to ", cache_path) 105 | 106 | print("target_scale_abs_mean", target_scale_abs_mean) 107 | print("target_scale_abs_max", target_scale_abs_max) 108 | print("target_scale_std", target_scale_std) 109 | import IPython; IPython.embed() 110 | -------------------------------------------------------------------------------- /src/ensemble/ensemble_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import yaml 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from pathlib import Path 9 | from numpy.random import multinomial, multivariate_normal 10 | from sklearn.mixture import GaussianMixture 11 | # from sklearn.mixture._gaussian_mixture import _estimate_gaussian_parameters 12 | import warnings 13 | import torch 14 | warnings.simplefilter('ignore') 15 | 16 | sys.path.append("./src") 17 | sys.path.append(os.pardir) 18 | sys.path.append(os.path.join(os.pardir, os.pardir)) 19 | # from lib.utils.timer_utils import timer_ms 20 | from lib.mixture.gmm import GaussianMixtureIdentity 21 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 22 | 23 | NUM_VAL = 190327 24 | # NUM_VAL = 5000 25 | SEED = 0 26 | np.random.seed(SEED) 27 | 28 | 29 | def main(flags): 30 | outdir = flags["outdir"] 31 | if not Path(outdir).exists(): 32 | # Calculate ensemble... 33 | w = flags["weight"] 34 | sigma = flags["sigma"] 35 | N_sample = flags["N_sample"] 36 | covariance_type = flags["covariance_type"] 37 | file_list = flags["file_list"] 38 | 39 | n_models: int = len(file_list) 40 | preds_list = [np.load(filepath) for filepath in file_list] 41 | # coords_array: (n_models, n_valid_data, num_modes=3, future_len=50, coords=2) 42 | coords_array = np.array([preds["coords"] for preds in preds_list]) 43 | # confs_array: (n_models, n_valid_data, num_modes=3) 44 | confs_array = np.array([preds["confs"] for preds in preds_list]) 45 | weighted_confs_array = np.array(w)[:, None, None] * confs_array 46 | weighted_confs_array = weighted_confs_array / np.sum(weighted_confs_array, axis=(0, 2), keepdims=True) 47 | print("coords_array", coords_array.shape, "confs_array", confs_array.shape) 48 | 49 | # 9164 ms 50 | # I = np.eye(100) * sigma 51 | # samples = [multivariate_normal(np.zeros(100), I) for _ in range(N_sample)] 52 | # print("samples", len(samples), "samples0", samples[0].shape, samples[0]) 53 | # samples_list = samples 54 | 55 | # 12 ms 56 | if sigma > 0: 57 | samples = np.random.normal(0, np.sqrt(sigma), (N_sample, 100)) 58 | else: 59 | samples = np.zeros((N_sample, 100), dtype=np.float64) 60 | print("samples", samples.shape, samples) 61 | 62 | coords_out = [] 63 | confs_out = [] 64 | 65 | for idx in tqdm(range(NUM_VAL)): 66 | # mu = [coords_array[i][idx].reshape(3, 100) for i in range(len(coords_array))] 67 | # mu = np.concatenate(mu) 68 | mu = coords_array[:, idx].reshape(n_models * 3, 100) 69 | # assert np.allclose(mu, mu2) 70 | # confidence = [w[i]*confs_array[i][idx] for i in range(len(confs_array))] 71 | # confidence = np.concatenate(confidence) 72 | # confidence /= confidence.sum() 73 | confidence = weighted_confs_array[:, idx].ravel() 74 | # assert np.allclose(confidence, confidence2) 75 | x = mu[np.random.choice(3*len(w), size=N_sample, p=confidence)]+samples 76 | if covariance_type == "identity": 77 | gauss = GaussianMixtureIdentity(3, "spherical", random_state=SEED) 78 | else: 79 | gauss = GaussianMixture(3, covariance_type, random_state=SEED) 80 | # if idx == 0: 81 | # print("random") 82 | # gauss = GaussianMixture(3, covariance_type, random_state=SEED, init_params="random") 83 | gauss.fit(x) 84 | confidence_fit = gauss.weights_ 85 | mu_fit = gauss.means_ 86 | coords_out.append(mu_fit.reshape(3, 50, 2)) 87 | confs_out.append(confidence_fit) 88 | 89 | preds0 = preds_list[0] 90 | timestamps = preds0["timestamps"] 91 | track_ids = preds0["track_ids"] 92 | targets = preds0["targets"] 93 | target_availabilities = preds0["target_availabilities"] 94 | coords = np.array(coords_out) 95 | confs = np.array(confs_out) 96 | np.savez_compressed( 97 | outdir, 98 | timestamps=timestamps, 99 | track_ids=track_ids, 100 | coords=coords, 101 | confs=confs, 102 | targets=targets, 103 | target_availabilities=target_availabilities, 104 | ) 105 | print(f"Saved to {outdir}") 106 | else: 107 | # Just use already calculated results... 108 | preds = np.load(outdir) 109 | timestamps = preds["timestamps"] 110 | track_ids = preds["track_ids"] 111 | targets = preds["targets"] 112 | target_availabilities = preds["target_availabilities"] 113 | coords = preds["coords"] 114 | confs = preds["confs"] 115 | errors = pytorch_neg_multi_log_likelihood_batch( 116 | torch.as_tensor(targets)[:NUM_VAL], 117 | torch.as_tensor(coords), 118 | torch.as_tensor(confs), 119 | torch.as_tensor(target_availabilities)[:NUM_VAL], 120 | ) 121 | print("errors", errors.shape, torch.mean(errors)) 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser(description='') 126 | parser.add_argument('--yaml_filepath', '-y', type=str, 127 | help='Flags yaml file path') 128 | args = parser.parse_args() 129 | with open(args.yaml_filepath, 'r') as f: 130 | flags = yaml.safe_load(f) 131 | main(flags) 132 | -------------------------------------------------------------------------------- /src/lib/nn/models/multi/pretrained_cnn_multi.py: -------------------------------------------------------------------------------- 1 | import pretrainedmodels 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from typing import Dict 6 | 7 | import sys 8 | import os 9 | 10 | from torch.nn import Sequential 11 | 12 | sys.path.append(os.pardir) 13 | sys.path.append(os.path.join(os.pardir, os.pardir)) 14 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir)) 15 | sys.path.append(os.path.join(os.pardir, os.pardir, os.pardir, os.pardir)) 16 | from lib.nn.block.linear_block import LinearBlock 17 | from lib.nn.models.multi.multi_utils import calc_out_channels 18 | from lib.nn.block.feat_module import FeatModule 19 | 20 | 21 | class PretrainedCNNMulti(nn.Module): 22 | 23 | def __init__( 24 | self, cfg, num_modes=3, model_name='se_resnext101_32x4d', 25 | use_bn: bool = True, 26 | hdim: int = 512, 27 | pretrained='imagenet', 28 | in_channels: int = 0, 29 | feat_module_type: str = "none", 30 | feat_channels: int = -1, 31 | ): 32 | super(PretrainedCNNMulti, self).__init__() 33 | out_dim, num_preds, future_len = calc_out_channels(cfg, num_modes=num_modes) 34 | self.in_channels = in_channels 35 | self.out_dim = out_dim 36 | self.num_preds = num_preds 37 | self.future_len = future_len 38 | self.num_modes = num_modes 39 | 40 | self.base_model = pretrainedmodels.__dict__[model_name](pretrained=pretrained) 41 | 42 | # --- Replace first conv --- 43 | try: 44 | if hasattr(self.base_model, "layer0") and isinstance(self.base_model.layer0[0], nn.Conv2d): 45 | print("Replace self.base_model.layer0[0]...") 46 | # This works with SeResNeXt, but not tested with other network... 47 | conv = self.base_model.layer0[0] 48 | self.base_model.layer0[0] = nn.Conv2d( 49 | in_channels, conv.out_channels, kernel_size=conv.kernel_size, stride=conv.stride, 50 | padding=conv.padding, bias=True) 51 | self.conv0 = None 52 | # elif hasattr(self.base_model, "conv1") and isinstance(self.base_model.conv1, nn.Conv2d): 53 | elif model_name in ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]: 54 | # torchvision resnet is follows... 55 | print("Replace base_model.conv1...") 56 | self.base_model.conv1 = nn.Conv2d( 57 | in_channels, 58 | self.base_model.conv1.out_channels, 59 | kernel_size=self.base_model.conv1.kernel_size, 60 | stride=self.base_model.conv1.stride, 61 | padding=self.base_model.conv1.padding, 62 | bias=False, 63 | ) 64 | self.conv0 = None 65 | else: 66 | raise ValueError("Cannot extract first conv layer") 67 | except Exception as e: 68 | # TODO: Better to replace `base_model`'s first conv block! 69 | self.conv0 = nn.Conv2d( 70 | in_channels, 3, kernel_size=3, stride=1, padding=1, bias=True) 71 | print(f'[WARNING ]Cannot extract first conv layer for {model_name}, use conv0 to align channel size') 72 | 73 | activation = F.leaky_relu 74 | self.do_pooling = True 75 | if self.do_pooling: 76 | inch = self.base_model.last_linear.in_features 77 | else: 78 | inch = None 79 | lin1 = LinearBlock(inch, hdim, use_bn=use_bn, activation=activation, residual=False) 80 | lin2 = LinearBlock(hdim, out_dim, use_bn=use_bn, activation=None, residual=False) 81 | self.lin_layers = Sequential(lin1, lin2) 82 | 83 | self.feat_module_type = feat_module_type 84 | self.feat_module = FeatModule( 85 | feat_module_type=feat_module_type, 86 | channels=inch, 87 | feat_channels=feat_channels, 88 | ) 89 | self.feat_channels = feat_channels 90 | 91 | def forward(self, x, x_feat=None): 92 | """ 93 | 94 | Args: 95 | x: image feature (bs, ch, h, w) 96 | x_feat: (bs, ch). Additional feature (Ex. Agent type, timestamp...) 97 | 98 | Returns: 99 | h: (bs, ch) 100 | """ 101 | if self.conv0 is None: 102 | h = x 103 | else: 104 | h = self.conv0(x) 105 | h = self.base_model.features(h) 106 | 107 | if self.do_pooling: 108 | # h = torch.sum(h, dim=(-1, -2)) 109 | h = torch.mean(h, dim=(-1, -2)) 110 | else: 111 | # [128, 2048, 4, 4] when input is (128, 128) 112 | bs, ch, height, width = h.shape 113 | h = h.view(bs, ch*height*width) 114 | 115 | h = self.feat_module(h, x_feat) 116 | 117 | for layer in self.lin_layers: 118 | h = layer(h) 119 | return h 120 | 121 | 122 | if __name__ == '__main__': 123 | # --- test instantiation --- 124 | from lib.utils.yaml_utils import load_yaml 125 | 126 | cfg = load_yaml("../../../../modeling/configs/0905_cfg.yaml") 127 | num_modes = 3 128 | model = PretrainedCNNMulti(cfg, num_modes=num_modes) 129 | print(type(model)) 130 | print(model) 131 | 132 | bs = 3 133 | in_channels = model.in_channels 134 | height, width = 224, 224 135 | device = "cuda:0" 136 | 137 | x = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 138 | model.to(device) 139 | # pred, confidences = model(x) 140 | # print("pred", pred.shape, "confidences", confidences.shape) 141 | h = model(x) 142 | print("h", h.shape) 143 | -------------------------------------------------------------------------------- /src/lib/dataset/fast_agent_dataset.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import numpy as np 6 | from l5kit.data import ChunkedDataset 7 | from l5kit.dataset import AgentDataset 8 | from l5kit.kinematic import Perturbation 9 | from l5kit.rasterization import Rasterizer 10 | 11 | from lib.utils.timer_utils import timer 12 | 13 | # WARNING: changing these values impact the number of instances selected for both train and inference! 14 | 15 | MIN_FRAME_HISTORY = 10 # minimum number of frames an agents must have in the past to be picked 16 | MIN_FRAME_FUTURE = 1 # minimum number of frames an agents must have in the future to be picked 17 | 18 | 19 | class FastAgentDataset(AgentDataset): 20 | def __init__( 21 | self, 22 | cfg: dict, 23 | zarr_dataset: ChunkedDataset, 24 | rasterizer: Rasterizer, 25 | perturbation: Optional[Perturbation] = None, 26 | agents_mask: Optional[np.ndarray] = None, 27 | min_frame_history: int = MIN_FRAME_HISTORY, 28 | min_frame_future: int = MIN_FRAME_FUTURE, 29 | ): 30 | assert perturbation is None, "AgentDataset does not support perturbation (yet)" 31 | 32 | super(AgentDataset, self).__init__(cfg, zarr_dataset, rasterizer, perturbation) 33 | 34 | # store the valid agents indexes 35 | with timer("agents_indices"): 36 | self.agents_indices = self.load_agents_indices(agents_mask, min_frame_history, min_frame_future) 37 | print("self.agents_indices", self.agents_indices.shape) 38 | # this will be used to get the frame idx from the agent idx 39 | with timer("cumulative_sizes_agents"): 40 | self.cumulative_sizes_agents = self.load_cumulative_sizes_agents(agents_mask) 41 | print("self.cumulative_sizes_agents", self.cumulative_sizes_agents.shape) 42 | 43 | # agents_mask may be `None` here. 44 | # Because this is not used in typical training & takes time to load... 45 | self.agents_mask = agents_mask 46 | 47 | # --- Below 2 methods are for "Fast" __init__. Caching time consuming array loading. 48 | def load_agents_indices( 49 | self, 50 | agents_mask: Optional[np.ndarray] = None, 51 | min_frame_history: int = MIN_FRAME_HISTORY, 52 | min_frame_future: int = MIN_FRAME_FUTURE, 53 | ) -> np.ndarray: 54 | agent_prob = self.cfg["raster_params"]["filter_agents_threshold"] 55 | agents_mask_str = "" if agents_mask is None else f"_mask{np.sum(agents_mask)}" 56 | filename = f"agents_indices_{agent_prob}_{min_frame_history}_{min_frame_future}{agents_mask_str}.npz" 57 | agents_indices_path = Path(self.dataset.path) / filename 58 | if not agents_indices_path.exists(): 59 | print(f"Cache {agents_indices_path} does not exist, creating...") 60 | if agents_mask is None: # if not provided try to load it from the zarr 61 | with timer("load_agents_mask"): 62 | agents_mask = self.load_agents_mask() 63 | print("agents_mask", agents_mask.shape) 64 | past_mask = agents_mask[:, 0] >= min_frame_history 65 | future_mask = agents_mask[:, 1] >= min_frame_future 66 | agents_mask = past_mask * future_mask 67 | 68 | if min_frame_history != MIN_FRAME_HISTORY: 69 | warnings.warn( 70 | f"you're running with custom min_frame_history of {min_frame_history}", 71 | RuntimeWarning, 72 | stacklevel=2, 73 | ) 74 | if min_frame_future != MIN_FRAME_FUTURE: 75 | warnings.warn( 76 | f"you're running with custom min_frame_future of {min_frame_future}", RuntimeWarning, stacklevel=2 77 | ) 78 | else: 79 | warnings.warn("you're running with a custom agents_mask", RuntimeWarning, stacklevel=2) 80 | agents_indices = np.nonzero(agents_mask)[0] 81 | np.savez_compressed(str(agents_indices_path), agents_indices=agents_indices) 82 | 83 | # --- Load from cache --- 84 | assert agents_indices_path.exists() 85 | agents_indices = np.load(agents_indices_path)["agents_indices"] 86 | print(f"Loaded from {agents_indices_path}") 87 | return agents_indices 88 | 89 | def load_cumulative_sizes_agents(self, agents_mask: Optional[np.ndarray] = None) -> np.ndarray: 90 | agents_mask_str = "" if agents_mask is None else f"_mask{np.sum(agents_mask)}" 91 | filename = f"cumulative_sizes_agents{agents_mask_str}.npz" 92 | cumulative_sizes_agents_path = Path(self.dataset.path) / filename 93 | if not cumulative_sizes_agents_path.exists(): 94 | print(f"Cache {cumulative_sizes_agents_path} does not exist, creating...") 95 | cumulative_sizes_agents = self.dataset.frames["agent_index_interval"][:, 1] 96 | np.savez_compressed(str(cumulative_sizes_agents_path), cumulative_sizes_agents=cumulative_sizes_agents) 97 | 98 | # --- Load from cache --- 99 | assert cumulative_sizes_agents_path.exists() 100 | cumulative_sizes_agents = np.load(cumulative_sizes_agents_path)["cumulative_sizes_agents"] 101 | print(f"Loaded from {cumulative_sizes_agents_path}") 102 | return cumulative_sizes_agents 103 | 104 | def get_scene_dataset(self, scene_index: int) -> "AgentDataset": 105 | """ 106 | Differs from parent only in the return type. 107 | Instead of doing everything from scratch, we rely on super call and fix the agents_mask 108 | """ 109 | if self.agents_mask is None: 110 | print("Loading agents_mask...") 111 | self.agents_mask = self.load_agents_mask() 112 | return super(FastAgentDataset, self).get_scene_dataset(scene_index) 113 | -------------------------------------------------------------------------------- /src/lib/dataset/faster_agent_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from pathlib import Path 3 | from typing import Optional 4 | from concurrent.futures import ProcessPoolExecutor 5 | 6 | import numpy as np 7 | from l5kit.data import ChunkedDataset 8 | from l5kit.dataset import EgoDataset 9 | from l5kit.kinematic import Perturbation 10 | from l5kit.rasterization import Rasterizer 11 | from l5kit.dataset.agent import MIN_FRAME_HISTORY, MIN_FRAME_FUTURE 12 | from numcodecs import Blosc 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | import zarr 16 | 17 | from lib.dataset.custom_ego_dataset import get_frame_custom 18 | from lib.dataset.fast_agent_dataset import FastAgentDataset 19 | from lib.sampling.agent_sampling_tl_history import create_generate_agent_sample_tl_history_partial 20 | from lib.sampling.agent_sampling_fixing_yaw import create_generate_agent_sample_fixing_yaw_partial 21 | 22 | 23 | def _job(index): 24 | global agent_dataset 25 | track_id = agent_dataset.dataset.agents[index]["track_id"] 26 | frame_index = bisect.bisect_right(agent_dataset.cumulative_sizes_agents, index) 27 | scene_index = bisect.bisect_right(agent_dataset.cumulative_sizes, frame_index) 28 | 29 | if scene_index == 0: 30 | state_index = frame_index 31 | else: 32 | state_index = frame_index - agent_dataset.cumulative_sizes[scene_index - 1] 33 | 34 | return track_id, scene_index, state_index 35 | 36 | 37 | class FasterAgentDataset(Dataset): 38 | def __init__( 39 | self, 40 | cfg: dict, 41 | zarr_dataset: ChunkedDataset, 42 | rasterizer: Rasterizer, 43 | perturbation: Optional[Perturbation] = None, 44 | agents_mask: Optional[np.ndarray] = None, 45 | min_frame_history: int = MIN_FRAME_HISTORY, 46 | min_frame_future: int = MIN_FRAME_FUTURE, 47 | override_sample_function_name: str = "", 48 | ): 49 | assert perturbation is None, "AgentDataset does not support perturbation (yet)" 50 | self.cfg = cfg 51 | self.ego_dataset = EgoDataset(cfg, zarr_dataset, rasterizer, perturbation) 52 | self.get_frame_arguments = self.load_get_frame_arguments(agents_mask, min_frame_history, min_frame_future) 53 | 54 | if override_sample_function_name != "": 55 | print("override_sample_function_name", override_sample_function_name) 56 | if override_sample_function_name == "generate_agent_sample_tl_history": 57 | self.ego_dataset.sample_function = create_generate_agent_sample_tl_history_partial(cfg, rasterizer) 58 | elif override_sample_function_name == "generate_agent_sample_fixing_yaw": 59 | self.ego_dataset.sample_function = create_generate_agent_sample_fixing_yaw_partial(cfg, rasterizer) 60 | 61 | def load_get_frame_arguments( 62 | self, 63 | agents_mask: Optional[np.ndarray] = None, 64 | min_frame_history: int = MIN_FRAME_HISTORY, 65 | min_frame_future: int = MIN_FRAME_FUTURE, 66 | ) -> zarr.core.Array: 67 | """ 68 | Returns: 69 | zarr.core.Array: int64 array of (track_id, scene_index, state_index) 70 | """ 71 | agent_prob = self.cfg["raster_params"]["filter_agents_threshold"] 72 | agents_mask_str = "" if agents_mask is None else f"_mask{np.sum(agents_mask)}" 73 | filename = f"get_frame_arguments_{agent_prob}_{min_frame_history}_{min_frame_future}{agents_mask_str}.zarr" 74 | cache_path = Path(self.ego_dataset.dataset.path) / filename 75 | 76 | if not cache_path.exists(): 77 | global agent_dataset 78 | print(f"Cache {cache_path} does not exist, creating...") 79 | 80 | # Use FastAgentDataset to build agent_indices. 81 | agent_dataset = FastAgentDataset( 82 | cfg=self.cfg, 83 | zarr_dataset=self.ego_dataset.dataset, 84 | rasterizer=self.ego_dataset.rasterizer, 85 | perturbation=self.ego_dataset.perturbation, 86 | agents_mask=agents_mask, 87 | min_frame_history=min_frame_history, 88 | min_frame_future=min_frame_future 89 | ) 90 | 91 | indices = agent_dataset.agents_indices 92 | with ProcessPoolExecutor(max_workers=16) as executor: 93 | get_frame_arguments = list(tqdm(executor.map(_job, indices, chunksize=10000), total=len(indices))) 94 | 95 | del agent_dataset 96 | 97 | get_frame_arguments = np.asarray(get_frame_arguments, dtype=np.int64) 98 | compressor = Blosc(cname="lz4", clevel=5, shuffle=Blosc.SHUFFLE, blocksize=0) 99 | z = zarr.open(str(cache_path), mode="w", shape=get_frame_arguments.shape, chunks=(20000, 3), dtype="i8", compressor=compressor) 100 | z[:] = get_frame_arguments 101 | 102 | z = zarr.open(str(cache_path), mode="r") 103 | return z 104 | 105 | def __len__(self) -> int: 106 | return len(self.get_frame_arguments) 107 | 108 | def __getitem__(self, index: int) -> dict: 109 | track_id, scene_index, state_index = self.get_frame_arguments[index] 110 | # return self.ego_dataset.get_frame(scene_index, state_index, track_id=track_id) 111 | return get_frame_custom(self.ego_dataset, scene_index, state_index, track_id=track_id) 112 | 113 | 114 | if __name__ == "__main__": 115 | from l5kit.data import LocalDataManager 116 | from l5kit.rasterization import build_rasterizer 117 | from lib.utils.yaml_utils import load_yaml 118 | 119 | repo_root = Path(__file__).parent.parent.parent.parent 120 | 121 | dm = LocalDataManager(local_data_folder=str(repo_root / "input" / "lyft-motion-prediction-autonomous-vehicles")) 122 | dataset = ChunkedDataset(dm.require("scenes/sample.zarr")).open(cached=False) 123 | cfg = load_yaml(repo_root / "src" / "modeling" / "configs" / "0905_cfg.yaml") 124 | rasterizer = build_rasterizer(cfg, dm) 125 | 126 | faster_agent_dataset = FasterAgentDataset(cfg, dataset, rasterizer, None) 127 | fast_agent_dataset = FastAgentDataset(cfg, dataset, rasterizer, None) 128 | 129 | assert len(faster_agent_dataset) == len(fast_agent_dataset) 130 | keys = ["image", "target_positions", "target_availabilities"] 131 | for index in tqdm(range(min(1000, len(faster_agent_dataset)))): 132 | actual = faster_agent_dataset[index] 133 | expected = fast_agent_dataset[index] 134 | for key in keys: 135 | assert (actual[key] == expected[key]).all() 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lyft Motion Prediction for Autonomous Vehicles 2 | 3 | Code for the 4th place solution of [Lyft Motion Prediction for Autonomous Vehicles](https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles) 4 | on Kaggle. 5 | 6 | - Discussion [4th place solution: Ensemble with GMM](https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles/discussion/199657) 7 | 8 | ## Directory structure 9 | 10 | ```text 11 | input --- Please locate data here 12 | src 13 | |-ensemble --- For 4. Ensemble scripts 14 | |-lib --- Library codes 15 | |-modeling --- For 1. training, 2. prediction and 3. evaluation scripts 16 | |-results --- Training, prediction and evaluation results will be stored here 17 | README.md --- This instruction file 18 | requirements.txt --- For python library versions 19 | ``` 20 | 21 | ## Hardware (The following specs were used to create the original solution) 22 | 23 | - Ubuntu 18.04 LTS 24 | - 32 CPUs 25 | - 128GB RAM 26 | - 8 x NVIDIA Tesla V100 GPUs 27 | 28 | ## Software (python packages are detailed separately in `requirements.txt`): 29 | 30 | Python 3.8.5 31 | CUDA 10.1.243 32 | cuddn 7.6.5 33 | nvidia drivers v.55.23.0 34 | -- Equivalent Dockerfile for the GPU installs: Use `nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04` as base image 35 | 36 | Also, we installed OpenMPI==4.0.4 for running pytorch distributed training. 37 | 38 | ### Python Library 39 | 40 | Deep learning framework, base library 41 | - torch==1.6.0+cu101 42 | - torchvision==0.7.0 43 | - l5kit==1.1.0 44 | - cupy-cuda101==7.0.0 45 | - pytorch-ignite==0.4.1 46 | - pytorch-pfn-extras==0.3.1 47 | 48 | CNN models 49 | - [pretrainedmodels](https://github.com/Cadene/pretrained-models.pytorch) ==0.7.4 50 | - [efficientnet_pytorch](https://github.com/lukemelas/EfficientNet-PyTorch) ==0.7.0 51 | - [resnest](https://github.com/zhanghang1989/ResNeSt) ==0.0.6b20200912 52 | - [segmentation-models-pytorch](https://github.com/qubvel/segmentation_models.pytorch) ==0.1.2 53 | - timm==0.3.1 54 | - Shapely==1.7.1 55 | 56 | Data processing/augmentation 57 | - albumentations==0.4.3 58 | - scikit-learn==0.22.2.post1 59 | 60 | We also installed `apex` https://github.com/nvidia/apex 61 | 62 | Please refer `requirements.txt` for more details. 63 | 64 | ### Environment Variable 65 | We recommend to set following environment variables for better performance. 66 | 67 | ```bash 68 | export MKL_NUM_THREADS=1 69 | export OMP_NUM_THREADS=1 70 | export NUMEXPR_NUM_THREADS=1 71 | ``` 72 | 73 | ## Data setup 74 | 75 | Please download competition data: 76 | - [lyft-motion-prediction-autonomous-vehicles](https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles/data) 77 | - [lyft-full-training-set](https://www.kaggle.com/philculliton/lyft-full-training-set) 78 | 79 | For the `lyft-motion-prediction-autonomous-vehicles` dataset, 80 | extract them under `input/lyft-motion-prediction-autonomous-vehicles` directory. 81 | 82 | For the `lyft-full-training-set` data which only contains `train_full.zarr`, 83 | please place it under `input/lyft-motion-prediction-autonomous-vehicles/scenes` as follows: 84 | ```text 85 | input 86 | |-lyft-motion-prediction-autonomous-vehicles 87 | |-scenes 88 | |-train_full.zarr (Place here!) 89 | |-train.zarr 90 | |-validate.zarr 91 | |-test.zarr 92 | |-... (other data) 93 | |-... (other data) 94 | 95 | ``` 96 | 97 | ## Pipeline 98 | Our submission pipeline consists of 1. Training, 2. Prediction, 3. Ensemble. 99 | 100 | ### Training with training/validation dataset 101 | The training script is located under `src/modeling`. 102 | 103 | `train_lyft.py` is the training script and 104 | the training configuration is specified by `flags` yaml file. 105 | 106 | [Note] If you want to run training from scratch, please **remove `results` folder once**. 107 | The training script tries to resume from `results` folder when `resume_if_possible=True` is set. 108 | 109 | [Note] For the first time of training, it creates cache for training to run efficiently. 110 | This cache creation should be done in single process, 111 | so please try with the single GPU training until training loop starts. 112 | The cache is directly created under `input` directory. 113 | 114 | Once the cache is created, we can run multi-GPU training using same `train_lyft.py` script, 115 | with `mpiexec` command. 116 | 117 | ```bash 118 | $ cd src/modeling 119 | 120 | # Single GPU training (Please run this for first time, for input data cache creation) 121 | $ python train_lyft.py --yaml_filepath ./flags/20201104_cosine_aug.yaml 122 | 123 | # Multi GPU training (-n 8 for 8 GPU training) 124 | $ mpiexec -x MASTER_ADDR=localhost -x MASTER_PORT=8899 -n 8 \ 125 | python train_lyft.py --yaml_filepath ./flags/20201104_cosine_aug.yaml 126 | ``` 127 | 128 | We have trained 9 different models for final submission. 129 | Each training configuration can be found in `src/modeling/flags`, 130 | and the training results are located in `src/modeling/results`. 131 | 132 | ### Prediction for test dataset 133 | 134 | `predict_lyft.py` under `src/modeling` executes the prediction for test data. 135 | 136 | Specify `out` as trained directory, the script uses trained model of this directory to inference. 137 | Please set `--convert_world_from_agent true` after `l5kit==1.1.0`. 138 | 139 | ```bash 140 | $ cd src/modeling 141 | $ python predict_lyft.py --out results/20201104_cosine_aug --use_ema true --convert_world_from_agent true 142 | ``` 143 | 144 | Predicted results are stored under `out` directory. 145 | For example, `results/20201104_cosine_aug/prediction_ema/submission.csv` is created with above setting. 146 | 147 | We executed this prediction for all 9 trained models. 148 | We can submit this `submission.csv` file as the single model prediction. 149 | 150 | ### (Optional) Evaluation with validation dataset 151 | 152 | `eval_lyft.py` under `src/modeling` executes the evaluation for validation data (chopped data). 153 | 154 | ```bash 155 | python eval_lyft.py --out results/20201104_cosine_aug --use_ema true 156 | ``` 157 | 158 | The script shows validation error, which is useful for local evaluation of model performance. 159 | 160 | ### Ensemble 161 | Finally all trained models' predictions are ensembled using GMM fitting. 162 | 163 | The ensemble script is located under `src/ensemble`. 164 | 165 | ```bash 166 | # Please execute from root of this repository. 167 | $ python src/ensemble/ensemble_test.py --yaml_filepath src/ensemble/flags/20201126_ensemble.yaml 168 | ``` 169 | 170 | The location of final ensembled `submission.csv` is specified in the yaml file. 171 | You can submit this `submission.csv` by uploading it as dataset, and submit via Kaggle kernel. 172 | Please follow [Save your time, submit without kernel inference](https://www.kaggle.com/corochann/save-your-time-submit-without-kernel-inference) 173 | for the submission procedure. 174 | -------------------------------------------------------------------------------- /src/modeling/calc_history_avail.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate num history for chopped valid/test data... 3 | """ 4 | import argparse 5 | from distutils.util import strtobool 6 | import numpy as np 7 | import torch 8 | from pathlib import Path 9 | 10 | from l5kit.dataset import AgentDataset 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.dataset import Subset 13 | 14 | from l5kit.data import LocalDataManager, ChunkedDataset 15 | 16 | import sys 17 | import os 18 | 19 | from tqdm import tqdm 20 | sys.path.append(os.pardir) 21 | sys.path.append(os.path.join(os.pardir, os.pardir)) 22 | from lib.dataset.faster_agent_dataset import FasterAgentDataset 23 | from lib.evaluation.mask import load_mask_chopped 24 | from modeling.load_flag import Flags 25 | from lib.rasterization.rasterizer_builder import build_custom_rasterizer 26 | from lib.utils.yaml_utils import save_yaml, load_yaml 27 | 28 | 29 | def calc_history_avail(data_loader) -> np.ndarray: 30 | his_avail_list = [] 31 | 32 | with torch.no_grad(): 33 | dataiter = tqdm(data_loader) 34 | for data in dataiter: 35 | his_avail = data["history_availabilities"].numpy() 36 | # To reduce memory usage, convert to bool dtype. 37 | his_avail_list.append(his_avail.astype(np.bool)) 38 | his_avail_array = np.concatenate(his_avail_list) 39 | return his_avail_array 40 | 41 | 42 | def parse(): 43 | parser = argparse.ArgumentParser(description='') 44 | parser.add_argument('--out', '-o', default='results/tmp', 45 | help='Directory to output the result') 46 | parser.add_argument('--debug', '-d', type=strtobool, default='false', 47 | help='Debug mode') 48 | args = parser.parse_args() 49 | return args 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse() 54 | out_dir = Path(args.out) 55 | debug = args.debug 56 | 57 | flags_dict = load_yaml(out_dir / 'flags.yaml') 58 | cfg = load_yaml(out_dir / 'cfg.yaml') 59 | # flags = DotDict(flags_dict) 60 | flags = Flags() 61 | flags.update(flags_dict) 62 | print(f"flags: {flags_dict}") 63 | 64 | # set env variable for data 65 | # Not use flags.l5kit_data_folder, but use fixed test data. 66 | l5kit_data_folder = "../../input/lyft-motion-prediction-autonomous-vehicles" 67 | os.environ["L5KIT_DATA_FOLDER"] = l5kit_data_folder 68 | dm = LocalDataManager(None) 69 | 70 | print("Load dataset...") 71 | default_test_cfg = { 72 | 'key': 'scenes/test.zarr', 73 | 'batch_size': 32, 74 | 'shuffle': False, 75 | 'num_workers': 4 76 | } 77 | test_cfg = cfg.get("test_data_loader", default_test_cfg) 78 | 79 | # from copy import deepcopy 80 | # cfg2 = deepcopy(cfg) 81 | # cfg2["model_params"]["history_num_frames"] = 50 82 | cfg["model_params"]["history_num_frames"] = 100 83 | cfg["raster_params"]["map_type"] = "stub_debug" # For faster calculation... 84 | 85 | # Rasterizer 86 | rasterizer = build_custom_rasterizer(cfg, dm) 87 | 88 | valid_cfg = cfg["valid_data_loader"] 89 | # valid_path = "scenes/sample.zarr" if debug else valid_cfg["key"] 90 | valid_path = valid_cfg["key"] 91 | valid_agents_mask = None 92 | if flags.validation_chopped: 93 | num_frames_to_chop = 100 94 | th_agent_prob = cfg["raster_params"]["filter_agents_threshold"] 95 | min_frame_future = 1 96 | num_frames_to_copy = num_frames_to_chop 97 | valid_agents_mask = load_mask_chopped( 98 | dm.require(valid_path), th_agent_prob, num_frames_to_copy, min_frame_future) 99 | print("valid_path", valid_path, "valid_agents_mask", valid_agents_mask.shape) 100 | valid_zarr = ChunkedDataset(dm.require(valid_path)).open(cached=False) 101 | valid_agent_dataset = FasterAgentDataset( 102 | cfg, valid_zarr, rasterizer, agents_mask=valid_agents_mask, 103 | min_frame_history=flags.min_frame_history, min_frame_future=flags.min_frame_future, 104 | override_sample_function_name=flags.override_sample_function_name, 105 | ) 106 | # valid_dataset = TransformDataset(valid_agent_dataset, transform) 107 | valid_dataset = valid_agent_dataset 108 | 109 | # Only use `n_valid_data` dataset for fast check. 110 | # Sample dataset from regular interval, to increase variety/coverage 111 | n_valid_data = 150 if debug else -1 112 | print("n_valid_data", n_valid_data) 113 | if n_valid_data > 0 and n_valid_data < len(valid_dataset): 114 | valid_sub_indices = np.linspace(0, len(valid_dataset)-1, num=n_valid_data, dtype=np.int64) 115 | valid_dataset = Subset(valid_dataset, valid_sub_indices) 116 | valid_batchsize = valid_cfg["batch_size"] 117 | 118 | local_valid_dataset = valid_dataset 119 | collate_fn = None 120 | valid_loader = DataLoader( 121 | local_valid_dataset, valid_batchsize, shuffle=False, 122 | pin_memory=True, num_workers=valid_cfg["num_workers"], 123 | collate_fn=collate_fn) 124 | 125 | print(valid_agent_dataset) 126 | print("# AgentDataset test:", len(valid_agent_dataset)) 127 | print("# ActualDataset test:", len(valid_dataset)) 128 | in_channels, height, width = valid_agent_dataset[0]["image"].shape # get input image shape 129 | print("in_channels", in_channels, "height", height, "width", width) 130 | 131 | # --- Calc valid data --- 132 | his_avail_valid = calc_history_avail(valid_loader) 133 | print("n_his_avail_valid", his_avail_valid.shape, his_avail_valid) 134 | 135 | # --- Calc test data --- 136 | test_path = test_cfg["key"] 137 | print(f"Loading from {test_path}") 138 | test_zarr = ChunkedDataset(dm.require(test_path)).open(cached=False) 139 | print("test_zarr", type(test_zarr)) 140 | test_mask = np.load(f"{l5kit_data_folder}/scenes/mask.npz")["arr_0"] 141 | test_agent_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) 142 | test_dataset = test_agent_dataset 143 | if debug: 144 | # Only use 100 dataset for fast check... 145 | test_dataset = Subset(test_dataset, np.arange(100)) 146 | test_loader = DataLoader( 147 | test_dataset, 148 | shuffle=test_cfg["shuffle"], 149 | batch_size=test_cfg["batch_size"], 150 | num_workers=test_cfg["num_workers"], 151 | pin_memory=True, 152 | ) 153 | his_avail_test = calc_history_avail(test_loader) 154 | print("n_his_avail_test", his_avail_test.shape, his_avail_test) 155 | 156 | # --- Save to npz format, for future analysis purpose --- 157 | debug_str = "_debug" if debug else "" 158 | processed_dir = Path("../../input/processed_data") 159 | os.makedirs(str(processed_dir), exist_ok=True) 160 | 161 | npz_path = processed_dir / f"history_avail{debug_str}.npz" 162 | np.savez_compressed( 163 | npz_path, 164 | his_avail_valid=his_avail_valid, 165 | his_avail_test=his_avail_test, 166 | ) 167 | print(f"Saved to {npz_path}") 168 | import IPython; IPython.embed() 169 | -------------------------------------------------------------------------------- /src/modeling/calc_num_history.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate num history for chopped valid/test data... 3 | """ 4 | import argparse 5 | from distutils.util import strtobool 6 | import numpy as np 7 | import torch 8 | from pathlib import Path 9 | 10 | from l5kit.dataset import AgentDataset 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.dataset import Subset 13 | 14 | from l5kit.data import LocalDataManager, ChunkedDataset 15 | 16 | import sys 17 | import os 18 | 19 | from tqdm import tqdm 20 | sys.path.append(os.pardir) 21 | sys.path.append(os.path.join(os.pardir, os.pardir)) 22 | from lib.dataset.faster_agent_dataset import FasterAgentDataset 23 | from lib.evaluation.mask import load_mask_chopped 24 | from modeling.load_flag import Flags 25 | from lib.rasterization.rasterizer_builder import build_custom_rasterizer 26 | from lib.utils.yaml_utils import save_yaml, load_yaml 27 | 28 | 29 | def calc_num_history(data_loader) -> np.ndarray: 30 | n_history_availabilities_list = [] 31 | 32 | with torch.no_grad(): 33 | dataiter = tqdm(data_loader) 34 | for data in dataiter: 35 | n_his_avail = np.sum(data["history_availabilities"].numpy(), axis=1) 36 | n_history_availabilities_list.append(n_his_avail) 37 | n_his_availabilities = np.concatenate(n_history_availabilities_list) 38 | return n_his_availabilities 39 | 40 | 41 | def parse(): 42 | parser = argparse.ArgumentParser(description='') 43 | parser.add_argument('--out', '-o', default='results/tmp', 44 | help='Directory to output the result') 45 | parser.add_argument('--debug', '-d', type=strtobool, default='false', 46 | help='Debug mode') 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | if __name__ == '__main__': 52 | args = parse() 53 | out_dir = Path(args.out) 54 | debug = args.debug 55 | 56 | flags_dict = load_yaml(out_dir / 'flags.yaml') 57 | cfg = load_yaml(out_dir / 'cfg.yaml') 58 | # flags = DotDict(flags_dict) 59 | flags = Flags() 60 | flags.update(flags_dict) 61 | print(f"flags: {flags_dict}") 62 | 63 | # set env variable for data 64 | # Not use flags.l5kit_data_folder, but use fixed test data. 65 | l5kit_data_folder = "../../input/lyft-motion-prediction-autonomous-vehicles" 66 | os.environ["L5KIT_DATA_FOLDER"] = l5kit_data_folder 67 | dm = LocalDataManager(None) 68 | 69 | print("Load dataset...") 70 | default_test_cfg = { 71 | 'key': 'scenes/test.zarr', 72 | 'batch_size': 32, 73 | 'shuffle': False, 74 | 'num_workers': 4 75 | } 76 | test_cfg = cfg.get("test_data_loader", default_test_cfg) 77 | 78 | # from copy import deepcopy 79 | # cfg2 = deepcopy(cfg) 80 | # cfg2["model_params"]["history_num_frames"] = 50 81 | cfg["model_params"]["history_num_frames"] = 100 82 | cfg["raster_params"]["map_type"] = "stub_debug" # For faster calculation... 83 | 84 | # Rasterizer 85 | rasterizer = build_custom_rasterizer(cfg, dm) 86 | 87 | valid_cfg = cfg["valid_data_loader"] 88 | # valid_path = "scenes/sample.zarr" if debug else valid_cfg["key"] 89 | valid_path = valid_cfg["key"] 90 | valid_agents_mask = None 91 | if flags.validation_chopped: 92 | num_frames_to_chop = 100 93 | th_agent_prob = cfg["raster_params"]["filter_agents_threshold"] 94 | min_frame_future = 1 95 | num_frames_to_copy = num_frames_to_chop 96 | valid_agents_mask = load_mask_chopped( 97 | dm.require(valid_path), th_agent_prob, num_frames_to_copy, min_frame_future) 98 | print("valid_path", valid_path, "valid_agents_mask", valid_agents_mask.shape) 99 | valid_zarr = ChunkedDataset(dm.require(valid_path)).open(cached=False) 100 | valid_agent_dataset = FasterAgentDataset( 101 | cfg, valid_zarr, rasterizer, agents_mask=valid_agents_mask, 102 | min_frame_history=flags.min_frame_history, min_frame_future=flags.min_frame_future, 103 | override_sample_function_name=flags.override_sample_function_name, 104 | ) 105 | # valid_dataset = TransformDataset(valid_agent_dataset, transform) 106 | valid_dataset = valid_agent_dataset 107 | 108 | # Only use `n_valid_data` dataset for fast check. 109 | # Sample dataset from regular interval, to increase variety/coverage 110 | n_valid_data = 150 if debug else -1 111 | print("n_valid_data", n_valid_data) 112 | if n_valid_data > 0 and n_valid_data < len(valid_dataset): 113 | valid_sub_indices = np.linspace(0, len(valid_dataset)-1, num=n_valid_data, dtype=np.int64) 114 | valid_dataset = Subset(valid_dataset, valid_sub_indices) 115 | valid_batchsize = valid_cfg["batch_size"] 116 | 117 | local_valid_dataset = valid_dataset 118 | collate_fn = None 119 | valid_loader = DataLoader( 120 | local_valid_dataset, valid_batchsize, shuffle=False, 121 | pin_memory=True, num_workers=valid_cfg["num_workers"], 122 | collate_fn=collate_fn) 123 | 124 | print(valid_agent_dataset) 125 | print("# AgentDataset test:", len(valid_agent_dataset)) 126 | print("# ActualDataset test:", len(valid_dataset)) 127 | in_channels, height, width = valid_agent_dataset[0]["image"].shape # get input image shape 128 | print("in_channels", in_channels, "height", height, "width", width) 129 | 130 | # --- Calc valid data --- 131 | n_his_avail_valid = calc_num_history(valid_loader) 132 | print("n_his_avail_valid", n_his_avail_valid.shape, n_his_avail_valid) 133 | 134 | # --- Calc test data --- 135 | test_path = test_cfg["key"] 136 | print(f"Loading from {test_path}") 137 | test_zarr = ChunkedDataset(dm.require(test_path)).open(cached=False) 138 | print("test_zarr", type(test_zarr)) 139 | test_mask = np.load(f"{l5kit_data_folder}/scenes/mask.npz")["arr_0"] 140 | test_agent_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) 141 | test_dataset = test_agent_dataset 142 | if debug: 143 | # Only use 100 dataset for fast check... 144 | test_dataset = Subset(test_dataset, np.arange(100)) 145 | test_loader = DataLoader( 146 | test_dataset, 147 | shuffle=test_cfg["shuffle"], 148 | batch_size=test_cfg["batch_size"], 149 | num_workers=test_cfg["num_workers"], 150 | pin_memory=True, 151 | ) 152 | n_his_avail_test = calc_num_history(test_loader) 153 | print("n_his_avail_test", n_his_avail_test.shape, n_his_avail_test) 154 | 155 | # --- Save to npz format, for future analysis purpose --- 156 | debug_str = "_debug" if debug else "" 157 | processed_dir = Path("../../input/processed_data") 158 | os.makedirs(str(processed_dir), exist_ok=True) 159 | 160 | npz_path = processed_dir / f"n_history{debug_str}.npz" 161 | np.savez_compressed( 162 | npz_path, 163 | n_his_avail_valid=n_his_avail_valid, 164 | n_his_avail_test=n_his_avail_test, 165 | ) 166 | print(f"Saved to {npz_path}") 167 | import IPython; IPython.embed() 168 | -------------------------------------------------------------------------------- /src/ensemble/ensemble_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from dataclasses import dataclass 5 | from typing import Sequence 6 | from contextlib import contextmanager 7 | import time 8 | from pathlib import Path 9 | from typing import Dict, Mapping, Any, Optional 10 | # import warnings 11 | # warnings.simplefilter('ignore') 12 | 13 | import yaml 14 | import numpy as np 15 | import pandas as pd 16 | from tqdm import tqdm 17 | import torch 18 | from l5kit.evaluation import write_pred_csv 19 | 20 | sys.path.append("./src") 21 | sys.path.append(os.pardir) 22 | sys.path.append(os.path.join(os.pardir, os.pardir)) 23 | from lib.mixture.batch_spherical_gmm import BatchSphericalGMM 24 | from lib.functions.nll import pytorch_neg_multi_log_likelihood_batch 25 | 26 | 27 | NUM_TEST = 71122 28 | SEED = 0 29 | np.random.seed(SEED) 30 | 31 | 32 | @dataclass 33 | class EnsembleFlags: 34 | output_path: str 35 | weight: Sequence[float] 36 | sigma: float 37 | N_sample: int 38 | file_list: Sequence[str] 39 | batch_size: int 40 | device: str 41 | gmm_kwargs: Mapping[str, Any] 42 | give_centroids_init: bool 43 | precisions_init_sigma: Optional[float] 44 | 45 | 46 | @contextmanager 47 | def add_time(name: str, store: Dict[str, float]): 48 | t = time.time() 49 | yield 50 | if name not in store: 51 | store[name] = 0.0 52 | store[name] += time.time() - t 53 | 54 | 55 | def load_predictions(file_list, weights): 56 | coords_list, confs_list = [], [] 57 | for path, w in zip(file_list, weights): 58 | path = Path(path) 59 | 60 | if path.name.endswith(".npz"): 61 | npz = np.load(str(path)) 62 | confs = npz["confs"] 63 | coords = npz["coords"] 64 | n_example, n_modes = confs.shape 65 | coords = coords.reshape(n_example, n_modes, 50, 2) 66 | confs = w * confs 67 | elif path.name.endswith(".csv"): 68 | df = pd.read_csv(path) 69 | columns = [f"coord_{xy}{mode}{step}" for mode in range(3) for step in range(50) for xy in ["x", "y"]] 70 | coords = df.loc[:, columns].to_numpy().reshape(-1, 3, 50, 2) 71 | confs = w * df.loc[:, ["conf_0", "conf_1", "conf_2"]].to_numpy().reshape(-1, 3) 72 | else: 73 | raise ValueError 74 | 75 | coords_list.append(coords) 76 | confs_list.append(confs) 77 | 78 | coords = np.concatenate(coords_list, axis=1) 79 | confs = np.concatenate(confs_list, axis=1) 80 | confs = confs / confs.sum(axis=1, keepdims=True) 81 | return coords, confs 82 | 83 | 84 | def load_metadata(file_list): 85 | path = file_list[0] 86 | 87 | if str(path).endswith(".npz"): 88 | npz = np.load(path) 89 | timestamps = npz["timestamps"] 90 | track_ids = npz["track_ids"] 91 | targets = npz.get("targets") 92 | target_availabilities = npz.get("target_availabilities") 93 | return timestamps, track_ids, targets, target_availabilities 94 | else: 95 | df = pd.read_csv(path) 96 | return df["timestamp"], df["track_id"], None, None 97 | 98 | 99 | def ensemble_batch_core(flags: EnsembleFlags): 100 | assert len(flags.weight) == len(flags.file_list) 101 | 102 | coords_all, confs_all = load_predictions(flags.file_list, flags.weight) 103 | 104 | num_modes_total = confs_all.shape[1] 105 | n_example = len(coords_all) 106 | 107 | assert coords_all.shape == (n_example, num_modes_total, 50, 2) 108 | 109 | np_random = np.random.RandomState(SEED) 110 | 111 | ens_confs = np.zeros((n_example, 3)) 112 | ens_coords = np.zeros((n_example, 3, 50, 2)) 113 | ens_log_probs = np.zeros((n_example,)) 114 | time_store = {} 115 | 116 | noise = np_random.normal(0, scale=flags.sigma, size=(flags.batch_size, flags.N_sample, 100)) 117 | 118 | for idx in tqdm(range(0, n_example, flags.batch_size)): 119 | confidences = confs_all[idx:idx + flags.batch_size] 120 | confidences = confidences / confidences.sum(axis=1, keepdims=True) 121 | size = confidences.shape[0] 122 | assert confidences.shape == (size, num_modes_total) 123 | 124 | coords = coords_all[idx:idx + flags.batch_size] 125 | coords = coords.reshape(size * num_modes_total, 50 * 2) 126 | 127 | # TODO: remove for-loop 128 | with add_time("choice", time_store): 129 | indices = np.stack([np_random.choice(num_modes_total, size=flags.N_sample, p=confidences[j]) for j in range(size)], axis=0) 130 | indices = (np.arange(size)[:, np.newaxis] * num_modes_total + indices).reshape(-1) 131 | 132 | X = coords[indices] 133 | X = X.reshape(size, flags.N_sample, 100) 134 | X = X + noise[:len(X)] 135 | 136 | if flags.give_centroids_init: 137 | flags.gmm_kwargs["centroids_init"] = coords.reshape(size, num_modes_total, 100)[:, :3, :] 138 | 139 | if flags.precisions_init_sigma is not None: 140 | flags.gmm_kwargs["precisions_init"] = np.full((size, 3), flags.precisions_init_sigma) 141 | 142 | with add_time("fit", time_store): 143 | gmm = BatchSphericalGMM(n_components=3, device=flags.device, seed=SEED, **flags.gmm_kwargs) 144 | weights, means, _, log_probs = gmm.fit(X) 145 | 146 | ens_confs[idx:idx + flags.batch_size] = weights 147 | ens_coords[idx:idx + flags.batch_size] = means.reshape(means.shape[0], 3, 50, 2) 148 | ens_log_probs[idx:idx + flags.batch_size] = log_probs 149 | 150 | print(time_store) 151 | 152 | output_path = Path(flags.output_path) 153 | output_path.parent.mkdir(exist_ok=True, parents=True) 154 | 155 | if output_path.name.endswith(".csv"): 156 | save_format = "csv" 157 | else: 158 | save_format = "npz" 159 | 160 | timestamps, track_ids, targets, target_availabilities = load_metadata(flags.file_list) 161 | 162 | if targets is not None: 163 | assert target_availabilities is not None 164 | errors = pytorch_neg_multi_log_likelihood_batch( 165 | torch.as_tensor(targets), 166 | torch.as_tensor(ens_coords), 167 | torch.as_tensor(ens_confs), 168 | torch.as_tensor(target_availabilities), 169 | ) 170 | print("errors", errors.shape, torch.mean(errors)) 171 | 172 | if save_format == "csv": 173 | write_pred_csv(str(output_path), timestamps, track_ids, ens_coords, ens_confs) 174 | else: 175 | np.savez_compressed( 176 | str(output_path), 177 | timestamps=timestamps, 178 | track_ids=track_ids, 179 | coords=ens_coords, 180 | confs=ens_confs, 181 | targets=targets, 182 | target_availabilities=target_availabilities, 183 | log_probs=ens_log_probs, 184 | ) 185 | print(f"Saved to {output_path}") 186 | 187 | 188 | if __name__ == '__main__': 189 | parser = argparse.ArgumentParser(description='') 190 | parser.add_argument('--yaml_filepath', '-y', type=str, 191 | help='Flags yaml file path') 192 | args = parser.parse_args() 193 | 194 | with open(args.yaml_filepath, 'r') as f: 195 | flags = EnsembleFlags(**yaml.safe_load(f)) 196 | 197 | ensemble_batch_core(flags) 198 | -------------------------------------------------------------------------------- /src/lib/data/tuned_map_api.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Iterator, Sequence, Union, no_type_check, Tuple 3 | 4 | import numba as nb 5 | import numpy as np 6 | from numpy import radians, cos, sin 7 | import pymap3d as pm 8 | from l5kit.data import MapAPI 9 | 10 | from l5kit.geometry import transform_points 11 | from l5kit.data.proto.road_network_pb2 import GeoFrame, GlobalId, MapElement, MapFragment 12 | 13 | from lib.utils.numba_utils import transform_points_nb 14 | 15 | CACHE_SIZE = int(1e5) 16 | ENCODING = "utf-8" 17 | 18 | 19 | @nb.jit(nb.types.Tuple((nb.float64, nb.float64, nb.float64))( 20 | nb.float64, nb.float64, nb.int64 21 | ), nopython=True, nogil=True) 22 | def geodetic2ecef(lat, lon, alt): 23 | # Assumes ell model = "wgs84" 24 | semimajor_axis = 6378137.0 25 | semiminor_axis = 6356752.31424518 26 | 27 | # radius of curvature of the prime vertical section 28 | N = semimajor_axis ** 2 / np.sqrt(semimajor_axis ** 2 * cos(lat) ** 2 + semiminor_axis ** 2 * sin(lat) ** 2) 29 | # Compute cartesian (geocentric) coordinates given (curvilinear) geodetic 30 | # coordinates. 31 | x = (N + alt) * cos(lat) * cos(lon) 32 | y = (N + alt) * cos(lat) * sin(lon) 33 | z = (N * (semiminor_axis / semimajor_axis) ** 2 + alt) * sin(lat) 34 | return x, y, z 35 | 36 | 37 | @nb.jit(nb.types.Tuple((nb.float64[:], nb.float64[:], nb.float64[:]))( 38 | nb.float64[:], nb.float64[:], nb.float64[:], nb.float64, nb.float64, 39 | ), nopython=True, nogil=True) 40 | def enu2uvw( 41 | east: np.ndarray, north: np.ndarray, up: np.ndarray, lat0: np.ndarray, lon0: np.ndarray 42 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 43 | t = cos(lat0) * up - sin(lat0) * north 44 | w = sin(lat0) * up + cos(lat0) * north 45 | 46 | u = cos(lon0) * t - sin(lon0) * east 47 | v = sin(lon0) * t + cos(lon0) * east 48 | return u, v, w 49 | 50 | 51 | @nb.jit(nb.types.Tuple((nb.float64[:], nb.float64[:], nb.float64[:]))( 52 | nb.float64[:], nb.float64[:], nb.float64[:], nb.float64, nb.float64, nb.int64 53 | ), nopython=True, nogil=True) 54 | def enu2ecef(e1, n1, u1, lat0, lon0, h0): 55 | # Assuming `ell = None, deg = True` 56 | lat0 = radians(lat0) 57 | lon0 = radians(lon0) 58 | 59 | x0, y0, z0 = geodetic2ecef(lat0, lon0, h0) 60 | dx, dy, dz = enu2uvw(e1, n1, u1, lat0, lon0) 61 | 62 | return x0 + dx, y0 + dy, z0 + dz 63 | 64 | 65 | @nb.jit(nb.float64[:, :]( 66 | nb.int64[:], nb.int64[:], nb.int64[:], nb.int64, nb.int64, nb.float64[:, :] 67 | ), nopython=True, nogil=True) 68 | def _unpack_deltas_cm_nb(dx, dy, dz, lat, lng, ecef_to_world): 69 | x = np.cumsum(dx / 100) 70 | y = np.cumsum(dy / 100) 71 | z = np.cumsum(dz / 100) 72 | frame_lat, frame_lng = lat / 1e7, lng / 1e7 73 | # xyz = np.stack(pm.enu2ecef(x, y, z, frame_lat, frame_lng, 0), axis=-1) 74 | xyz = np.stack(enu2ecef(x, y, z, frame_lat, frame_lng, 0), axis=-1) 75 | xyz = transform_points_nb(xyz, ecef_to_world) 76 | return xyz 77 | 78 | 79 | class TunedMapAPI(MapAPI): 80 | def __init__(self, protobuf_map_path: str, world_to_ecef: np.ndarray): 81 | """ 82 | Interface to the raw protobuf map file with the following features: 83 | - access to element using ID is O(1); 84 | - access to coordinates in world ref system for a set of elements is O(1) after first access (lru cache) 85 | - object support iteration using __getitem__ protocol 86 | 87 | Args: 88 | protobuf_map_path (str): path to the protobuf file 89 | world_to_ecef (np.ndarray): transformation matrix from world coordinates to ECEF (dataset dependent) 90 | """ 91 | super(TunedMapAPI, self).__init__(protobuf_map_path, world_to_ecef) 92 | # self.protobuf_map_path = protobuf_map_path 93 | # self.ecef_to_world = np.linalg.inv(world_to_ecef) 94 | # 95 | # with open(protobuf_map_path, "rb") as infile: 96 | # mf = MapFragment() 97 | # mf.ParseFromString(infile.read()) 98 | # 99 | # self.elements = mf.elements 100 | # self.ids_to_el = {self.id_as_str(el.id): idx for idx, el in enumerate(self.elements)} # store a look-up table 101 | 102 | @no_type_check 103 | def unpack_deltas_cm(self, dx: Sequence[int], dy: Sequence[int], dz: Sequence[int], frame: GeoFrame) -> np.ndarray: 104 | """ 105 | Get coords in world reference system (local ENU->ECEF->world). 106 | See the protobuf annotations for additional information about how coordinates are stored 107 | 108 | Args: 109 | dx (Sequence[int]): X displacement in centimeters in local ENU 110 | dy (Sequence[int]): Y displacement in centimeters in local ENU 111 | dz (Sequence[int]): Z displacement in centimeters in local ENU 112 | frame (GeoFrame): geo-location information for the local ENU. It contains lat and long origin of the frame 113 | 114 | Returns: 115 | np.ndarray: array of shape (Nx3) with XYZ coordinates in world ref system 116 | 117 | """ 118 | xyz = _unpack_deltas_cm_nb( 119 | np.asarray(dx), np.asarray(dy), np.asarray(dz), 120 | frame.origin.lat_e7, frame.origin.lng_e7, self.ecef_to_world) 121 | return xyz 122 | 123 | @lru_cache(maxsize=CACHE_SIZE) 124 | def get_lane_coords(self, element_id: str) -> dict: 125 | """ 126 | Get XYZ coordinates in world ref system for a lane given its id 127 | lru_cached for O(1) access 128 | 129 | Args: 130 | element_id (str): lane element id 131 | 132 | Returns: 133 | dict: a dict with the two boundaries coordinates as (Nx3) XYZ arrays 134 | """ 135 | element = self[element_id] 136 | assert self.is_lane(element) 137 | 138 | lane = element.element.lane 139 | left_boundary = lane.left_boundary 140 | right_boundary = lane.right_boundary 141 | 142 | xyz_left = self.unpack_deltas_cm( 143 | left_boundary.vertex_deltas_x_cm, 144 | left_boundary.vertex_deltas_y_cm, 145 | left_boundary.vertex_deltas_z_cm, 146 | lane.geo_frame, 147 | ) 148 | xyz_right = self.unpack_deltas_cm( 149 | right_boundary.vertex_deltas_x_cm, 150 | right_boundary.vertex_deltas_y_cm, 151 | right_boundary.vertex_deltas_z_cm, 152 | lane.geo_frame, 153 | ) 154 | 155 | return {"xyz_left": xyz_left, "xyz_right": xyz_right} 156 | 157 | @lru_cache(maxsize=CACHE_SIZE) 158 | def get_crosswalk_coords(self, element_id: str) -> dict: 159 | """ 160 | Get XYZ coordinates in world ref system for a crosswalk given its id 161 | lru_cached for O(1) access 162 | 163 | Args: 164 | element_id (str): crosswalk element id 165 | 166 | Returns: 167 | dict: a dict with the polygon coordinates as an (Nx3) XYZ array 168 | """ 169 | element = self[element_id] 170 | assert self.is_crosswalk(element) 171 | traffic_element = element.element.traffic_control_element 172 | 173 | xyz = self.unpack_deltas_cm( 174 | traffic_element.points_x_deltas_cm, 175 | traffic_element.points_y_deltas_cm, 176 | traffic_element.points_z_deltas_cm, 177 | traffic_element.geo_frame, 178 | ) 179 | 180 | return {"xyz": xyz} 181 | -------------------------------------------------------------------------------- /src/modeling/builder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from typing import Optional 4 | from collections import defaultdict 5 | import re 6 | 7 | import pretrainedmodels 8 | import torch 9 | from torch import nn 10 | import timm 11 | 12 | sys.path.append(os.pardir) 13 | sys.path.append(os.path.join(os.pardir, os.pardir)) 14 | from lib.nn.models.multi.multi_model_predictor import LyftMultiModelPredictor 15 | from lib.nn.models.multi.pretrained_cnn_multi import PretrainedCNNMulti 16 | from lib.nn.models.multi.resnest_multi import ResNeStMulti 17 | from lib.nn.models.multi.lyft_multi_model import LyftMultiModel 18 | from lib.nn.models.multi.efficientnet_multi import EfficientNetMulti 19 | try: 20 | from lib.nn.models.multi.timm_multi import TimmMulti 21 | except Exception as e: 22 | print("[WARNING] TimmMulti import failed!") 23 | from lib.nn.models.deep_ensemble.lyft_multi_deep_ensemble_predictor import LyftMultiDeepEnsemblePredictor 24 | from lib.nn.models.rnn_head_multi.lstm_head_multi_predictor import LSTMHeadMultiPredictor 25 | from lib.nn.models.rnn_head_multi.target_scale_wrapper import TargetScaleWrapper 26 | from modeling.load_flag import Flags 27 | 28 | 29 | def build_rnn_head_multi_predictor( 30 | cfg, flags: Flags, device: torch.device, in_channels: int, target_scale: Optional[torch.Tensor] = None 31 | ) -> nn.Module: 32 | num_modes = 3 33 | future_len = cfg["model_params"]["future_num_frames"] 34 | model_name = flags.model_name 35 | print("model_name", model_name, "model_kwargs", flags.model_kwargs) 36 | 37 | if model_name == "LSTMHeadMultiPredictor": 38 | predictor = LSTMHeadMultiPredictor( 39 | num_modes=num_modes, future_len=future_len, in_channels=in_channels, **flags.model_kwargs 40 | ) 41 | else: 42 | raise ValueError(f"[ERROR] Unexpected value model_name={model_name}") 43 | 44 | if target_scale is not None: 45 | assert target_scale.shape == (future_len, 2) 46 | predictor = TargetScaleWrapper(predictor, target_scale) 47 | 48 | # --- Forward once to initialize lazy params --- 49 | bs = 2 50 | height, width = cfg["raster_params"]["raster_size"] 51 | in_channels = predictor.in_channels 52 | image = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 53 | history_positions = torch.rand((bs, 10, 2), dtype=torch.float32).to(device) 54 | history_availablelities = torch.ones((bs, 10), dtype=torch.float32).to(device) 55 | predictor.to(device) 56 | pred, confidences = predictor(image, history_positions, history_availablelities) 57 | assert pred.shape == (bs, num_modes, future_len, 2) 58 | assert confidences.shape == (bs, num_modes) 59 | # --- Done --- 60 | 61 | return predictor 62 | 63 | 64 | def build_multi_predictor( 65 | cfg, 66 | flags: Flags, 67 | device: torch.device, 68 | in_channels: int, 69 | target_scale: Optional[torch.Tensor] = None, 70 | num_modes: int = 3, 71 | ) -> nn.Module: 72 | model_name = flags.model_name 73 | print("model_name", model_name, "model_kwargs", flags.model_kwargs, "num_modes", num_modes) 74 | if model_name == "resnet18": 75 | print("Building LyftMultiModel") 76 | base_model = LyftMultiModel(cfg, num_modes=num_modes, in_channels=in_channels) 77 | elif "efficientnet" in model_name: 78 | print("Building EfficientNetMulti") 79 | base_model = EfficientNetMulti( 80 | cfg, num_modes=num_modes, model_name=model_name, in_channels=in_channels, **flags.model_kwargs) 81 | elif "resnest" in model_name: 82 | print("Building ResNeStMulti") 83 | base_model = ResNeStMulti( 84 | cfg, num_modes=num_modes, model_name=model_name, in_channels=in_channels, **flags.model_kwargs) 85 | elif model_name in pretrainedmodels.__dict__.keys(): 86 | print("Building PretrainedCNNMulti") 87 | base_model = PretrainedCNNMulti( 88 | cfg, num_modes=num_modes, model_name=model_name, in_channels=in_channels, **flags.model_kwargs) 89 | elif model_name in timm.list_models(): 90 | print("Building TimmMulti") 91 | base_model = TimmMulti( 92 | cfg, in_channels=in_channels, num_modes=num_modes, backbone=model_name, **flags.model_kwargs 93 | ) 94 | else: 95 | raise ValueError(f"[ERROR] Unexpected value model_name={model_name}") 96 | predictor = LyftMultiModelPredictor(base_model, cfg, num_modes=num_modes, target_scale=target_scale) 97 | 98 | # --- Forward once to initialize lazy params --- 99 | bs = 2 100 | height, width = cfg["raster_params"]["raster_size"] 101 | in_channels = predictor.in_channels 102 | x = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 103 | predictor.to(device) 104 | if flags.feat_mode == "agent_type": 105 | feat_channels = flags.model_kwargs["feat_channels"] 106 | x_feat = torch.rand((bs, feat_channels), dtype=torch.float32).to(device) 107 | predictor(x, x_feat) 108 | else: 109 | predictor(x) 110 | # --- Done --- 111 | 112 | return predictor 113 | 114 | 115 | def build_multi_mode_deep_ensemble( 116 | cfg, 117 | flags: Flags, 118 | device: torch.device, 119 | in_channels: int, 120 | target_scale: Optional[torch.Tensor] = None, 121 | ) -> nn.Module: 122 | ensemble_name = flags.model_name 123 | model_names = ensemble_name.split("+") 124 | use_D4 = flags.model_kwargs.pop("use_D4", False) 125 | print("use_D4:", use_D4) 126 | 127 | predictors = [] 128 | names = [] 129 | cnts = defaultdict(int) 130 | for model_name in model_names: 131 | match_obj = re.match(r"(.+)_(\d+)modes", model_name) 132 | if match_obj is None: 133 | actual_model_name = model_name 134 | num_modes = 3 135 | else: 136 | actual_model_name = match_obj.group(1) 137 | num_modes = int(match_obj.group(2)) 138 | 139 | flags.model_name = actual_model_name 140 | predictor = build_multi_predictor(cfg, flags, device, in_channels, target_scale, num_modes) 141 | flags.model_name = ensemble_name 142 | 143 | predictors.append(predictor) 144 | names.append(f"{model_name}_{cnts[model_name]}") 145 | cnts[model_name] += 1 146 | 147 | predictor = LyftMultiDeepEnsemblePredictor(predictors, names, use_D4) 148 | flags.model_kwargs["use_D4"] = use_D4 149 | return predictor 150 | 151 | 152 | def build_multi_agent_predictor(cfg, flags: Flags, device: torch.device, in_channels: int) -> nn.Module: 153 | num_modes = 3 154 | model_name = flags.model_name 155 | print("model_name", model_name, "model_kwargs", flags.model_kwargs) 156 | # TODO: now model_name is skipped... 157 | print("Building LyftMultiModel") 158 | if model_name.startswith("smp_"): 159 | from lib.nn.models.multi_agent.smp_multi_agent_model import SMPMultiAgentModel 160 | predictor = SMPMultiAgentModel(cfg, num_modes=num_modes, in_channels=in_channels, **flags.model_kwargs) 161 | else: 162 | raise NotImplementedError(f"model_name {model_name} is not supported for multi_agent model") 163 | 164 | # TODO: 165 | # --- Forward once to initialize lazy params --- 166 | bs = 2 167 | n_agents = 2 168 | height, width = cfg["raster_params"]["raster_size"] 169 | in_channels = predictor.in_channels 170 | x = torch.rand((bs, in_channels, height, width), dtype=torch.float32).to(device) 171 | centroid_pixel = torch.randint(3, 10, (n_agents, 2), dtype=torch.long).to(device) 172 | batch_agents = torch.tensor([0, 0], dtype=torch.long).to(device) 173 | predictor.to(device) 174 | predictor(x, centroid_pixel, batch_agents) 175 | # --- Done --- 176 | return predictor 177 | -------------------------------------------------------------------------------- /src/lib/rasterization/channel_semantic_rasterizer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List, Optional, Tuple, cast 3 | 4 | import cv2 5 | import numpy as np 6 | from l5kit.data import MapAPI, filter_tl_faces_by_status 7 | from l5kit.geometry import rotation33_as_yaw, transform_point, transform_points 8 | from l5kit.rasterization import SemanticRasterizer, RenderContext 9 | from l5kit.rasterization.rasterizer_builder import _load_metadata, get_hardcoded_world_to_ecef 10 | from l5kit.rasterization.semantic_rasterizer import CV2_SHIFT, cv2_subpixel, elements_within_bounds 11 | 12 | 13 | class ChannelSemanticRasterizer(SemanticRasterizer): 14 | @staticmethod 15 | def from_cfg(cfg, data_manager): 16 | raster_cfg = cfg["raster_params"] 17 | # map_type = raster_cfg["map_type"] 18 | dataset_meta_key = raster_cfg["dataset_meta_key"] 19 | 20 | render_context = RenderContext( 21 | raster_size_px=np.array(raster_cfg["raster_size"]), 22 | pixel_size_m=np.array(raster_cfg["pixel_size"]), 23 | center_in_raster_ratio=np.array(raster_cfg["ego_center"]), 24 | ) 25 | # filter_agents_threshold = raster_cfg["filter_agents_threshold"] 26 | # history_num_frames = cfg["model_params"]["history_num_frames"] 27 | 28 | semantic_map_filepath = data_manager.require(raster_cfg["semantic_map_key"]) 29 | try: 30 | dataset_meta = _load_metadata(dataset_meta_key, data_manager) 31 | world_to_ecef = np.array(dataset_meta["world_to_ecef"], dtype=np.float64) 32 | except (KeyError, FileNotFoundError): # TODO remove when new dataset version is available 33 | world_to_ecef = get_hardcoded_world_to_ecef() 34 | 35 | return ChannelSemanticRasterizer(render_context, semantic_map_filepath, world_to_ecef) 36 | 37 | def __init__( 38 | self, 39 | render_context: RenderContext, semantic_map_path: str, world_to_ecef: np.ndarray, 40 | ): 41 | super(ChannelSemanticRasterizer, self).__init__(render_context, semantic_map_path, world_to_ecef) 42 | self.raster_channels = 6 43 | 44 | def rasterize( 45 | self, 46 | history_frames: np.ndarray, 47 | history_agents: List[np.ndarray], 48 | history_tl_faces: List[np.ndarray], 49 | agent: Optional[np.ndarray] = None, 50 | ) -> np.ndarray: 51 | if agent is None: 52 | ego_translation_m = history_frames[0]["ego_translation"] 53 | ego_yaw_rad = rotation33_as_yaw(history_frames[0]["ego_rotation"]) 54 | else: 55 | ego_translation_m = np.append(agent["centroid"], history_frames[0]["ego_translation"][-1]) 56 | ego_yaw_rad = agent["yaw"] 57 | 58 | raster_from_world = self.render_context.raster_from_world(ego_translation_m, ego_yaw_rad) 59 | world_from_raster = np.linalg.inv(raster_from_world) 60 | 61 | # get XY of center pixel in world coordinates 62 | center_in_raster_px = np.asarray(self.raster_size) * (0.5, 0.5) 63 | center_in_world_m = transform_point(center_in_raster_px, world_from_raster) 64 | 65 | sem_im = self.render_semantic_map(center_in_world_m, raster_from_world, history_tl_faces[0]) 66 | return sem_im.astype(np.float32) / 255 67 | 68 | def render_semantic_map( 69 | self, center_world: np.ndarray, raster_from_world: np.ndarray, tl_faces: np.ndarray 70 | ) -> np.ndarray: 71 | """Renders the semantic map at given x,y coordinates. 72 | 73 | Args: 74 | center_world (np.ndarray): XY of the image center in world ref system 75 | raster_from_world (np.ndarray): 76 | tl_faces (np.ndarray): 77 | 78 | Returns: 79 | np.ndarray: RGB raster 80 | 81 | """ 82 | 83 | # img = 255 * np.ones(shape=(self.raster_size[1], self.raster_size[0], 3), dtype=np.uint8) 84 | img = np.zeros(shape=(self.raster_channels, self.raster_size[1], self.raster_size[0]), dtype=np.uint8) 85 | 86 | # filter using half a radius from the center 87 | raster_radius = float(np.linalg.norm(self.raster_size * self.pixel_size)) / 2 88 | 89 | # get active traffic light faces 90 | active_tl_ids = set(filter_tl_faces_by_status(tl_faces, "ACTIVE")["face_id"].tolist()) 91 | 92 | # plot lanes 93 | lanes_lines = defaultdict(list) 94 | 95 | for idx in elements_within_bounds(center_world, self.bounds_info["lanes"]["bounds"], raster_radius): 96 | lane = self.proto_API[self.bounds_info["lanes"]["ids"][idx]].element.lane 97 | 98 | # get image coords 99 | lane_coords = self.proto_API.get_lane_coords(self.bounds_info["lanes"]["ids"][idx]) 100 | xy_left = cv2_subpixel(transform_points(lane_coords["xyz_left"][:, :2], raster_from_world)) 101 | xy_right = cv2_subpixel(transform_points(lane_coords["xyz_right"][:, :2], raster_from_world)) 102 | lanes_area = np.vstack((xy_left, np.flip(xy_right, 0))) # start->end left then end->start right 103 | 104 | # --- lanes --- 105 | # Note(lberg): this called on all polygons skips some of them, don't know why 106 | # cv2.fillPoly(img, [lanes_area], (17, 17, 31), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 107 | cv2.fillPoly(img[0], [lanes_area], 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 108 | 109 | lane_type = "default" # no traffic light face is controlling this lane 110 | lane_tl_ids = set([MapAPI.id_as_str(la_tc) for la_tc in lane.traffic_controls]) 111 | for tl_id in lane_tl_ids.intersection(active_tl_ids): 112 | if self.proto_API.is_traffic_face_colour(tl_id, "red"): 113 | lane_type = "red" 114 | elif self.proto_API.is_traffic_face_colour(tl_id, "green"): 115 | lane_type = "green" 116 | elif self.proto_API.is_traffic_face_colour(tl_id, "yellow"): 117 | lane_type = "yellow" 118 | 119 | lanes_lines[lane_type].extend([xy_left, xy_right]) 120 | 121 | # --- Traffic lights --- 122 | # cv2.polylines(img, lanes_lines["default"], False, (255, 217, 82), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 123 | # cv2.polylines(img, lanes_lines["green"], False, (0, 255, 0), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 124 | # cv2.polylines(img, lanes_lines["yellow"], False, (255, 255, 0), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 125 | # cv2.polylines(img, lanes_lines["red"], False, (255, 0, 0), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 126 | cv2.polylines(img[1], lanes_lines["default"], False, 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 127 | cv2.polylines(img[2], lanes_lines["green"], False, 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 128 | cv2.polylines(img[3], lanes_lines["yellow"], False, 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 129 | cv2.polylines(img[4], lanes_lines["red"], False, 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 130 | 131 | # plot crosswalks 132 | crosswalks = [] 133 | for idx in elements_within_bounds(center_world, self.bounds_info["crosswalks"]["bounds"], raster_radius): 134 | crosswalk = self.proto_API.get_crosswalk_coords(self.bounds_info["crosswalks"]["ids"][idx]) 135 | 136 | xy_cross = cv2_subpixel(transform_points(crosswalk["xyz"][:, :2], raster_from_world)) 137 | crosswalks.append(xy_cross) 138 | 139 | # --- Cross Walks --- 140 | # cv2.polylines(img, crosswalks, True, (255, 117, 69), lineType=cv2.LINE_AA, shift=CV2_SHIFT) 141 | cv2.polylines(img[5], crosswalks, True, 255, lineType=cv2.LINE_AA, shift=CV2_SHIFT) 142 | # ch, h, w --> h, w, ch 143 | img = img.transpose((1, 2, 0)) 144 | return img 145 | 146 | def to_rgb(self, in_im: np.ndarray, **kwargs: dict) -> np.ndarray: 147 | # return (in_im * 255).astype(np.uint8) 148 | raise NotImplementedError 149 | --------------------------------------------------------------------------------