├── pyrightconfig.json ├── .gitignore ├── losses ├── diff_BCE.py ├── L2.py └── laplace_nll_loss.py ├── README.md ├── metrics ├── fde_t.py ├── ade_t.py └── mr_t.py ├── test.py ├── models ├── decoders │ ├── dec_hivt_nusargo_grid.py │ └── dec_hivt_nusargo_sde.py ├── utils │ ├── embedding.py │ ├── dec_utils.py │ ├── sde_utils.py │ ├── ode_utils.py │ └── util.py ├── aggregators │ └── agg_hivt.py ├── model_base_mix_sde.py ├── model_base_mix.py └── encoders │ ├── enc_hivt_nusargo_grid.py │ └── enc_hivt_nusargo_sde_sep2.py ├── train.py ├── dataset ├── Datamodule_nuargo_mix.py ├── nuScenes_Argoverse │ └── nuScenes_Argoverse.py └── Argoverse │ └── Argoverse_abs.py ├── configs └── nusargo │ ├── hivt_nuSArgo_trmenc_mlpdec.yml │ └── hivt_nuSArgo_sdesepenc_sdedec.yml └── env.yml /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": ["preprocessed", "checkpoints", "data"], 3 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *tmp* 2 | *.pyc 3 | data/** 4 | preprocessed/** 5 | checkpoints*/** 6 | .vscode 7 | ray_results* 8 | **2304** 9 | dataset/Waymo/idcs 10 | -------------------------------------------------------------------------------- /losses/diff_BCE.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class DiffBCE(nn.Module): 7 | def __init__(self, reduction: str='mean') -> None: 8 | super(DiffBCE, self).__init__() 9 | self.bce = nn.BCELoss(reduction=reduction) 10 | 11 | def forward(self, data, output): 12 | 13 | diff_in, diff_out, label_in, label_out = output['diff_in'], output['diff_out'], output['label_in'], output['label_out'] 14 | loss_in = self.bce(diff_in, label_in) 15 | loss_out = self.bce(diff_out, label_out) 16 | 17 | return loss_in + loss_out -------------------------------------------------------------------------------- /losses/L2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class L2(nn.Module): 6 | def __init__(self, reduction: str='mean') -> None: 7 | super(L2, self).__init__() 8 | self.reduction = reduction 9 | 10 | def forward(self, data, output): 11 | target = data['y'] 12 | loc, scale = output['loc'].chunk(2, dim=-1) 13 | reg_mask = output['reg_mask'] 14 | 15 | l2 = torch.norm(target.unsqueeze(0) - loc, p=2, dim=-1) 16 | 17 | ade = l2.clone() 18 | ade[:,~reg_mask] = 0 19 | made_idcs = torch.argmin(ade.mean(-1), dim=0) 20 | minl2 = l2[made_idcs, torch.arange(l2.size(1))] 21 | 22 | if reg_mask.sum() > 0: 23 | if self.reduction == 'mean': 24 | return minl2[reg_mask].mean() 25 | else: 26 | raise ValueError(f'{self.reduction} is not a valid value for reduction') 27 | else: 28 | return 0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## This is an code implementation of "Improving Transferability for Cross-domain Trajectory Prediction via Neural Stochastic Differential Equation", AAAI'24 2 | Please follow below steps to run our code 3 | 4 | ## 1. Create virtual environment in Anaconda with env.yml 5 | 6 | ``` 7 | conda env create --file env.yaml -n trajsde 8 | conda activate trajsde 9 | ``` 10 | 11 | ## 2. Prepare raw dataset of nuScenes and Argoverse 12 | Download meta data of trainval set of nuScenes from "https://www.nuscenes.org/nuscenes#download". 13 | 14 | Download Training/Validataion/Testing dataset of motion forecasting from "https://www.argoverse.org/av1.html#download-link" 15 | 16 | Locate them in 'data' dir as following: 17 | ``` 18 | . 19 | ├── configs 20 | ├── ... 21 | ├── data 22 | │ ├── nuScenes 23 | │ │ ├── maps 24 | │ │ ├── samples 25 | │ │ ├── ... 26 | │ │ └── v1.0-trainval 27 | │ └── argodataset 28 | │ ├── map_files 29 | │ ├── train 30 | │ └── val 31 | └── train.py 32 | ``` 33 | 34 | ## 3. Run preprocessing for nuScenes and Argoverse, respectively 35 | ``` 36 | mkdir preprocessed 37 | # Argoverse 38 | python dataset/Argoverse/Argoverse_abs.py 39 | # nuScenes 40 | python dataset/nuScenes/nuScenes_hivt.py 41 | ``` 42 | 43 | Then, preprocess data files are saved in 'preprocessed/Argoverse' for Argoverse and 'preprocessed/nuScenes' for nuScenes. 44 | 45 | ## 4. Make checkpoints dir and run training code 46 | ``` 47 | mkdir checkpoints 48 | # Vanilla HiVT 49 | python train.py -n baseline -c configs/nusargo/hivt_nuSArgo_trmenc_mlpdec.yml 50 | # Ours 51 | python train.py -n nsde -c configs/nusargo/hivt_nuSArgo_sdesepenc_sdedec.yml 52 | ``` 53 | -------------------------------------------------------------------------------- /losses/laplace_nll_loss.py: -------------------------------------------------------------------------------- 1 | # opyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class LaplaceNLLLoss(nn.Module): 19 | 20 | def __init__(self, 21 | eps: float = 1e-6, 22 | reduction: str = 'mean') -> None: 23 | super(LaplaceNLLLoss, self).__init__() 24 | self.eps = float(eps) 25 | self.reduction = reduction 26 | 27 | def forward(self, 28 | data, output) -> torch.Tensor: 29 | target = data['y'] 30 | loc, scale = output['loc'].chunk(2, dim=-1) 31 | reg_mask = output['reg_mask'] 32 | 33 | diff = torch.norm(target.unsqueeze(0) - loc, dim=-1) 34 | diff_ = diff.clone() 35 | diff_[:,~reg_mask] = 0 36 | best_mode = torch.argmin(diff_.mean(-1), dim=0) 37 | node_nums = best_mode.size(0) 38 | 39 | loc, scale = loc[best_mode, torch.arange(node_nums)], scale[best_mode, torch.arange(node_nums)] 40 | scale = scale.clone() 41 | with torch.no_grad(): 42 | scale.clamp_(min=self.eps) 43 | nll = torch.log(2 * scale) + torch.abs(target - loc) / scale 44 | if self.reduction == 'mean': 45 | return nll[reg_mask].mean() 46 | else: 47 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction)) 48 | 49 | -------------------------------------------------------------------------------- /metrics/fde_t.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class FDE_T(Metric): 21 | 22 | def __init__(self, 23 | dataset, 24 | end_idcs, 25 | sources=[0,1], 26 | compute_on_step: bool = True, 27 | dist_sync_on_step: bool = False, 28 | process_group: Optional[Any] = None, 29 | dist_sync_fn: Callable = None, 30 | **kwargs) -> None: 31 | super(FDE_T, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 32 | process_group=process_group, dist_sync_fn=dist_sync_fn) 33 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 34 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 35 | self.dataset = dataset 36 | self.end_idcs = end_idcs 37 | self.target_sources = sources 38 | 39 | def update(self, 40 | pred: torch.Tensor, 41 | target: torch.Tensor, 42 | reg_mask: torch.Tensor, 43 | source) -> None: 44 | 45 | K, NA, TS, _ = pred.shape 46 | 47 | count_0, count_1 = (source==self.target_sources[0]).sum(), (source==self.target_sources[1]).sum() 48 | end_idcs_ = torch.repeat_interleave(torch.tensor(self.end_idcs), torch.tensor([count_0, count_1])) 49 | 50 | l2 = torch.norm(pred[:,torch.arange(NA),end_idcs_,:] - target[torch.arange(NA),end_idcs_].unsqueeze(0), p=2, dim=-1) 51 | reg_mask_any = reg_mask[torch.arange(NA),end_idcs_] 52 | 53 | l2 = l2[:, reg_mask_any] 54 | best_idx = torch.argmin(l2, dim=0) 55 | fde_best = l2[best_idx, torch.arange(reg_mask_any.sum())] 56 | self.sum += fde_best.sum().item() 57 | self.count += reg_mask_any.sum().item() 58 | 59 | def compute(self) -> torch.Tensor: 60 | return self.sum / self.count 61 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | from importlib.machinery import SourceFileLoader 5 | from argparse import ArgumentParser 6 | 7 | import pytorch_lightning as pl 8 | 9 | import torch.multiprocessing 10 | torch.multiprocessing.set_sharing_strategy('file_system') 11 | 12 | if __name__ == '__main__': 13 | pl.seed_everything(0) 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument('-c', '--config', type=str, required=True) 17 | parser.add_argument('--viz', action="store_true", default=False) 18 | parser.add_argument('--viz_goalpred', action="store_true", default=False) 19 | parser.add_argument('--submit', action="store_true", default=False) 20 | parser.add_argument('--ood', action="store_true", default=False) 21 | parser.add_argument('--viz_ood', action="store_true", default=False) 22 | 23 | parser.add_argument('-s', '--save_dir', type=str, default='checkpoints') 24 | parser.add_argument('--root', type=str, default=None) 25 | parser.add_argument('--ckpt_dir', type=str, default='checkpoints') 26 | parser.add_argument('--ckpt', type=str, default=None) 27 | parser.add_argument('--save_top_k', type=int, default=-1) 28 | parser.add_argument('--gpus', type=int, default=1) 29 | 30 | parser.add_argument('--pin_memory', type=bool, default=True) 31 | parser.add_argument('--persistent_workers', type=bool, default=True) 32 | parser.add_argument('--num_workers', type=int, default=8) 33 | 34 | parser.add_argument('--shuffle', type=bool, default=True) 35 | parser.add_argument('--max_epochs', type=int, default=100) 36 | parser.add_argument('--monitor', type=str, default='val_minFDE', choices=['val_minADE', 'val_minFDE', 'val_minMR']) 37 | 38 | args = parser.parse_args() 39 | 40 | with open(args.config, 'r') as yaml_file: 41 | cfg = yaml.safe_load(yaml_file) 42 | cfg['model_specific']['kwargs']['viz'] = args.viz 43 | cfg['model_specific']['kwargs']['viz_goalpred'] = args.viz_goalpred 44 | cfg['model_specific']['kwargs']['submit'] = args.submit 45 | cfg['model_specific']['kwargs']['ood'] = args.ood 46 | cfg['model_specific']['kwargs']['viz_ood'] = args.viz_ood 47 | 48 | model = getattr(SourceFileLoader(cfg['model_specific']['module_name'], cfg['model_specific']['file_path']).load_module(cfg['model_specific']['module_name']), cfg['model_specific']['module_name']) 49 | model = model(**dict(cfg)) 50 | 51 | trainer = pl.Trainer.from_argparse_args(args, limit_test_batches=1.) 52 | trainer.logger = False 53 | 54 | dmodulecfg = cfg['datamodule_specific'] 55 | datamodule = getattr(SourceFileLoader(dmodulecfg['module_name'], dmodulecfg['file_path']).load_module(dmodulecfg['module_name']), dmodulecfg['module_name']) 56 | datamodule = datamodule(**dict(dmodulecfg['kwargs'])) 57 | 58 | trainer.test(model, dataloaders=datamodule, ckpt_path=args.ckpt) 59 | -------------------------------------------------------------------------------- /models/decoders/dec_hivt_nusargo_grid.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.utils.util import init_weights 8 | 9 | 10 | class MLPDecoder(nn.Module): 11 | 12 | def __init__(self, 13 | **kwargs) -> None: 14 | super(MLPDecoder, self).__init__() 15 | 16 | for key, value in kwargs.items(): 17 | setattr(self, key, value) 18 | 19 | self.input_size = self.global_channels 20 | self.hidden_size = self.local_channels 21 | 22 | self.aggr_embed = nn.Sequential( 23 | nn.Linear(self.input_size + self.hidden_size, self.hidden_size), 24 | nn.LayerNorm(self.hidden_size), 25 | nn.ReLU(inplace=True)) 26 | self.loc = nn.Sequential( 27 | nn.Linear(self.hidden_size, self.hidden_size), 28 | nn.LayerNorm(self.hidden_size), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(self.hidden_size, self.future_steps * 2)) 31 | if self.uncertain: 32 | self.scale = nn.Sequential( 33 | nn.Linear(self.hidden_size, self.hidden_size), 34 | nn.LayerNorm(self.hidden_size), 35 | nn.ReLU(inplace=True), 36 | nn.Linear(self.hidden_size, self.future_steps * 2)) 37 | self.pi = nn.Sequential( 38 | nn.Linear(self.hidden_size + self.input_size, self.hidden_size), 39 | nn.LayerNorm(self.hidden_size), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(self.hidden_size, self.hidden_size), 42 | nn.LayerNorm(self.hidden_size), 43 | nn.ReLU(inplace=True), 44 | nn.Linear(self.hidden_size, 1)) 45 | self.apply(init_weights) 46 | 47 | def forward(self, 48 | data, 49 | local_embed: torch.Tensor, 50 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 51 | pi = self.pi(torch.cat((local_embed.expand(self.num_modes, *local_embed.shape), 52 | global_embed), dim=-1)).squeeze(-1).t() 53 | out = self.aggr_embed(torch.cat((global_embed, local_embed.expand(self.num_modes, *local_embed.shape)), dim=-1)) 54 | loc = self.loc(out).view(self.num_modes, -1, self.future_steps, 2) # [F, N, H, 2] 55 | if self.uncertain: 56 | scale = F.elu_(self.scale(out), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2) + 1.0 57 | scale = scale + self.min_scale # [F, N, H, 2] 58 | out = {'loc': torch.cat((loc, scale), dim=-1), 'pi': pi, 'local_embed': local_embed, 'global_embed': global_embed} # [F, N, H, 4], [N, F] 59 | 60 | else: 61 | out = {'loc': loc, 'pi': pi} # [F, N, H, 2], [N, F] 62 | 63 | out['reg_mask'] = ~data['padding_mask'][:,-self.future_steps:] 64 | return out -------------------------------------------------------------------------------- /metrics/ade_t.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class ADE_T(Metric): 21 | 22 | def __init__(self, 23 | dataset, 24 | end_idcs, 25 | sources=[0,1], 26 | compute_on_step: bool = True, 27 | dist_sync_on_step: bool = False, 28 | process_group: Optional[Any] = None, 29 | dist_sync_fn: Callable = None, 30 | **kwargs) -> None: 31 | super(ADE_T, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 32 | process_group=process_group, dist_sync_fn=dist_sync_fn) 33 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 34 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 35 | self.dataset = dataset 36 | self.end_idcs = end_idcs 37 | self.target_sources = sources 38 | 39 | def update(self, 40 | pred: torch.Tensor, 41 | target: torch.Tensor, 42 | reg_mask: torch.Tensor, 43 | source) -> None: 44 | l2 = torch.norm(pred - target.unsqueeze(0), p=2, dim=-1) 45 | reg_mask_any = reg_mask.any(-1) 46 | l2, reg_mask = l2[:,reg_mask_any], reg_mask[reg_mask_any] 47 | l2[:, ~reg_mask] = 0 48 | source = source[reg_mask_any] 49 | 50 | ade = l2.sum(-1) / reg_mask.sum(-1).unsqueeze(0) 51 | 52 | if self.dataset == 'nuScenes': 53 | 54 | best_idx = torch.argmin(ade, dim=0) 55 | elif self.dataset == 'Argoverse': 56 | count_0, count_1 = (source==self.target_sources[0]).sum(), (source==self.target_sources[1]).sum() 57 | end_idcs_ = torch.repeat_interleave(torch.tensor(self.end_idcs), torch.tensor([count_0, count_1])) 58 | 59 | fde = l2[:,torch.arange(l2.size(1)),end_idcs_] 60 | best_idx = torch.argmin(fde, dim=0) 61 | else: 62 | raise NotImplementedError('other dataset is not implemented') 63 | 64 | ade_best = ade[best_idx, torch.arange(reg_mask_any.sum())] 65 | self.sum += ade_best.sum().item() 66 | self.count += reg_mask_any.sum().item() 67 | 68 | def compute(self) -> torch.Tensor: 69 | return self.sum / self.count 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import importlib 3 | from importlib.machinery import SourceFileLoader 4 | from argparse import ArgumentParser 5 | import os 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | 11 | from debug_util import save_modules 12 | 13 | import torch.multiprocessing 14 | torch.multiprocessing.set_sharing_strategy('file_system') 15 | 16 | import warnings 17 | warnings.filterwarnings(action='ignore') 18 | 19 | if __name__ == '__main__': 20 | pl.seed_everything(0) 21 | 22 | parser = ArgumentParser() 23 | parser.add_argument('-c', '--config', type=str, required=True) 24 | parser.add_argument('-n', '--name', type=str, required=True) 25 | parser.add_argument('-d', '--description', type=str, help='description of the experiment', default='') 26 | parser.add_argument('-s', '--save_dir', type=str, default='checkpoints/nuSArgo') 27 | 28 | parser.add_argument('--ckpt', type=str, default=None) 29 | parser.add_argument('--wonly', action='store_true', default=False) 30 | 31 | parser.add_argument('--root', type=str, default=None) 32 | parser.add_argument('--save_top_k', type=int, default=-1) 33 | parser.add_argument('--viz', action='store_true', default=False) 34 | parser.add_argument('--viz_goalpred', action='store_true', default=False) 35 | parser.add_argument('--gpus', type=int, default=1) 36 | 37 | parser.add_argument('--max_epochs', type=int, default=100) 38 | parser.add_argument('--monitor', type=str, default='val/ADE_T') 39 | 40 | args = parser.parse_args() 41 | 42 | with open(args.config, 'r') as yaml_file: 43 | cfg = yaml.safe_load(yaml_file) 44 | cfg['model_specific']['kwargs']['viz'] = args.viz 45 | cfg['model_specific']['kwargs']['viz_goalpred'] = args.viz_goalpred 46 | cfg['description'] = args.description 47 | args.max_epochs = cfg['training_specific']['max_epochs'] 48 | 49 | model = getattr(SourceFileLoader(cfg['model_specific']['module_name'], cfg['model_specific']['file_path']).load_module(cfg['model_specific']['module_name']), cfg['model_specific']['module_name']) 50 | model = model(**dict(cfg)) 51 | 52 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 53 | logger = TensorBoardLogger(save_dir=args.save_dir, name=args.name) 54 | trainer = pl.Trainer.from_argparse_args(args, callbacks=[model_checkpoint], logger=logger, num_sanity_val_steps = 0) 55 | 56 | dmodulecfg = cfg['datamodule_specific'] 57 | datamodule = getattr(SourceFileLoader(dmodulecfg['module_name'], dmodulecfg['file_path']).load_module(dmodulecfg['module_name']), dmodulecfg['module_name']) 58 | datamodule = datamodule(**dict(dmodulecfg['kwargs'])) 59 | 60 | save_modules(logger.log_dir, args.config, cfg) 61 | 62 | if args.wonly: 63 | model = model.load_from_checkpoint(checkpoint_path=args.ckpt) 64 | trainer.fit(model, datamodule) 65 | else: 66 | trainer.fit(model, datamodule, ckpt_path=args.ckpt) 67 | -------------------------------------------------------------------------------- /dataset/Datamodule_nuargo_mix.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | import importlib 3 | from importlib.machinery import SourceFileLoader 4 | 5 | from pytorch_lightning import LightningDataModule 6 | from torch_geometric.data import DataLoader 7 | 8 | 9 | SPLIT_NAME = {'nuScenes': {'train': 'train', 'val': 'train_val', 'test': 'val', 'mini_train': 'mini_train', 'mini_val': 'mini_val'}, 10 | 'Argoverse': {'train': 'train', 'val': 'val', 'test': 'test_obs', 'sample': 'forecasting_sample'}} 11 | 12 | 13 | 14 | class DataModuleNuArgoMix(LightningDataModule): 15 | 16 | def __init__(self, 17 | dataset_file_path, 18 | dataset_module_name, 19 | **kwargs) -> None: 20 | super(DataModuleNuArgoMix, self).__init__() 21 | for k,v in kwargs.items(): 22 | self.__setattr__(k, v) 23 | 24 | self.dataset_module = getattr(SourceFileLoader(dataset_module_name, dataset_file_path).load_module(dataset_module_name), dataset_module_name) 25 | 26 | def setup(self, stage: Optional[str] = None) -> None: 27 | 28 | self.train_dataset = self.dataset_module('train', self.nu_root, self.Argo_root, self.nu_dir, self.Argo_dir, spec_args=self.tr_dataset_args) 29 | self.val_dataset = self.dataset_module('val', self.nu_root, self.Argo_root, self.nu_dir, self.Argo_dir, spec_args=self.val_dataset_args) 30 | # self.test_dataset = self.dataset_module('test', self.nu_root, self.Argo2_root, self.nu_dir, self.Argo2_dir, spec_args=self.test_dataset_args) 31 | self.test_dataset = self.dataset_module('val', self.nu_root, self.Argo_root, self.nu_dir, self.Argo_dir, spec_args=self.test_dataset_args) 32 | 33 | def train_dataloader(self): 34 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 35 | num_workers=self.num_workers, pin_memory=self.pin_memory, 36 | persistent_workers=self.persistent_workers) 37 | 38 | def val_dataloader(self): 39 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 40 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) 41 | 42 | def test_dataloader(self): 43 | return DataLoader(self.test_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 44 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) 45 | 46 | if __name__ == '__main__': 47 | import yaml 48 | import importlib 49 | 50 | with open('/home/user/ssd4tb/frm_lightning/configs/hivt_LanesegGoal.yml', 'r') as yaml_file: 51 | cfg = yaml.safe_load(yaml_file) 52 | 53 | dataconfig = cfg['datamodule_specific'] 54 | datamodule = getattr(importlib.import_module(dataconfig['file_path']), dataconfig['module_name']) 55 | datamodule = datamodule(**dict(dataconfig['kwargs'])) 56 | datamodule.setup() 57 | 58 | for idx, batch in enumerate(datamodule.train_dataloader()): 59 | print(batch) 60 | if idx>10: 61 | break 62 | -------------------------------------------------------------------------------- /metrics/mr_t.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class MR_T(Metric): 21 | 22 | def __init__(self, 23 | dataset, 24 | end_idcs, 25 | sources=[0,1], 26 | miss_threshold: float = 2.0, 27 | compute_on_step: bool = True, 28 | dist_sync_on_step: bool = False, 29 | process_group: Optional[Any] = None, 30 | dist_sync_fn: Callable = None, 31 | **kwargs) -> None: 32 | super(MR_T, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 33 | process_group=process_group, dist_sync_fn=dist_sync_fn) 34 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 35 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 36 | self.miss_threshold = miss_threshold 37 | self.dataset = dataset 38 | self.end_idcs = end_idcs 39 | self.target_sources = sources 40 | 41 | def update(self, 42 | pred: torch.Tensor, 43 | target: torch.Tensor, 44 | reg_mask: torch.Tensor, 45 | source) -> None: 46 | 47 | if self.dataset == 'nuScenes': 48 | l2 = torch.norm(pred - target.unsqueeze(0), p=2, dim=-1) 49 | reg_mask_any = reg_mask.any(-1) 50 | l2, reg_mask = l2[:,reg_mask_any], reg_mask[reg_mask_any] 51 | l2[:, ~reg_mask] = 0 52 | maxmin_l2 = l2.max(-1)[0].min(0)[0] 53 | 54 | missed = maxmin_l2 > self.miss_threshold 55 | elif self.dataset == 'Argoverse': 56 | K, NA, TS, _ = pred.shape 57 | 58 | count_0, count_1 = (source==self.target_sources[0]).sum(), (source==self.target_sources[1]).sum() 59 | end_idcs_ = torch.repeat_interleave(torch.tensor(self.end_idcs), torch.tensor([count_0, count_1])) 60 | 61 | l2 = torch.norm(pred[:,torch.arange(NA),end_idcs_,:] - target[torch.arange(NA),end_idcs_].unsqueeze(0), p=2, dim=-1) 62 | reg_mask_any = reg_mask[torch.arange(NA),end_idcs_] 63 | 64 | l2 = l2[:, reg_mask_any] 65 | best_idx = torch.argmin(l2, dim=0) 66 | fde_best = l2[best_idx, torch.arange(reg_mask_any.sum())] 67 | missed = fde_best > self.miss_threshold 68 | 69 | self.sum += missed.sum().item() 70 | self.count += reg_mask_any.sum().item() 71 | 72 | def compute(self) -> torch.Tensor: 73 | return self.sum / self.count 74 | -------------------------------------------------------------------------------- /configs/nusargo/hivt_nuSArgo_trmenc_mlpdec.yml: -------------------------------------------------------------------------------- 1 | training_specific: 2 | lr: 0.0005 3 | weight_decay: 0.0001 4 | T_max: 64 5 | hivt_optimizer: true 6 | nodecay: true 7 | max_epochs: 64 8 | 9 | model_specific: 10 | file_path: models/model_base_mix.py 11 | module_name: PredictionModel 12 | kwargs: 13 | dataset: &dataset nuScenes 14 | ref_time: &ref_time 20 15 | historical_steps: &historical_steps 21 16 | future_steps: &future_steps 60 17 | num_modes: &num_modes 10 18 | rotate: &rotate true 19 | parallel: ¶llel true 20 | only_agent: false 21 | is_gtabs: &is_gtabs true 22 | ts_drop: false 23 | 24 | encoder: 25 | file_path: models/encoders/enc_hivt_nusargo_grid.py 26 | module_name: LocalEncoder 27 | kwargs: 28 | historical_steps: *historical_steps 29 | node_dim: &node_dim 2 30 | edge_dim: &edge_dim 2 31 | embed_dim: &embed_dim 64 32 | num_heads: &num_heads 4 33 | dropout: &dropout 0.1 34 | num_temporal_layers: &num_temporal_layers 4 35 | local_radius: &local_radius 50 36 | parallel: *parallel 37 | input_diff: &input_diff true 38 | 39 | 40 | aggregator: 41 | file_path: models/aggregators/agg_hivt.py 42 | module_name: GlobalInteractor 43 | kwargs: 44 | historical_steps: *historical_steps 45 | embed_dim: *embed_dim 46 | edge_dim: *edge_dim 47 | num_modes: *num_modes 48 | num_heads: *num_heads 49 | num_layers: &num_global_layers 3 50 | dropout: *dropout 51 | rotate: *rotate 52 | 53 | decoder: 54 | file_path: models/decoders/dec_hivt_nusargo_grid.py 55 | module_name: MLPDecoder 56 | kwargs: 57 | local_channels: *embed_dim 58 | global_channels: *embed_dim 59 | future_steps: *future_steps 60 | num_modes: *num_modes 61 | uncertain: True 62 | min_scale: 0.001 63 | 64 | losses: ['losses/L2.py'] 65 | losses_module: ['L2'] 66 | loss_weights: [1] 67 | loss_args: 68 | - reduction: mean 69 | 70 | metrics: ['metrics/ade_t.py', 'metrics/fde_t.py', 'metrics/mr_t.py'] 71 | metrics_module: ['ADE_T', 'FDE_T', 'MR_T'] 72 | metric_args: 73 | - dataset: *dataset 74 | end_idcs: [59, 29] 75 | sources: [0,1] 76 | - dataset: *dataset 77 | end_idcs: [59, 29] 78 | sources: [0,1] 79 | - dataset: *dataset 80 | end_idcs: [59, 29] 81 | sources: [0,1] 82 | 83 | datamodule_specific: 84 | file_path: dataset/Datamodule_nuargo_mix.py 85 | module_name: DataModuleNuArgoMix 86 | kwargs: 87 | nu_root: data/nuScenes 88 | Argo_root: data/argodataset 89 | nu_dir: preprocessed/nuScenes_hivt 90 | Argo_dir: preprocessed/Argoverse 91 | train_batch_size: 512 92 | val_batch_size: 512 93 | num_workers: 32 94 | pin_memory: true 95 | persistent_workers: true 96 | dataset_file_path: dataset/nuScenes_Argoerse/nuScenes_Argoverse.py 97 | dataset_module_name: nuArgoDataset 98 | shuffle: true 99 | tr_dataset_args: 100 | type: grid 101 | nus: true 102 | Argo: true 103 | ref_time: *ref_time 104 | random_flip: true 105 | random_rotate: false 106 | is_gtabs: *is_gtabs 107 | val_dataset_args: 108 | type: grid 109 | nus: true 110 | Argo: false 111 | ref_time: *ref_time 112 | random_flip: false 113 | random_rotate: false 114 | is_gtabs: *is_gtabs 115 | test_dataset_args: 116 | type: grid 117 | nus: true 118 | Argo: false 119 | ref_time: *ref_time 120 | random_flip: false 121 | random_rotate: false 122 | is_gtabs: *is_gtabs 123 | 124 | -------------------------------------------------------------------------------- /configs/nusargo/hivt_nuSArgo_sdesepenc_sdedec.yml: -------------------------------------------------------------------------------- 1 | training_specific: 2 | lr: 0.001 3 | weight_decay: 0.0007 4 | T_max: 100 5 | hivt_optimizer: true 6 | nodecay: false 7 | max_epochs: 100 8 | 9 | model_specific: 10 | file_path: models/model_base_mix_sde.py 11 | module_name: PredictionModelSDENet 12 | kwargs: 13 | dataset: &dataset nuScenes 14 | ref_time: &ref_time 20 15 | historical_steps: &historical_steps 21 16 | future_steps: &future_steps 60 17 | num_modes: &num_modes 10 18 | rotate: &rotate true 19 | parallel: ¶llel true 20 | only_agent: false 21 | is_gtabs: &is_gtabs true 22 | is_gtabs: true 23 | 24 | encoder: 25 | file_path: models/encoders/enc_hivt_nusargo_sde_sep2.py 26 | module_name: LocalEncoderSDESepPara2 27 | kwargs: 28 | max_past_t: 2 29 | historical_steps: *historical_steps 30 | node_dim: &node_dim 2 31 | edge_dim: &edge_dim 2 32 | embed_dim: &embed_dim 64 33 | num_heads: &num_heads 8 34 | dropout: &dropout 0.1 35 | local_radius: &local_radius 50 36 | parallel: *parallel 37 | input_diff: &input_diff true 38 | minimum_step: &step_size 0.1 39 | ref_time: *ref_time 40 | run_backwards: true 41 | adjoint: false 42 | rtol: &rtol 0.001 43 | atol: &atol 0.001 44 | method: euler 45 | adaptive: false 46 | sde_layers: 2 47 | 48 | aggregator: 49 | file_path: models/aggregators/agg_hivt.py 50 | module_name: GlobalInteractor 51 | kwargs: 52 | historical_steps: *historical_steps 53 | embed_dim: *embed_dim 54 | edge_dim: *edge_dim 55 | num_modes: *num_modes 56 | num_heads: *num_heads 57 | num_layers: &num_global_layers 3 58 | dropout: *dropout 59 | rotate: *rotate 60 | 61 | decoder: 62 | file_path: models/decoders/dec_hivt_nusargo_sde.py 63 | module_name: SDEDecoder 64 | kwargs: 65 | local_channels: *embed_dim 66 | global_channels: *embed_dim 67 | future_steps: *future_steps 68 | num_modes: *num_modes 69 | max_fut_t: 6 70 | ode_func_layers: 3 71 | uncertain: True 72 | min_scale: 0.001 73 | rtol: *rtol 74 | atol: *atol 75 | min_stepsize: *step_size 76 | method: euler 77 | 78 | losses: ['losses/L2.py', 'losses/diff_BCE.py'] 79 | losses_module: ['L2', 'DiffBCE'] 80 | loss_weights: [1, 1] 81 | loss_args: 82 | - reduction: mean 83 | - reduction: mean 84 | 85 | metrics: ['metrics/ade_t.py', 'metrics/fde_t.py', 'metrics/mr_t.py'] 86 | metrics_module: ['ADE_T', 'FDE_T', 'MR_T'] 87 | metric_args: 88 | - dataset: *dataset 89 | end_idcs: [59, 29] 90 | sources: [0,1] 91 | - dataset: *dataset 92 | end_idcs: [59, 29] 93 | sources: [0,1] 94 | - dataset: *dataset 95 | end_idcs: [59, 29] 96 | sources: [0,1] 97 | 98 | datamodule_specific: 99 | file_path: dataset/Datamodule_nuargo_mix.py 100 | module_name: DataModuleNuArgoMix 101 | kwargs: 102 | nu_root: data/nuScenes 103 | Argo_root: data/argodataset 104 | nu_dir: preprocessed/nuScenes_hivt 105 | Argo_dir: preprocessed/Argoverse 106 | train_batch_size: 128 107 | val_batch_size: 128 108 | num_workers: 32 109 | dataset_file_path: dataset/nuScenes_Argoerse/nuScenes_Argoverse.py 110 | dataset_module_name: nuArgoDataset 111 | shuffle: true 112 | pin_memory: true 113 | persistent_workers: true 114 | tr_dataset_args: 115 | type: grid 116 | nus: true 117 | Argo: true 118 | ref_time: *ref_time 119 | random_flip: true 120 | random_rotate: false 121 | is_gtabs: *is_gtabs 122 | val_dataset_args: 123 | type: grid 124 | nus: true 125 | Argo: false 126 | ref_time: *ref_time 127 | random_flip: false 128 | random_rotate: false 129 | is_gtabs: *is_gtabs 130 | test_dataset_args: 131 | type: grid 132 | nus: true 133 | Argo: false 134 | ref_time: *ref_time 135 | random_flip: false 136 | random_rotate: false 137 | is_gtabs: *is_gtabs 138 | 139 | -------------------------------------------------------------------------------- /models/utils/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import List, Optional 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from models.utils.util import init_weights 20 | 21 | 22 | class SingleInputEmbedding(nn.Module): 23 | 24 | def __init__(self, 25 | in_channel: int, 26 | out_channel: int) -> None: 27 | super(SingleInputEmbedding, self).__init__() 28 | self.embed = nn.Sequential( 29 | nn.Linear(in_channel, out_channel), 30 | nn.LayerNorm(out_channel), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(out_channel, out_channel), 33 | nn.LayerNorm(out_channel), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(out_channel, out_channel), 36 | nn.LayerNorm(out_channel)) 37 | self.apply(init_weights) 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | return self.embed(x) 41 | 42 | 43 | class MultipleInputEmbedding(nn.Module): 44 | 45 | def __init__(self, 46 | in_channels: List[int], 47 | out_channel: int) -> None: 48 | super(MultipleInputEmbedding, self).__init__() 49 | self.module_list = nn.ModuleList( 50 | [nn.Sequential(nn.Linear(in_channel, out_channel), 51 | nn.LayerNorm(out_channel), 52 | nn.ReLU(inplace=True), 53 | nn.Linear(out_channel, out_channel)) 54 | for in_channel in in_channels]) 55 | self.aggr_embed = nn.Sequential( 56 | nn.LayerNorm(out_channel), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(out_channel, out_channel), 59 | nn.LayerNorm(out_channel)) 60 | self.apply(init_weights) 61 | 62 | def forward(self, 63 | continuous_inputs: List[torch.Tensor], 64 | categorical_inputs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: 65 | for i in range(len(self.module_list)): 66 | continuous_inputs[i] = self.module_list[i](continuous_inputs[i]) 67 | output = torch.stack(continuous_inputs).sum(dim=0) 68 | if categorical_inputs is not None: 69 | output += torch.stack(categorical_inputs).sum(dim=0) 70 | return self.aggr_embed(output) 71 | 72 | class MultipleInputConcatEmbedding(nn.Module): 73 | 74 | def __init__(self, 75 | in_channels: List[int], 76 | out_channel: int) -> None: 77 | super(MultipleInputConcatEmbedding, self).__init__() 78 | self.module_list = nn.ModuleList( 79 | [nn.Sequential(nn.Linear(in_channel, out_channel//2), 80 | nn.LayerNorm(out_channel//2), 81 | nn.ReLU(inplace=True), 82 | nn.Linear(out_channel//2, out_channel//2)) 83 | for in_channel in in_channels]) 84 | self.aggr_embed = nn.Sequential( 85 | nn.LayerNorm(out_channel), 86 | nn.ReLU(inplace=True), 87 | nn.Linear(out_channel, out_channel), 88 | nn.LayerNorm(out_channel)) 89 | self.apply(init_weights) 90 | 91 | def forward(self, 92 | continuous_inputs: List[torch.Tensor], 93 | categorical_inputs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: 94 | for i in range(len(self.module_list)): 95 | continuous_inputs[i] = self.module_list[i](continuous_inputs[i]) 96 | output = torch.hstack(continuous_inputs) 97 | if categorical_inputs is not None: 98 | output += torch.stack(categorical_inputs).sum(dim=0) 99 | return self.aggr_embed(output) -------------------------------------------------------------------------------- /models/utils/dec_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from datasets.interface import SingleAgentDataset 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | import psutil 6 | import ray 7 | from scipy.spatial.distance import cdist 8 | 9 | # # Initialize ray: 10 | # num_cpus = psutil.cpu_count(logical=False) 11 | # ray.init(num_cpus=num_cpus, log_to_driver=False) 12 | 13 | 14 | @ray.remote 15 | def cluster_and_rank(k: int, data: np.ndarray): 16 | """ 17 | Combines the clustering and ranking steps so that ray.remote gets called just once 18 | """ 19 | 20 | def cluster(n_clusters: int, x: np.ndarray): 21 | """ 22 | Cluster using Scikit learn 23 | """ 24 | clustering_op = KMeans(n_clusters=n_clusters, n_init=1, max_iter=100, init='random').fit(x) 25 | return clustering_op.labels_, clustering_op.cluster_centers_ 26 | 27 | def rank_clusters(cluster_counts, cluster_centers): 28 | """ 29 | Rank the K clustered trajectories using Ward's criterion. Start with K cluster centers and cluster counts. 30 | Find the two clusters to merge based on Ward's criterion. Smaller of the two will get assigned rank K. 31 | Merge the two clusters. Repeat process to assign ranks K-1, K-2, ..., 2. 32 | """ 33 | 34 | num_clusters = len(cluster_counts) 35 | cluster_ids = np.arange(num_clusters) 36 | ranks = np.ones(num_clusters) 37 | 38 | for i in range(num_clusters, 0, -1): 39 | # Compute Ward distances: 40 | centroid_dists = cdist(cluster_centers, cluster_centers) 41 | n1 = cluster_counts.reshape(1, -1).repeat(len(cluster_counts), axis=0) 42 | n2 = n1.transpose() 43 | wts = n1 * n2 / (n1 + n2) 44 | dists = wts * centroid_dists + np.diag(np.inf * np.ones(len(cluster_counts))) 45 | 46 | # Get clusters with min Ward distance and select cluster with fewer counts 47 | c1, c2 = np.unravel_index(dists.argmin(), dists.shape) 48 | c = c1 if cluster_counts[c1] <= cluster_counts[c2] else c2 49 | c_ = c2 if cluster_counts[c1] <= cluster_counts[c2] else c1 50 | 51 | # Assign rank i to selected cluster 52 | ranks[cluster_ids[c]] = i 53 | 54 | # Merge clusters and update identity of merged cluster 55 | cluster_centers[c_] = (cluster_counts[c_] * cluster_centers[c_] + cluster_counts[c] * cluster_centers[c]) /\ 56 | (cluster_counts[c_] + cluster_counts[c]) 57 | cluster_counts[c_] += cluster_counts[c] 58 | 59 | # Discard merged cluster 60 | cluster_ids = np.delete(cluster_ids, c) 61 | cluster_centers = np.delete(cluster_centers, c, axis=0) 62 | cluster_counts = np.delete(cluster_counts, c) 63 | 64 | return ranks 65 | 66 | cluster_lbls, cluster_ctrs = cluster(k, data) 67 | cluster_cnts = np.unique(cluster_lbls, return_counts=True)[1] 68 | cluster_ranks = rank_clusters(cluster_cnts.copy(), cluster_ctrs.copy()) 69 | return {'lbls': cluster_lbls, 'ranks': cluster_ranks, 'counts': cluster_cnts} 70 | 71 | 72 | def cluster_traj(k: int, traj: torch.Tensor, interval: int = 3): 73 | """ 74 | clusters sampled trajectories to output K modes. 75 | :param k: number of clusters 76 | :param traj: set of sampled trajectories, shape [batch_size, num_samples, traj_len, 2] 77 | :return: traj_clustered: set of clustered trajectories, shape [batch_size, k, traj_len, 2] 78 | scores: scores for clustered trajectories (basically 1/rank), shape [batch_size, k] 79 | """ 80 | 81 | # Initialize output tensors 82 | batch_size = traj.shape[0] 83 | num_samples = traj.shape[1] 84 | traj_len = traj.shape[2] 85 | 86 | # Down-sample traj along time dimension for faster clustering 87 | data = traj[:, :, 0::interval, :] 88 | data = data.reshape(batch_size, num_samples, -1).detach().cpu().numpy() 89 | 90 | # Cluster and rank 91 | cluster_ops = ray.get([cluster_and_rank.remote(k, data_slice) for data_slice in data]) 92 | # cluster_ops = [ray.get(cluster_and_rank.remote(k, data_slice)) for data_slice in data] 93 | # cluster_ops = [cluster_and_rank(k, data_slice) for data_slice in data] 94 | cluster_lbls = [cluster_op['lbls'] for cluster_op in cluster_ops] 95 | cluster_counts = [cluster_op['counts'] for cluster_op in cluster_ops] 96 | cluster_ranks = [cluster_op['ranks'] for cluster_op in cluster_ops] 97 | 98 | # Compute mean (clustered) traj and scores 99 | lbls = torch.as_tensor(np.array(cluster_lbls), device=traj.device).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, traj_len, 2).long() 100 | traj_summed = torch.zeros(batch_size, k, traj_len, 2, device=traj.device).scatter_add(1, lbls, traj) 101 | cnt_tensor = torch.as_tensor(np.array(cluster_counts), device=traj.device).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, traj_len, 2) 102 | traj_clustered = traj_summed / cnt_tensor 103 | scores = 1 / torch.as_tensor(np.array(cluster_ranks), device=traj.device) 104 | scores = scores / torch.sum(scores, dim=1)[0] 105 | 106 | return traj_clustered, scores 107 | -------------------------------------------------------------------------------- /models/aggregators/agg_hivt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.typing import Adj 7 | from torch_geometric.typing import OptTensor 8 | from torch_geometric.typing import Size 9 | from torch_geometric.utils import softmax 10 | from torch_geometric.utils import subgraph 11 | 12 | from models.utils.embedding import MultipleInputEmbedding 13 | from models.utils.embedding import SingleInputEmbedding 14 | from models.utils.util import TemporalData 15 | from models.utils.util import init_weights 16 | 17 | 18 | class GlobalInteractor(nn.Module): 19 | 20 | def __init__(self, 21 | **kwargs) -> None: 22 | super(GlobalInteractor, self).__init__() 23 | 24 | for key, value in kwargs.items(): 25 | setattr(self, key, value) 26 | 27 | if self.rotate: 28 | self.rel_embed = MultipleInputEmbedding(in_channels=[self.edge_dim, self.edge_dim], out_channel=self.embed_dim) 29 | else: 30 | self.rel_embed = SingleInputEmbedding(in_channel=self.edge_dim, out_channel=self.embed_dim) 31 | self.global_interactor_layers = nn.ModuleList( 32 | [GlobalInteractorLayer(embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout) 33 | for _ in range(self.num_layers)]) 34 | self.norm = nn.LayerNorm(self.embed_dim) 35 | self.multihead_proj = nn.Linear(self.embed_dim, self.num_modes * self.embed_dim) 36 | self.apply(init_weights) 37 | 38 | def forward(self, 39 | data: TemporalData, 40 | local_embed: torch.Tensor) -> torch.Tensor: 41 | edge_index, _ = subgraph(subset=~data['padding_mask'][:, self.historical_steps - 1], edge_index=data.edge_index) 42 | rel_pos = data['positions'][edge_index[0], self.historical_steps - 1] - data['positions'][ 43 | edge_index[1], self.historical_steps - 1] 44 | if data['rotate_mat'] is None: 45 | rel_embed = self.rel_embed(rel_pos) 46 | else: 47 | rel_pos = torch.bmm(rel_pos.unsqueeze(-2), data['rotate_mat'][edge_index[1]]).squeeze(-2) 48 | rel_theta = data['rotate_angles'][edge_index[0]] - data['rotate_angles'][edge_index[1]] 49 | rel_theta_cos = torch.cos(rel_theta).unsqueeze(-1) 50 | rel_theta_sin = torch.sin(rel_theta).unsqueeze(-1) 51 | rel_embed = self.rel_embed([rel_pos, torch.cat((rel_theta_cos, rel_theta_sin), dim=-1)]) 52 | x = local_embed 53 | for layer in self.global_interactor_layers: 54 | x = layer(x, edge_index, rel_embed) 55 | x = self.norm(x) # [N, D] 56 | x = self.multihead_proj(x).view(-1, self.num_modes, self.embed_dim) # [N, F, D] 57 | x = x.transpose(0, 1) # [F, N, D] 58 | return x 59 | 60 | 61 | class GlobalInteractorLayer(MessagePassing): 62 | 63 | def __init__(self, 64 | embed_dim: int, 65 | num_heads: int = 8, 66 | dropout: float = 0.1, 67 | **kwargs) -> None: 68 | super(GlobalInteractorLayer, self).__init__(aggr='add', node_dim=0, **kwargs) 69 | self.embed_dim = embed_dim 70 | self.num_heads = num_heads 71 | 72 | self.lin_q_node = nn.Linear(embed_dim, embed_dim) 73 | self.lin_k_node = nn.Linear(embed_dim, embed_dim) 74 | self.lin_k_edge = nn.Linear(embed_dim, embed_dim) 75 | self.lin_v_node = nn.Linear(embed_dim, embed_dim) 76 | self.lin_v_edge = nn.Linear(embed_dim, embed_dim) 77 | self.lin_self = nn.Linear(embed_dim, embed_dim) 78 | self.attn_drop = nn.Dropout(dropout) 79 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 80 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 81 | self.out_proj = nn.Linear(embed_dim, embed_dim) 82 | self.proj_drop = nn.Dropout(dropout) 83 | self.norm1 = nn.LayerNorm(embed_dim) 84 | self.norm2 = nn.LayerNorm(embed_dim) 85 | self.mlp = nn.Sequential( 86 | nn.Linear(embed_dim, embed_dim * 4), 87 | nn.ReLU(inplace=True), 88 | nn.Dropout(dropout), 89 | nn.Linear(embed_dim * 4, embed_dim), 90 | nn.Dropout(dropout)) 91 | 92 | def forward(self, 93 | x: torch.Tensor, 94 | edge_index: Adj, 95 | edge_attr: torch.Tensor, 96 | size: Size = None) -> torch.Tensor: 97 | x = x + self._mha_block(self.norm1(x), edge_index, edge_attr, size) 98 | x = x + self._ff_block(self.norm2(x)) 99 | return x 100 | 101 | def message(self, 102 | x_i: torch.Tensor, 103 | x_j: torch.Tensor, 104 | edge_attr: torch.Tensor, 105 | index: torch.Tensor, 106 | ptr: OptTensor, 107 | size_i: Optional[int]) -> torch.Tensor: 108 | query = self.lin_q_node(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads) 109 | key_node = self.lin_k_node(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 110 | key_edge = self.lin_k_edge(edge_attr).view(-1, self.num_heads, self.embed_dim // self.num_heads) 111 | value_node = self.lin_v_node(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 112 | value_edge = self.lin_v_edge(edge_attr).view(-1, self.num_heads, self.embed_dim // self.num_heads) 113 | scale = (self.embed_dim // self.num_heads) ** 0.5 114 | alpha = (query * (key_node + key_edge)).sum(dim=-1) / scale 115 | alpha = softmax(alpha, index, ptr, size_i) 116 | alpha = self.attn_drop(alpha) 117 | return (value_node + value_edge) * alpha.unsqueeze(-1) 118 | 119 | def update(self, 120 | inputs: torch.Tensor, 121 | x: torch.Tensor) -> torch.Tensor: 122 | inputs = inputs.view(-1, self.embed_dim) 123 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x)) 124 | return inputs + gate * (self.lin_self(x) - inputs) 125 | 126 | def _mha_block(self, 127 | x: torch.Tensor, 128 | edge_index: Adj, 129 | edge_attr: torch.Tensor, 130 | size: Size) -> torch.Tensor: 131 | x = self.out_proj(self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, size=size)) 132 | return self.proj_drop(x) 133 | 134 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 135 | return self.mlp(x) -------------------------------------------------------------------------------- /models/utils/sde_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.distributions import Normal, kl_divergence 5 | import pytorch_lightning as pl 6 | import torchsde 7 | from torchsde import sdeint, sdeint_adjoint 8 | 9 | 10 | 11 | class SAdjDiffeqSolver(nn.Module): 12 | def __init__(self, ode_func, method, dt=0.05, odeint_rtol=1e-4, odeint_atol=1e-5): 13 | super(SAdjDiffeqSolver, self).__init__() 14 | self.dt = dt 15 | self.ode_method = method 16 | self.ode_func = ode_func 17 | 18 | self.odeint_rtol = odeint_rtol 19 | self.odeint_atol = odeint_atol 20 | 21 | 22 | def forward(self, first_point, time_steps_to_predict): 23 | """ 24 | Decode the trajectory through ODE Solver. 25 | """ 26 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 27 | 28 | pred_y = sdeint_adjoint(self.ode_func, first_point, time_steps_to_predict, dt = self.dt, 29 | rtol = self.odeint_rtol, atol = self.odeint_atol, method = self.ode_method) 30 | pred_y = pred_y.permute(1,2,0) 31 | 32 | # assert(torch.mean(pred_y - first_point) < 0.001) 33 | assert(pred_y.size()[0] == n_traj_samples) 34 | assert(pred_y.size()[1] == n_traj) 35 | 36 | return pred_y 37 | 38 | class SDiffeqSolver(nn.Module): 39 | def __init__(self, ode_func, method, dt=0.05, odeint_rtol=1e-4, odeint_atol=1e-5): 40 | super(SDiffeqSolver, self).__init__() 41 | self.dt = dt 42 | self.ode_method = method 43 | self.ode_func = ode_func 44 | 45 | self.odeint_rtol = odeint_rtol 46 | self.odeint_atol = odeint_atol 47 | 48 | 49 | def forward(self, first_point, time_steps_to_predict): 50 | """ 51 | Decode the trajectory through ODE Solver. 52 | """ 53 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 54 | 55 | pred_y = sdeint(self.ode_func, first_point, time_steps_to_predict, dt = self.dt, 56 | rtol = self.odeint_rtol, atol = self.odeint_atol, method = self.ode_method) 57 | pred_y = pred_y.permute(1,2,0) 58 | 59 | # assert(torch.mean(pred_y - first_point) < 0.001) 60 | assert(pred_y.size()[0] == n_traj_samples) 61 | assert(pred_y.size()[1] == n_traj) 62 | 63 | return pred_y 64 | 65 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict): 66 | """ 67 | Decode the trajectory through ODE Solver using samples from the prior 68 | time_steps_to_predict: time steps at which we want to sample the new trajectory 69 | """ 70 | func = self.ode_func.sample_next_point_from_prior 71 | 72 | pred_y = sdeint(func, starting_point_enc, time_steps_to_predict, 73 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 74 | pred_y = pred_y.permute(1,2,0,3) 75 | return pred_y 76 | 77 | class SDiffeqSolverAug(torchsde.SDEIto): 78 | def __init__(self, ode_func, method, dt=0.05, odeint_rtol=1e-4, odeint_atol=1e-5): 79 | super(SDiffeqSolverAug, self).__init__(noise_type="diagonal") 80 | self.dt = dt 81 | self.ode_method = method 82 | self.ode_func = ode_func 83 | 84 | self.odeint_rtol = odeint_rtol 85 | self.odeint_atol = odeint_atol 86 | 87 | 88 | def forward(self, first_point, time_steps_to_predict): 89 | """ 90 | Decode the trajectory through ODE Solver. 91 | """ 92 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 93 | 94 | fitst_point_aug = torch.cat([first_point, torch.zeros(n_traj_samples, 1).to(first_point)], dim=1) 95 | 96 | pred_y_aug = sdeint(self.ode_func, fitst_point_aug, time_steps_to_predict, dt = self.dt, 97 | rtol = self.odeint_rtol, atol = self.odeint_atol, method = self.ode_method, 98 | names={'drift': 'f_aug', 'diffusion': 'g_aug'}) 99 | 100 | pred_y = pred_y_aug[:,:,:-1].permute(1,2,0) 101 | logqp_path = pred_y_aug[-1,:,-1] 102 | 103 | # assert(torch.mean(pred_y - first_point) < 0.001) 104 | assert(pred_y.size()[0] == n_traj_samples) 105 | assert(pred_y.size()[1] == n_traj) 106 | 107 | return pred_y, logqp_path.mean(dim=0) 108 | 109 | class SDEFunc(nn.Module): 110 | def __init__(self, f, g, order=1): 111 | super().__init__() 112 | self.order, self.intloss, self.sensitivity = order, None, None 113 | self.f_func, self.g_func = f, g 114 | self.fnfe, self.gnfe = 0, 0 115 | 116 | def forward(self, s, x): 117 | pass 118 | 119 | def f(self, s, x): 120 | """Posterior drift.""" 121 | self.fnfe += 1 122 | return self.f_func(x) 123 | 124 | def g(self, s, x): 125 | """Diffusion""" 126 | self.gnfe += 1 127 | return self.g_func(x).diag_embed() 128 | 129 | class LSDEFunc(torchsde.SDEIto): 130 | def __init__(self, f, g, h, order=1): 131 | super().__init__(noise_type="diagonal") 132 | self.order, self.intloss, self.sensitivity = order, None, None 133 | self.f_func, self.g_func, self.h_func = f, g, h 134 | self.fnfe, self.gnfe, self.hnfe = 0, 0, 0 135 | 136 | def forward(self, s, x): 137 | pass 138 | 139 | def h(self, s, x): 140 | """ Prior drift 141 | :param s: 142 | :param x: 143 | """ 144 | self.hnfe += 1 145 | return self.h_func(t=s, y=x) 146 | 147 | def f(self, s, x): 148 | """Posterior drift. 149 | :param s: 150 | :param x: 151 | """ 152 | self.fnfe += 1 153 | return self.f_func(t=s, y=x) 154 | 155 | def g(self, s, x): 156 | """Diffusion. 157 | :param s: 158 | :param x: 159 | """ 160 | self.gnfe += 1 161 | return self.g_func(t=s, y=x) 162 | 163 | 164 | class LSDEFuncAug(torchsde.SDEIto): 165 | def __init__(self, f, g, h, order=1): 166 | super().__init__(noise_type="diagonal") 167 | self.order, self.intloss, self.sensitivity = order, None, None 168 | self.f_func, self.g_func, self.h_func = f, g, h 169 | self.fnfe, self.gnfe, self.hnfe = 0, 0, 0 170 | 171 | def forward(self, s, x): 172 | pass 173 | 174 | def h(self, s, x): 175 | """ Prior drift 176 | :param s: 177 | :param x: 178 | """ 179 | self.hnfe += 1 180 | return self.h_func(t=s, y=x) 181 | 182 | def f(self, s, x): 183 | """Posterior drift. 184 | :param s: 185 | :param x: 186 | """ 187 | self.fnfe += 1 188 | return self.f_func(t=s, y=x) 189 | 190 | def g(self, s, x): 191 | """Diffusion. 192 | :param s: 193 | :param x: 194 | """ 195 | self.gnfe += 1 196 | return self.g_func(t=s, y=x) 197 | 198 | def f_aug(self, t, y): # Drift for augmented dynamics with logqp term. 199 | y = y[:,:-1] 200 | f, g, h = self.f(t, y), self.g(t, y), self.h(t, y) 201 | u = _stable_division(f - h, g) 202 | f_logqp = .5 * (u ** 2).sum(dim=1, keepdim=True) 203 | return torch.cat([f, f_logqp], dim=1) 204 | 205 | def g_aug(self, t, y): # Diffusion for augmented dynamics with logqp term. 206 | y = y[:,:-1] 207 | g = self.g(t, y) 208 | g_logqp = torch.zeros(y.size(0),1).to(y) 209 | return torch.cat([g, g_logqp], dim=1) 210 | 211 | def _stable_division(a, b, epsilon=1e-7): 212 | b = torch.where(b.abs().detach() > epsilon, b, torch.full_like(b, fill_value=epsilon) * b.sign()) 213 | return a / b -------------------------------------------------------------------------------- /models/decoders/dec_hivt_nusargo_sde.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import torchsde 9 | 10 | from models.utils.util import init_weights 11 | from torchsde import sdeint 12 | 13 | 14 | class SDEDecoder(nn.Module): 15 | 16 | def __init__(self, 17 | **kwargs) -> None: 18 | super(SDEDecoder, self).__init__() 19 | 20 | for key, value in kwargs.items(): 21 | setattr(self, key, value) 22 | 23 | self.input_size = self.global_channels 24 | self.hidden_size = self.local_channels 25 | 26 | self.aggr_embed = nn.Sequential( 27 | nn.Linear(self.input_size + self.hidden_size, self.hidden_size), 28 | nn.LayerNorm(self.hidden_size), 29 | nn.ReLU(inplace=True)) 30 | 31 | # ode_func_netD = create_net(self.hidden_size, self.hidden_size, 32 | # n_layers=self.ode_func_layers, 33 | # n_units=self.hidden_size, 34 | # nonlinear=nn.Tanh) 35 | 36 | # gen_ode_func = ODEFunc(ode_func_net=ode_func_netD) 37 | 38 | # self.diffeq_solver = DiffeqSolver(gen_ode_func, 39 | # 'euler', 40 | # odeint_rtol=self.rtol, 41 | # odeint_atol=self.atol) 42 | 43 | sigma, theta, mu = 0.5, 1.0, 0.0 44 | post_drift = FFunc(self.hidden_size) 45 | prior_drift = HFunc(theta=theta, mu=mu) 46 | diffusion= GFunc(self.hidden_size, sigma=sigma) 47 | self.lsde_func = LSDEFunc(f=post_drift, g=diffusion, h=prior_drift, embed_dim=self.hidden_size) 48 | self.lsde_func.noise_type, self.lsde_func.sde_type = 'diagonal', 'ito' 49 | 50 | self.decoder = nn.Sequential( 51 | nn.Linear(self.hidden_size, self.hidden_size), 52 | nn.LayerNorm(self.hidden_size), 53 | nn.ReLU(inplace=True), 54 | nn.Linear(self.hidden_size, 2)) 55 | 56 | if self.uncertain: 57 | self.scale = nn.Sequential( 58 | nn.Linear(self.hidden_size, self.hidden_size), 59 | nn.LayerNorm(self.hidden_size), 60 | nn.ReLU(inplace=True), 61 | nn.Linear(self.hidden_size, 2)) 62 | 63 | self.pi = nn.Sequential( 64 | nn.Linear(self.hidden_size + self.input_size, self.hidden_size), 65 | nn.LayerNorm(self.hidden_size), 66 | nn.ReLU(inplace=True), 67 | nn.Linear(self.hidden_size, 1)) 68 | 69 | self.hidden = nn.Parameter(torch.Tensor(self.hidden_size)) 70 | nn.init.normal_(self.hidden, mean=0., std=.02) 71 | 72 | self.ts_pred = torch.linspace(0, self.max_fut_t, self.future_steps+1) 73 | # self.tstp, self.tstp_mask = self.interp_timesteps(self.ts_to_predict, self.min_stepsize, return_mask=True) 74 | 75 | self.apply(init_weights) 76 | 77 | def forward(self, 78 | data, 79 | local_embed: torch.Tensor, 80 | global_embed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 81 | 82 | loc_emb = self.aggr_embed(torch.cat((global_embed, local_embed.expand(self.num_modes, *local_embed.shape)), dim=-1)) 83 | 84 | ### Diff solver 풀어씀 ### (sol_y = self.diffeq_solver(loc_emb, self.ts_pred)) 85 | _, num_actors, _ = loc_emb.shape 86 | hidden_0 = loc_emb.view(self.num_modes*num_actors, self.hidden_size) 87 | 88 | sol_y = sdeint(self.lsde_func, hidden_0, self.ts_pred, dt=self.min_stepsize, dt_min=self.min_stepsize, rtol=self.rtol, atol=self.atol, method=self.method)[1:].permute(1,0,2) 89 | 90 | ########################## 91 | 92 | 93 | pi = self.pi(torch.cat((local_embed.expand(self.num_modes, *local_embed.shape), 94 | global_embed), dim=-1)).squeeze(-1).t() 95 | loc = self.decoder(sol_y).view(self.num_modes, num_actors, self.future_steps, 2) # [F, N, H, 2] 96 | if self.uncertain: 97 | scale = F.elu_(self.scale(sol_y), alpha=1.0).view(self.num_modes, -1, self.future_steps, 2) + 1.0 98 | scale = scale + self.min_scale # [F, N, H, 2] 99 | out = {'loc': torch.cat((loc, scale), dim=-1), 'pi': pi,} # [F, N, H, 4], [N, F] 100 | 101 | else: 102 | out = {'loc': loc, 'pi': pi} # [F, N, H, 2], [N, F] 103 | 104 | out['reg_mask'] = ~data['padding_mask'][:,-self.future_steps:] 105 | return out 106 | 107 | class FFunc(nn.Module): 108 | """Posterior drift.""" 109 | def __init__(self, embed_dim): 110 | super(FFunc, self).__init__() 111 | self.net = nn.Sequential( 112 | nn.Linear(embed_dim+2, embed_dim), 113 | nn.Tanh(), 114 | nn.Linear(embed_dim, embed_dim), 115 | nn.Tanh(), 116 | nn.Linear(embed_dim, embed_dim) 117 | ) 118 | 119 | def forward(self, t, y): 120 | # if t.dim() == 0: 121 | # t = float(t) * torch.ones_like(y) 122 | # # Positional encoding in transformers; must use `t`, since the posterior is likely inhomogeneous. 123 | # inp = torch.cat((torch.sin(t), torch.cos(t), y), dim=-1) 124 | _t = torch.ones(y.size(0), 1) * float(t) 125 | _t = _t.to(y) 126 | inp = torch.cat((y,torch.sin(_t), torch.cos(_t)), dim=-1) 127 | return self.net(inp) 128 | 129 | 130 | class HFunc(nn.Module): 131 | """Prior drift""" 132 | def __init__(self, theta=1.0, mu=0.0): 133 | super(HFunc, self).__init__() 134 | self.theta = nn.Parameter(torch.tensor([[theta]]), requires_grad=False) 135 | self.mu = nn.Parameter(torch.tensor([[mu]]), requires_grad=False) 136 | 137 | def forward(self, t, y): 138 | return self.theta * (self.mu - y) 139 | 140 | 141 | class GFunc(nn.Module): 142 | """Diffusion""" 143 | def __init__(self, embed_dim, sigma=0.5): 144 | super(GFunc, self).__init__() 145 | # self.sigma = nn.Parameter(torch.tensor([[sigma]]), requires_grad=False) 146 | self.net = nn.Sequential( 147 | nn.Linear(embed_dim+2, embed_dim), 148 | nn.Tanh(), 149 | nn.Linear(embed_dim, embed_dim), 150 | nn.Tanh(), 151 | nn.Linear(embed_dim, 1) 152 | ) 153 | 154 | def forward(self, t, y): 155 | _t = torch.ones(y.size(0), 1) * float(t) 156 | _t = _t.to(y) 157 | out = self.net(torch.cat((y, torch.sin(_t), torch.cos(_t)), dim=-1)) 158 | return torch.sigmoid(out) 159 | 160 | class LSDEFunc(torchsde.SDEIto): 161 | def __init__(self, f, g, h, embed_dim, order=1): 162 | super().__init__(noise_type="diagonal") 163 | self.order, self.intloss, self.sensitivity = order, None, None 164 | self.f_func, self.g_func, self.h_func = f, g, h 165 | self.fnfe, self.gnfe, self.hnfe = 0, 0, 0 166 | self.embed_dim = embed_dim 167 | 168 | 169 | def forward(self, s, x): 170 | pass 171 | 172 | def h(self, s, x): 173 | """ Prior drift 174 | :param s: 175 | :param x: 176 | """ 177 | self.hnfe += 1 178 | return self.h_func(t=s, y=x) 179 | 180 | def f(self, s, x): 181 | """Posterior drift. 182 | :param s: 183 | :param x: 184 | """ 185 | self.fnfe += 1 186 | return self.f_func(t=s, y=x) 187 | 188 | def g(self, s, x): 189 | """Diffusion. 190 | :param s: 191 | :param x: 192 | """ 193 | self.gnfe += 1 194 | out = self.g_func(t=s, y=x).repeat(1,self.embed_dim) 195 | return out -------------------------------------------------------------------------------- /models/utils/ode_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import os 3 | import shutil 4 | import time 5 | import torch 6 | 7 | from torchdiffeq import odeint as odeint 8 | 9 | ##################################################################################################### 10 | class ODE_T_Func(nn.Module): 11 | def __init__(self, ode_func_net): 12 | """ 13 | ode_func_net: neural net that used to transform hidden state in ode 14 | """ 15 | super(ODE_T_Func, self).__init__() 16 | self.gradient_net = ode_func_net 17 | 18 | def forward(self, t_local, y, backwards = False): 19 | """ 20 | Perform one step in solving ODE. Given current data point y and 21 | current time point t_local, returns gradient dy/dt at this time point 22 | 23 | t_local: current time point 24 | y: value at the current time point 25 | """ 26 | grad = self.get_ode_gradient_nn(t_local, y) 27 | if backwards: 28 | grad = -grad 29 | return grad 30 | 31 | def get_ode_gradient_nn(self, t_local, y): 32 | return self.gradient_net(y, t_local) 33 | 34 | def sample_next_point_from_prior(self, t_local, y): 35 | """ 36 | t_local: current time point 37 | y: value at the current time point 38 | """ 39 | return self.get_ode_gradient_nn(t_local, y) 40 | 41 | class ODEFunc(nn.Module): 42 | def __init__(self, ode_func_net): 43 | """ 44 | ode_func_net: neural net that used to transform hidden state in ode 45 | """ 46 | super(ODEFunc, self).__init__() 47 | self.gradient_net = ode_func_net 48 | 49 | def forward(self, t_local, y, backwards = False): 50 | """ 51 | Perform one step in solving ODE. Given current data point y and 52 | current time point t_local, returns gradient dy/dt at this time point 53 | 54 | t_local: current time point 55 | y: value at the current time point 56 | """ 57 | grad = self.get_ode_gradient_nn(t_local, y) 58 | if backwards: 59 | grad = -grad 60 | return grad 61 | 62 | def get_ode_gradient_nn(self, t_local, y): 63 | return self.gradient_net(y) 64 | 65 | def sample_next_point_from_prior(self, t_local, y): 66 | """ 67 | t_local: current time point 68 | y: value at the current time point 69 | """ 70 | return self.get_ode_gradient_nn(t_local, y) 71 | 72 | 73 | class DiffeqSolver(nn.Module): 74 | def __init__(self, ode_func, method, odeint_rtol=1e-4, odeint_atol=1e-5): 75 | super(DiffeqSolver, self).__init__() 76 | 77 | self.ode_method = method 78 | self.ode_func = ode_func 79 | 80 | self.odeint_rtol = odeint_rtol 81 | self.odeint_atol = odeint_atol 82 | 83 | def forward(self, first_point, time_steps_to_predict): 84 | """ 85 | Decode the trajectory through ODE Solver. 86 | """ 87 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 88 | 89 | pred_y = odeint(self.ode_func, first_point, time_steps_to_predict, 90 | rtol = self.odeint_rtol, atol = self.odeint_atol, method = self.ode_method) 91 | pred_y = pred_y.permute(1,2,0,3) 92 | 93 | # assert(torch.mean(pred_y[:, :, 0, :] - first_point) < 0.001) 94 | assert(pred_y.size()[0] == n_traj_samples) 95 | assert(pred_y.size()[1] == n_traj) 96 | 97 | return pred_y 98 | 99 | def sample_traj_from_prior(self, starting_point_enc, time_steps_to_predict): 100 | """ 101 | Decode the trajectory through ODE Solver using samples from the prior 102 | time_steps_to_predict: time steps at which we want to sample the new trajectory 103 | """ 104 | func = self.ode_func.sample_next_point_from_prior 105 | 106 | pred_y = odeint(func, starting_point_enc, time_steps_to_predict, 107 | rtol=self.odeint_rtol, atol=self.odeint_atol, method = self.ode_method) 108 | pred_y = pred_y.permute(1,2,0,3) 109 | return pred_y 110 | 111 | class GRU_Unit(nn.Module): 112 | def __init__(self, latent_dim, input_dim, n_units=100): 113 | super(GRU_Unit, self).__init__() 114 | 115 | self.update_gate = nn.Sequential( 116 | nn.Linear(latent_dim + input_dim, n_units), 117 | nn.Tanh(), 118 | nn.Linear(n_units, latent_dim), 119 | nn.Sigmoid()) 120 | init_network_weights(self.update_gate) 121 | 122 | self.reset_gate = nn.Sequential( 123 | nn.Linear(latent_dim + input_dim, n_units), 124 | nn.Tanh(), 125 | nn.Linear(n_units, latent_dim), 126 | nn.Sigmoid()) 127 | init_network_weights(self.reset_gate) 128 | 129 | self.new_state_net = nn.Sequential( 130 | nn.Linear(latent_dim + input_dim, n_units), 131 | nn.Tanh(), 132 | nn.Linear(n_units, latent_dim)) 133 | init_network_weights(self.new_state_net) 134 | 135 | 136 | def forward(self, h_cur, input_tensor, mask): 137 | y_concat = torch.cat([h_cur, input_tensor], -1) 138 | 139 | update_gate = self.update_gate(y_concat) 140 | reset_gate = self.reset_gate(y_concat) 141 | 142 | combined = torch.cat([input_tensor, reset_gate * h_cur], dim=1) 143 | new_state = self.new_state_net(combined) 144 | 145 | h_next = (1 - update_gate) * new_state + update_gate * h_cur 146 | 147 | # mask = (torch.sum(mask, -1, keepdim=True) > 0).float() 148 | # mask = mask.unsqueeze(-1) 149 | 150 | h_next = mask.unsqueeze(-1) * h_next + ~mask.unsqueeze(-1) * h_cur 151 | 152 | return h_next 153 | 154 | class GRU_Unit2(nn.Module): 155 | def __init__(self, latent_dim, input_dim, n_units=100): 156 | super(GRU_Unit, self).__init__() 157 | 158 | self.update_gate = nn.Sequential( 159 | nn.Linear(latent_dim + input_dim, n_units), 160 | nn.Tanh(), 161 | nn.Linear(n_units, latent_dim), 162 | nn.Sigmoid()) 163 | init_network_weights(self.update_gate) 164 | 165 | self.reset_gate = nn.Sequential( 166 | nn.Linear(latent_dim + input_dim, n_units), 167 | nn.Tanh(), 168 | nn.Linear(n_units, latent_dim), 169 | nn.Sigmoid()) 170 | init_network_weights(self.reset_gate) 171 | 172 | self.new_state_net = nn.Sequential( 173 | nn.Linear(latent_dim + input_dim, n_units), 174 | nn.Tanh(), 175 | nn.Linear(n_units, latent_dim)) 176 | init_network_weights(self.new_state_net) 177 | 178 | 179 | def forward(self, h_cur, input_tensor, mask): 180 | y_concat = torch.cat([h_cur, input_tensor], -1) 181 | 182 | update_gate = self.update_gate(y_concat) 183 | reset_gate = self.reset_gate(y_concat) 184 | 185 | combined = torch.cat([input_tensor, reset_gate * h_cur], dim=1) 186 | new_state = self.new_state_net(combined) 187 | 188 | h_next = (1 - update_gate) * new_state + update_gate * h_cur 189 | 190 | # mask = (torch.sum(mask, -1, keepdim=True) > 0).float() 191 | # mask = mask.unsqueeze(-1) 192 | 193 | h_next = mask.unsqueeze(-1) * h_next + ~mask.unsqueeze(-1) * h_cur 194 | 195 | return h_next 196 | 197 | def get_timesteps(dataset): 198 | if dataset == 'Argoverse': 199 | ref_step = 19 200 | past_t, future_t = 2, 3 201 | t_res = 10 202 | elif dataset == 'nuScenes': 203 | ref_step = 4 204 | past_t, future_t = 2, 6 205 | t_res = 2 206 | 207 | timesteps = torch.arange(0,past_t+future_t, 1/t_res) - past_t + 1/t_res 208 | timesteps[ref_step] = 0 209 | return timesteps 210 | 211 | def init_network_weights(net, std=0.1): 212 | for m in net.modules(): 213 | if isinstance(m, nn.Linear): 214 | nn.init.normal_(m.weight, mean=0, std=std) 215 | nn.init.constant_(m.bias, val=0) 216 | 217 | def create_net(n_inputs, n_outputs, n_layers=1, n_units=100, nonlinear=nn.Tanh): 218 | layers = [nn.Linear(n_inputs, n_units)] 219 | for i in range(n_layers): 220 | layers.append(nonlinear()) 221 | layers.append(nn.Linear(n_units, n_units)) 222 | 223 | layers.append(nonlinear()) 224 | layers.append(nn.Linear(n_units, n_outputs)) 225 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: trajsde 2 | channels: 3 | - pyg 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotlipy=0.7.0=py38h27cfd23_1003 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2022.10.11=h06a4308_0 13 | - certifi=2022.9.24=py38h06a4308_0 14 | - cffi=1.15.1=py38h74dc2b5_0 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - cryptography=38.0.1=py38h9ce1e76_0 17 | - cudatoolkit=11.3.1=h2bc3f7f_2 18 | - ffmpeg=4.3=hf484d3e_0 19 | - fftw=3.3.9=h27cfd23_1 20 | - freetype=2.12.1=h4a9f257_0 21 | - giflib=5.2.1=h7b6447c_0 22 | - gmp=6.2.1=h295c915_3 23 | - gnutls=3.6.15=he1e5248_0 24 | - idna=3.4=py38h06a4308_0 25 | - intel-openmp=2021.4.0=h06a4308_3561 26 | - jinja2=3.1.2=py38h06a4308_0 27 | - joblib=1.1.1=py38h06a4308_0 28 | - jpeg=9e=h7f8727e_0 29 | - lame=3.100=h7b6447c_0 30 | - lcms2=2.12=h3be6417_0 31 | - ld_impl_linux-64=2.38=h1181459_1 32 | - lerc=3.0=h295c915_0 33 | - libdeflate=1.8=h7f8727e_5 34 | - libffi=3.4.2=h295c915_4 35 | - libgcc-ng=11.2.0=h1234567_1 36 | - libgfortran-ng=11.2.0=h00389a5_1 37 | - libgfortran5=11.2.0=h1234567_1 38 | - libgomp=11.2.0=h1234567_1 39 | - libiconv=1.16=h7f8727e_2 40 | - libidn2=2.3.2=h7f8727e_0 41 | - libpng=1.6.37=hbc83047_0 42 | - libstdcxx-ng=11.2.0=h1234567_1 43 | - libtasn1=4.16.0=h27cfd23_0 44 | - libtiff=4.4.0=hecacb30_2 45 | - libunistring=0.9.10=h27cfd23_0 46 | - libwebp=1.2.4=h11a3e52_0 47 | - libwebp-base=1.2.4=h5eee18b_0 48 | - lz4-c=1.9.3=h295c915_1 49 | - markupsafe=2.1.1=py38h7f8727e_0 50 | - mkl=2021.4.0=h06a4308_640 51 | - mkl-service=2.4.0=py38h7f8727e_0 52 | - mkl_fft=1.3.1=py38hd3c417c_0 53 | - mkl_random=1.2.2=py38h51133e4_0 54 | - ncurses=6.3=h5eee18b_3 55 | - nettle=3.7.3=hbbd107a_1 56 | - openh264=2.1.1=h4ff587b_0 57 | - openssl=1.1.1s=h7f8727e_0 58 | - pillow=9.2.0=py38hace64e9_1 59 | - pip=22.2.2=py38h06a4308_0 60 | - pycparser=2.21=pyhd3eb1b0_0 61 | - pyopenssl=22.0.0=pyhd3eb1b0_0 62 | - pyparsing=3.0.9=py38h06a4308_0 63 | - pysocks=1.7.1=py38h06a4308_0 64 | - python=3.8.15=h3fd9d12_0 65 | - pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0 66 | - pytorch-cluster=1.6.0=py38_torch_1.12.0_cu113 67 | - pytorch-mutex=1.0=cuda 68 | - pytorch-scatter=2.0.9=py38_torch_1.12.0_cu113 69 | - pytorch-sparse=0.6.15=py38_torch_1.12.0_cu113 70 | - readline=8.2=h5eee18b_0 71 | - requests=2.28.1=py38h06a4308_0 72 | - scikit-learn=1.1.3=py38h6a678d5_0 73 | - six=1.16.0=pyhd3eb1b0_1 74 | - sqlite=3.39.3=h5082296_0 75 | - tk=8.6.12=h1ccaba5_0 76 | - torchvision=0.13.1=py38_cu113 77 | - tqdm=4.64.1=py38h06a4308_0 78 | - typing_extensions=4.3.0=py38h06a4308_0 79 | - urllib3=1.26.12=py38h06a4308_0 80 | - wheel=0.37.1=pyhd3eb1b0_0 81 | - xz=5.2.6=h5eee18b_0 82 | - zlib=1.2.13=h5eee18b_0 83 | - zstd=1.5.2=ha4553b6_0 84 | - pip: 85 | - absl-py==1.3.0 86 | - aiohttp==3.8.3 87 | - aiosignal==1.3.1 88 | - alembic==1.11.1 89 | - algebra==1.2.1 90 | - antlr4-python3-runtime==4.8 91 | - anyio==3.6.2 92 | - argcomplete==3.0.8 93 | - argon2-cffi==21.3.0 94 | - argon2-cffi-bindings==21.2.0 95 | # - argoverse==1.1.0 96 | - asttokens==2.1.0 97 | - astunparse==1.6.3 98 | - async-timeout==4.0.2 99 | - attrs==22.1.0 100 | - autopage==0.5.1 101 | - av==10.0.0 102 | - av2==0.2.1 103 | - ax-platform==0.3.2 104 | - backcall==0.2.0 105 | - backends==1.5.3 106 | - backends-matrix==1.3.0 107 | - bayesian-optimization==1.4.3 108 | - beartype==0.12.0 109 | - beautifulsoup4==4.11.1 110 | - bleach==5.0.1 111 | - boltons==23.0.0 112 | - botorch==0.8.5 113 | - cachetools==5.2.0 114 | - cftime==1.6.2 115 | - click==8.0.4 116 | - cliff==4.3.0 117 | - cloudpickle==2.2.1 118 | - cmaes==0.9.1 119 | - cmd2==2.4.3 120 | - colorama==0.4.6 121 | - colorlog==6.7.0 122 | - colour==0.1.5 123 | - contourpy==1.0.6 124 | - crc32c==2.3.post0 125 | - cycler==0.11.0 126 | - cython==0.29.33 127 | - debugpy==1.6.3 128 | - decorator==4.4.2 129 | - defusedxml==0.7.1 130 | - descartes==1.1.0 131 | - distlib==0.3.6 132 | - dragonfly-opt==0.1.7 133 | - entrypoints==0.4 134 | - executing==1.2.0 135 | - fastjsonschema==2.16.2 136 | - fdm==0.4.1 137 | - filelock==3.8.0 138 | - fire==0.4.0 139 | - flaml==1.2.3 140 | - fonttools==4.38.0 141 | - frozenlist==1.3.3 142 | - fsspec==2022.11.0 143 | - future==0.18.2 144 | - gast==0.3.3 145 | - google-auth==2.14.1 146 | - google-auth-oauthlib==0.4.6 147 | - google-pasta==0.2.0 148 | - gpytorch==1.10 149 | - greenlet==2.0.2 150 | - grpcio==1.50.0 151 | - gviz-api==1.10.0 152 | - h5py==2.10.0 153 | - hydra-core==1.1.0 154 | - hyperopt==0.2.7 155 | - imageio==2.22.4 156 | - importlib-metadata==5.0.0 157 | - importlib-resources==5.10.0 158 | - ipykernel==6.17.1 159 | - ipython==8.6.0 160 | - ipython-genutils==0.2.0 161 | - ipywidgets==8.0.2 162 | - jedi==0.18.1 163 | - jsonschema==4.17.0 164 | - jupyter==1.0.0 165 | - jupyter-client==7.4.7 166 | - jupyter-console==6.4.4 167 | - jupyter-core==5.0.0 168 | - jupyter-server==1.23.2 169 | - jupyterlab-pygments==0.2.2 170 | - jupyterlab-widgets==3.0.3 171 | - keras-preprocessing==1.1.2 172 | - kiwisolver==1.4.4 173 | # - l5kit==1.5.0 174 | - lapsolver==1.1.0 175 | - lightgbm==3.3.5 176 | - lightning-utilities==0.3.0 177 | - linear-operator==0.4.0 178 | - llvmlite==0.39.1 179 | - mako==1.2.4 180 | - markdown==3.4.1 181 | - markdown-it-py==2.2.0 182 | - matplotlib==3.6.2 183 | - matplotlib-inline==0.1.6 184 | - mdurl==0.1.2 185 | - mistune==2.0.4 186 | - mlkernels==0.4.0 187 | - motmetrics==1.1.3 188 | - mpmath==1.3.0 189 | - msgpack==1.0.4 190 | - multidict==6.0.2 191 | - multipledispatch==0.6.0 192 | - nbclassic==0.4.8 193 | - nbclient==0.7.0 194 | - nbconvert==7.2.5 195 | - nbformat==5.7.0 196 | - nest-asyncio==1.5.6 197 | - netcdf4==1.6.3 198 | - networkx==3.0 199 | - notebook==6.5.2 200 | - notebook-shim==0.2.2 201 | - nox==2023.4.22 202 | - numba==0.56.4 203 | # - numpy==1.23.5 204 | - nuscenes-devkit==1.1.9 205 | - oauthlib==3.2.2 206 | - omegaconf==2.1.0 207 | - opencv-python==4.6.0.66 208 | - opt-einsum==3.3.0 209 | - optuna==2.8.0 210 | - packaging==21.3 211 | - pandas==1.3.5 212 | - pandocfilters==1.5.0 213 | - parmap==1.6.0 214 | - parso==0.8.3 215 | - pbr==5.11.1 216 | - pexpect==4.8.0 217 | - pickleshare==0.7.5 218 | - pkgutil-resolve-name==1.3.10 219 | - platformdirs==2.5.4 220 | - plotly==5.14.1 221 | - plum-dispatch==2.0.1 222 | - polars==0.15.2 223 | - prettytable==3.7.0 224 | - prometheus-client==0.15.0 225 | - prompt-toolkit==3.0.32 226 | - protobuf==3.20.1 227 | - psutil==5.9.4 228 | - ptyprocess==0.7.0 229 | - pure-eval==0.2.2 230 | - py4j==0.10.9.7 231 | - pyarrow==11.0.0 232 | - pyasn1==0.4.8 233 | - pyasn1-modules==0.2.8 234 | - pycocotools==2.0.6 235 | - pydeprecate==0.3.1 236 | # - pyg-lib==0.1.0+pt112cu113 237 | - pygments==2.13.0 238 | - pyntcloud==0.3.1 239 | - pyperclip==1.8.2 240 | - pyproj==3.5.0 241 | - pyquaternion==0.9.9 242 | - pyro-api==0.1.2 243 | - pyro-ppl==1.8.4 244 | - pyrsistent==0.19.2 245 | - python-dateutil==2.8.2 246 | - python-slugify==8.0.1 247 | - pytorch-lightning==1.6.5 248 | - pytz==2022.6 249 | - pyyaml==6.0 250 | - pyzmq==24.0.1 251 | - qtconsole==5.4.0 252 | - qtpy==2.3.0 253 | - ray==2.1.0 254 | - ray-lightning==0.3.0 255 | - requests-oauthlib==1.3.1 256 | - rich==13.3.5 257 | - rsa==4.9 258 | # - scipy==1.4.1 259 | - seaborn==0.12.2 260 | - send2trash==1.8.0 261 | - setuptools==59.5.0 262 | - shapely==1.8.5.post1 263 | # - sklearn==0.0.post1 264 | - sniffio==1.3.0 265 | - soupsieve==2.3.2.post1 266 | - sqlalchemy==2.0.13 267 | - stack-data==0.6.1 268 | - stevedore==5.1.0 269 | - stheno==1.4.1 270 | - sympy==1.11.1 271 | - tabulate==0.9.0 272 | - tenacity==8.2.2 273 | - tensorboard==2.11.0 274 | - tensorboard-data-server==0.6.1 275 | - tensorboard-plugin-profile==2.13.0 276 | - tensorboard-plugin-wit==1.8.1 277 | - tensorboardx==2.6 278 | # - tensorflow-estimator==2.3.0 279 | # - tensorflow-gpu==2.3.0 280 | - termcolor==2.1.0 281 | - terminado==0.17.0 282 | - text-unidecode==1.3 283 | - tfrecord==1.14.1 284 | - threadpoolctl==3.1.0 285 | - tinycss2==1.2.1 286 | - torch-geometric==2.2.0 287 | - torch-geometric-temporal==0.54.0 288 | - torch-tb-profiler==0.4.0 289 | - torchcde==0.2.5 290 | - torchdiffeq==0.2.3 291 | - torchdyn==1.0.4 292 | - torchmetrics==0.10.3 293 | - torchsde==0.2.5 294 | - tornado==6.2 295 | - traitlets==5.5.0 296 | - trampoline==0.1.2 297 | - typeguard==2.13.3 298 | - varz==0.8.1 299 | - virtualenv==20.16.7 300 | - waymo-open-dataset-tf-2-3-0==1.3.1 301 | - wbml==0.4.1 302 | - wcwidth==0.2.5 303 | - webencodings==0.5.1 304 | - websocket-client==1.4.2 305 | - werkzeug==2.2.2 306 | - widgetsnbextension==4.0.3 307 | - wrapt==1.15.0 308 | - xarray==2023.1.0 309 | - xgboost==1.7.5 310 | - yarl==1.8.1 311 | - zipp==3.10.0 312 | prefix: /home/user/miniconda3/envs/jw 313 | -------------------------------------------------------------------------------- /models/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import List, Optional, Tuple 15 | 16 | import torch 17 | import torch.nn as nn 18 | from torch_geometric.data import Data 19 | 20 | 21 | class TemporalData(Data): 22 | 23 | def __init__(self, 24 | x: Optional[torch.Tensor] = None, 25 | positions: Optional[torch.Tensor] = None, 26 | states: Optional[torch.Tensor] = None, 27 | category: Optional[torch.Tensor] = None, 28 | edge_index: Optional[torch.Tensor] = None, 29 | edge_attrs: Optional[List[torch.Tensor]] = None, 30 | y: Optional[torch.Tensor] = None, 31 | num_nodes: Optional[int] = None, 32 | padding_mask: Optional[torch.Tensor] = None, 33 | bos_mask: Optional[torch.Tensor] = None, 34 | rotate_angles: Optional[torch.Tensor] = None, 35 | lane_positions: Optional[torch.Tensor] = None, 36 | lane_vectors: Optional[torch.Tensor] = None, 37 | lane_rotate_angles: Optional[torch.Tensor] = None, 38 | lane_edge_index: Optional[torch.Tensor] = None, 39 | lane_edge_type: Optional[torch.Tensor] = None, 40 | lane_paddings: Optional[torch.Tensor] = None, 41 | lane_lengths: Optional[torch.Tensor] = None, 42 | lane_actor_index: Optional[torch.Tensor] = None, 43 | lane_actor_vectors: Optional[torch.Tensor] = None, 44 | lane_edge_index2_succ: Optional[torch.Tensor] = None, 45 | lane_edge_index2_pred: Optional[torch.Tensor] = None, 46 | lane_edge_index2_neigh: Optional[torch.Tensor] = None, 47 | goal_idcs: Optional[torch.Tensor] = None, 48 | has_goal: Optional[torch.Tensor] = None, 49 | seq_id: Optional[int] = None, 50 | **kwargs) -> None: 51 | if x is None: 52 | super(TemporalData, self).__init__() 53 | return 54 | super(TemporalData, self).__init__(x=x, positions=positions, states=states, category=category, edge_index=edge_index, y=y, num_nodes=num_nodes, 55 | padding_mask=padding_mask, bos_mask=bos_mask, rotate_angles=rotate_angles, 56 | lane_positions=lane_positions, lane_vectors=lane_vectors, lane_rotate_angles=lane_rotate_angles, 57 | lane_edge_index=lane_edge_index, lane_edge_type=lane_edge_type, 58 | lane_paddings=lane_paddings, lane_lengths=lane_lengths, 59 | lane_actor_index=lane_actor_index, lane_actor_vectors=lane_actor_vectors, 60 | lane_edge_index2_succ=lane_edge_index2_succ, lane_edge_index2_pred=lane_edge_index2_pred, lane_edge_index2_neigh=lane_edge_index2_neigh, 61 | goal_idcs=goal_idcs, has_goal=has_goal, 62 | seq_id=seq_id, **kwargs) 63 | if edge_attrs is not None: 64 | for t in range(self.x.size(1)): 65 | self[f'edge_attr_{t}'] = edge_attrs[t] 66 | 67 | def __inc__(self, key, value, *args, **kwargs): 68 | if key == 'lane_actor_index': 69 | return torch.tensor([[self['lane_vectors'].size(0)], [self.num_nodes]]) 70 | elif key == 'lane_edge_index': 71 | return torch.tensor([[self['lane_vectors'].size(0)], [self['lane_vectors'].size(0)]]) 72 | elif 'lane_edge_index2' in key: 73 | return torch.tensor([[self['lane_actor_index'].size(1)], [self['lane_actor_index'].size(1)]]) 74 | else: 75 | return super().__inc__(key, value, *args, **kwargs) 76 | 77 | 78 | class DistanceDropEdge(object): 79 | 80 | def __init__(self, max_distance: Optional[float] = None) -> None: 81 | self.max_distance = max_distance 82 | 83 | def __call__(self, 84 | edge_index: torch.Tensor, 85 | edge_attr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 86 | if self.max_distance is None: 87 | return edge_index, edge_attr 88 | row, col = edge_index 89 | mask = torch.norm(edge_attr, p=2, dim=-1) < self.max_distance 90 | edge_index = torch.stack([row[mask], col[mask]], dim=0) 91 | edge_attr = edge_attr[mask] 92 | return edge_index, edge_attr 93 | 94 | def init_weights(m: nn.Module) -> None: 95 | if isinstance(m, nn.Linear): 96 | nn.init.xavier_uniform_(m.weight) 97 | if m.bias is not None: 98 | nn.init.zeros_(m.bias) 99 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 100 | fan_in = m.in_channels / m.groups 101 | fan_out = m.out_channels / m.groups 102 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 103 | nn.init.uniform_(m.weight, -bound, bound) 104 | if m.bias is not None: 105 | nn.init.zeros_(m.bias) 106 | elif isinstance(m, nn.Embedding): 107 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 108 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 109 | nn.init.ones_(m.weight) 110 | nn.init.zeros_(m.bias) 111 | elif isinstance(m, nn.LayerNorm): 112 | nn.init.ones_(m.weight) 113 | nn.init.zeros_(m.bias) 114 | elif isinstance(m, nn.MultiheadAttention): 115 | if m.in_proj_weight is not None: 116 | fan_in = m.embed_dim 117 | fan_out = m.embed_dim 118 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 119 | nn.init.uniform_(m.in_proj_weight, -bound, bound) 120 | else: 121 | nn.init.xavier_uniform_(m.q_proj_weight) 122 | nn.init.xavier_uniform_(m.k_proj_weight) 123 | nn.init.xavier_uniform_(m.v_proj_weight) 124 | if m.in_proj_bias is not None: 125 | nn.init.zeros_(m.in_proj_bias) 126 | nn.init.xavier_uniform_(m.out_proj.weight) 127 | if m.out_proj.bias is not None: 128 | nn.init.zeros_(m.out_proj.bias) 129 | if m.bias_k is not None: 130 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02) 131 | if m.bias_v is not None: 132 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02) 133 | elif isinstance(m, nn.LSTM): 134 | for name, param in m.named_parameters(): 135 | if 'weight_ih' in name: 136 | for ih in param.chunk(4, 0): 137 | nn.init.xavier_uniform_(ih) 138 | elif 'weight_hh' in name: 139 | for hh in param.chunk(4, 0): 140 | nn.init.orthogonal_(hh) 141 | elif 'weight_hr' in name: 142 | nn.init.xavier_uniform_(param) 143 | elif 'bias_ih' in name: 144 | nn.init.zeros_(param) 145 | elif 'bias_hh' in name: 146 | nn.init.zeros_(param) 147 | nn.init.ones_(param.chunk(4, 0)[1]) 148 | elif isinstance(m, nn.GRU): 149 | for name, param in m.named_parameters(): 150 | if 'weight_ih' in name: 151 | for ih in param.chunk(3, 0): 152 | nn.init.xavier_uniform_(ih) 153 | elif 'weight_hh' in name: 154 | for hh in param.chunk(3, 0): 155 | nn.init.orthogonal_(hh) 156 | elif 'bias_ih' in name: 157 | nn.init.zeros_(param) 158 | elif 'bias_hh' in name: 159 | nn.init.zeros_(param) 160 | 161 | class MLP(nn.Module): 162 | def __init__(self, input_dim, output_dim, num_layers, dropout_rate, reduction_factor=0.5): 163 | super(MLP, self).__init__() 164 | 165 | self.layers = nn.ModuleList() 166 | 167 | # Compute minimum feature dimension based on output dimension 168 | min_feature_dim = max(output_dim, int(input_dim/reduction_factor)) 169 | 170 | # Add input layer 171 | self.layers.append(nn.Linear(input_dim, min_feature_dim)) 172 | self.layers.append(nn.LayerNorm(min_feature_dim)) 173 | self.layers.append(nn.ReLU()) 174 | self.layers.append(nn.Dropout(dropout_rate)) 175 | 176 | # Add hidden layers 177 | for i in range(num_layers-1): 178 | feature_dim = max(output_dim, int(input_dim/(reduction_factor**(i+1)))) 179 | self.layers.append(nn.Linear(min_feature_dim, feature_dim)) 180 | self.layers.append(nn.LayerNorm(feature_dim)) 181 | self.layers.append(nn.ReLU()) 182 | self.layers.append(nn.Dropout(dropout_rate)) 183 | min_feature_dim = feature_dim 184 | 185 | # Add output layer 186 | self.layers.append(nn.Linear(min_feature_dim, output_dim)) 187 | 188 | def forward(self, x): 189 | for layer in self.layers: 190 | x = layer(x) 191 | return x -------------------------------------------------------------------------------- /models/model_base_mix_sde.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import pickle as pkl 4 | from copy import deepcopy 5 | import json 6 | import pytorch_lightning as pl 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch_geometric.nn.dense import Linear 11 | 12 | from models.utils.util import TemporalData 13 | from debug_util import viz_result_batch_base, viz_result_batch_goalpred, viz_result_batch_ood, viz_result_batch_ood_load 14 | 15 | import importlib 16 | from importlib.machinery import SourceFileLoader 17 | 18 | import matplotlib.pyplot as plt 19 | import sys 20 | 21 | 22 | class PredictionModelSDENet(pl.LightningModule): 23 | 24 | def __init__(self, 25 | **kwargs) -> None: 26 | 27 | super(PredictionModelSDENet, self).__init__() 28 | self.save_hyperparameters() 29 | 30 | for key, value in kwargs.items(): 31 | if key == 'training_specific': 32 | for k, v in value.items(): 33 | setattr(self, k, v) 34 | elif key == 'model_specific': 35 | for k, v in value['kwargs'].items(): 36 | setattr(self, k, v) 37 | 38 | enc_args, agg_args, dec_args = kwargs['encoder'], kwargs['aggregator'], kwargs['decoder'] 39 | encoder = getattr(SourceFileLoader(enc_args['module_name'], enc_args['file_path']).load_module(enc_args['module_name']), enc_args['module_name']) 40 | aggregator = getattr(SourceFileLoader(agg_args['module_name'], agg_args['file_path']).load_module(agg_args['module_name']), agg_args['module_name']) 41 | decoder = getattr(SourceFileLoader(dec_args['module_name'], dec_args['file_path']).load_module(dec_args['module_name']), dec_args['module_name']) 42 | 43 | self.encoder = encoder(**dict(kwargs['encoder']['kwargs'])) 44 | self.aggregator = aggregator(**dict(kwargs['aggregator']['kwargs'])) 45 | self.decoder = decoder(**dict(kwargs['decoder']['kwargs'])) 46 | 47 | self.losses = [] 48 | self.loss_names = [] 49 | for i, loss_path in enumerate(kwargs['losses']): 50 | loss_module_name = kwargs['losses_module'][i] 51 | 52 | loss = getattr(SourceFileLoader(loss_module_name, loss_path).load_module(loss_module_name), loss_module_name) 53 | loss = loss(**dict(kwargs['loss_args'][i])) 54 | self.losses.append(loss) 55 | self.loss_names.append(loss_module_name) 56 | self.loss_weights = kwargs['loss_weights'] 57 | 58 | self.metrics_tr = [] 59 | self.metrics_vl = [] 60 | self.metric_names = [] 61 | for i, metric_path in enumerate(kwargs['metrics']): 62 | metric_module_name = kwargs['metrics_module'][i] 63 | 64 | metric = getattr(SourceFileLoader(metric_module_name, metric_path).load_module(metric_module_name), metric_module_name) 65 | metric = metric(**dict(kwargs['metric_args'][i])) 66 | self.metrics_tr.append(metric) 67 | self.metrics_vl.append(deepcopy(metric)) 68 | self.metric_names.append(metric_module_name) 69 | 70 | if hasattr(self, 'stds_fn'): 71 | with open(self.stds_fn, 'rb') as f: 72 | self.stds_loaded = pkl.load(f) 73 | 74 | def forward(self, data: TemporalData): 75 | if self.rotate: 76 | rotate_mat = torch.empty(data.num_nodes, 2, 2, device=self.device) 77 | sin_vals = torch.sin(data['rotate_angles']) 78 | cos_vals = torch.cos(data['rotate_angles']) 79 | rotate_mat[:, 0, 0] = cos_vals 80 | rotate_mat[:, 0, 1] = -sin_vals 81 | rotate_mat[:, 1, 0] = sin_vals 82 | rotate_mat[:, 1, 1] = cos_vals 83 | if data.y is not None: 84 | data.y = torch.bmm(data.y, rotate_mat) 85 | data['rotate_mat'] = rotate_mat 86 | else: 87 | data['rotate_mat'] = False 88 | 89 | if hasattr(self, 'ood') and self.ood: 90 | local_embed, stds = self.encoder.forward_ood(data=data) 91 | else: 92 | local_embed, diffusions_in, diffusionts_out, in_labels, out_labels = self.encoder(data=data) 93 | 94 | global_embed = self.aggregator(data=data, local_embed=local_embed) 95 | out = self.decoder(data=data, local_embed=local_embed, global_embed=global_embed) 96 | 97 | if hasattr(self, 'ood') and self.ood: 98 | out['stds'] = stds 99 | else: 100 | out['diff_in'], out['diff_out'], out['label_in'], out['label_out'] = diffusions_in, diffusionts_out, in_labels, out_labels 101 | 102 | return out 103 | 104 | def training_step(self, data, batch_idx): 105 | output = self(data) 106 | 107 | loss = 0 108 | for lidx, lossfn in enumerate(self.losses): 109 | lossname = self.loss_names[lidx] 110 | loss_i = lossfn(data, output) 111 | loss = loss + self.loss_weights[lidx]*loss_i 112 | self.log(f'train/{lossname}', loss_i, prog_bar=True, on_step=True, on_epoch=True, batch_size=output['loc'].size(1)) 113 | self.log('lr', self.scheduler.get_lr()[0], prog_bar=False, on_step=False, on_epoch=True, batch_size=1) 114 | 115 | return loss 116 | 117 | def validation_step(self, data, batch_idx): 118 | output = self(data) 119 | 120 | y_hat_agent = output['loc'][:, data['agent_index'], :, : 2] 121 | y_agent = data.y[data['agent_index']] 122 | agent_reg_mask = output['reg_mask'][data['agent_index']] 123 | agent_source = data['source'] 124 | 125 | if not self.is_gtabs: 126 | y_hat_agent = torch.cumsum(y_hat_agent, dim=-2) 127 | y_agent = torch.cumsum(y_agent, dim=-2) 128 | 129 | for midx, metric in enumerate(self.metrics_vl): 130 | metricname = self.metric_names[midx] 131 | metric.update(y_hat_agent.detach().cpu(), y_agent.detach().cpu(), agent_reg_mask.detach().cpu(), agent_source.detach().cpu()) 132 | 133 | def test_step(self, data, batch_idx): 134 | output = self(data) 135 | 136 | if self.only_agent: 137 | self.leave_only_agent(data, output) 138 | 139 | y_hat_agent = output['loc'][:, data['agent_index'], :, : 2] 140 | if data.y is not None: y_agent = data.y[data['agent_index']] 141 | pi_agent = output['pi'][data['agent_index']] 142 | origin_agent = data['positions'][data['agent_index'], self.ref_time] 143 | agent_reg_mask = output['reg_mask'][data['agent_index']] 144 | agent_source = data['source'] 145 | 146 | if data.y is not None: 147 | for metric in self.metrics_vl: 148 | metric.update(y_hat_agent.detach().cpu(), y_agent.detach().cpu(), agent_reg_mask.detach().cpu(), agent_source.detach().cpu()) 149 | 150 | def test_epoch_end(self, outputs) -> None: 151 | 152 | ckpt_path = Path(self.trainer._ckpt_path) 153 | out_dir = os.path.join(ckpt_path.parent.parent, 'out') 154 | if not os.path.isdir(out_dir): 155 | os.mkdir(out_dir) 156 | 157 | metrics = dict() 158 | for midx, metric in enumerate(self.metrics_vl): 159 | metricname = self.metric_names[midx] 160 | metrics[metricname] = metric.compute().item() 161 | 162 | ckpt_name = ckpt_path.stem 163 | ckpt_fn = os.path.join(out_dir, f'result_{ckpt_name}.json') 164 | with open(ckpt_fn, 'w') as f: 165 | json.dump(metrics, f) 166 | 167 | 168 | @staticmethod 169 | def leave_only_agent(data, output): 170 | data.num_nodes = data.x.size(0) 171 | data.bos_mask = data.bos_mask[data['agent_index']] 172 | data.y = data.y[data['agent_index']] 173 | data.x = data.x[data['agent_index']] 174 | if 'category' in data.keys: data.category = data.category[data['agent_index']] 175 | data.positions = data.positions[data['agent_index']] 176 | data.rotate_mat = data.rotate_mat[data['agent_index']] 177 | data.rotate_angles = data.rotate_angles[data['agent_index']] 178 | data.has_goal = data.has_goal[data['agent_index']] 179 | data.padding_mask = data.padding_mask[data['agent_index']] 180 | 181 | al_agent_mask = torch.isin(data['lane_actor_index'][1], data['agent_index']) 182 | agent_has_lane = torch.isin(data['agent_index'], data['lane_actor_index'][1]) 183 | data.goal_idcs = data.goal_idcs[al_agent_mask] 184 | data.lane_actor_vectors = data.lane_actor_vectors[al_agent_mask] 185 | 186 | output['loc'] = output['loc'][:,data['agent_index']] 187 | output['pi'] = output['pi'][data['agent_index']] 188 | output['reg_mask'] = output['reg_mask'][data['agent_index']] 189 | if 'cls_mask' in output: output['cls_mask'] = output['cls_mask'][data['agent_index']] 190 | if 'goal_prob' in output: 191 | output['goal_prob'] = output['goal_prob'][al_agent_mask] 192 | if 'goal_cls_mask' in output: 193 | output['goal_cls_mask'] = output['goal_cls_mask'][al_agent_mask] 194 | 195 | data.lane_actor_index = data.lane_actor_index[:,al_agent_mask] 196 | for i, agent_i in enumerate(data['agent_index']): 197 | if agent_has_lane[i]: 198 | data.lane_actor_index[1][data.lane_actor_index[1] == agent_i] = i 199 | 200 | data.agent_index = torch.arange(data.x.size(0)).to(data.x.device) 201 | data.av_index = torch.arange(data.x.size(0)).to(data.x.device) 202 | data.batch = torch.arange(data.x.size(0)).to(data.x.device) 203 | 204 | def configure_optimizers(self): 205 | self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 206 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=self.optimizer, T_max=self.T_max, eta_min=0.0) 207 | return [self.optimizer], [self.scheduler] 208 | 209 | -------------------------------------------------------------------------------- /models/model_base_mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from copy import deepcopy 4 | import json 5 | import pickle as pkl 6 | import pytorch_lightning as pl 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch_geometric.nn.dense import Linear 11 | 12 | from models.utils.util import TemporalData 13 | from debug_util import viz_result_batch_base, viz_result_batch_goalpred, viz_result_batch_ood_load 14 | 15 | import importlib 16 | from importlib.machinery import SourceFileLoader 17 | 18 | import matplotlib.pyplot as plt 19 | import sys 20 | 21 | 22 | class PredictionModel(pl.LightningModule): 23 | 24 | def __init__(self, 25 | **kwargs) -> None: 26 | 27 | super(PredictionModel, self).__init__() 28 | self.save_hyperparameters() 29 | 30 | for key, value in kwargs.items(): 31 | if key == 'training_specific': 32 | for k, v in value.items(): 33 | setattr(self, k, v) 34 | elif key == 'model_specific': 35 | for k, v in value['kwargs'].items(): 36 | setattr(self, k, v) 37 | 38 | enc_args, agg_args, dec_args = kwargs['encoder'], kwargs['aggregator'], kwargs['decoder'] 39 | encoder = getattr(SourceFileLoader(enc_args['module_name'], enc_args['file_path']).load_module(enc_args['module_name']), enc_args['module_name']) 40 | aggregator = getattr(SourceFileLoader(agg_args['module_name'], agg_args['file_path']).load_module(agg_args['module_name']), agg_args['module_name']) 41 | decoder = getattr(SourceFileLoader(dec_args['module_name'], dec_args['file_path']).load_module(dec_args['module_name']), dec_args['module_name']) 42 | 43 | self.encoder = encoder(**dict(kwargs['encoder']['kwargs'])) 44 | self.aggregator = aggregator(**dict(kwargs['aggregator']['kwargs'])) 45 | self.decoder = decoder(**dict(kwargs['decoder']['kwargs'])) 46 | 47 | self.losses = [] 48 | self.loss_names = [] 49 | for i, loss_path in enumerate(kwargs['losses']): 50 | loss_module_name = kwargs['losses_module'][i] 51 | 52 | loss = getattr(SourceFileLoader(loss_module_name, loss_path).load_module(loss_module_name), loss_module_name) 53 | loss = loss(**dict(kwargs['loss_args'][i])) 54 | self.losses.append(loss) 55 | self.loss_names.append(loss_module_name) 56 | self.loss_weights = kwargs['loss_weights'] 57 | 58 | self.metrics_tr = [] 59 | self.metrics_vl = [] 60 | self.metric_names = [] 61 | for i, metric_path in enumerate(kwargs['metrics']): 62 | metric_module_name = kwargs['metrics_module'][i] 63 | 64 | metric = getattr(SourceFileLoader(metric_module_name, metric_path).load_module(metric_module_name), metric_module_name) 65 | metric = metric(**dict(kwargs['metric_args'][i])) 66 | self.metrics_tr.append(metric) 67 | self.metrics_vl.append(deepcopy(metric)) 68 | self.metric_names.append(metric_module_name) 69 | 70 | if hasattr(self, 'stds_fn'): 71 | with open(self.stds_fn, 'rb') as f: 72 | self.stds_loaded = pkl.load(f) 73 | 74 | 75 | def forward(self, data: TemporalData): 76 | if self.rotate: 77 | rotate_mat = torch.empty(data.num_nodes, 2, 2, device=self.device) 78 | sin_vals = torch.sin(data['rotate_angles']) 79 | cos_vals = torch.cos(data['rotate_angles']) 80 | rotate_mat[:, 0, 0] = cos_vals 81 | rotate_mat[:, 0, 1] = -sin_vals 82 | rotate_mat[:, 1, 0] = sin_vals 83 | rotate_mat[:, 1, 1] = cos_vals 84 | if data.y is not None: 85 | data.y = torch.bmm(data.y, rotate_mat) 86 | data['rotate_mat'] = rotate_mat 87 | else: 88 | data['rotate_mat'] = False 89 | 90 | local_embed = self.encoder(data=data) 91 | global_embed = self.aggregator(data=data, local_embed=local_embed) 92 | out = self.decoder(data=data, local_embed=local_embed, global_embed=global_embed) 93 | return out 94 | 95 | def training_step(self, data, batch_idx): 96 | if self.ts_drop: 97 | masking_mask = (torch.rand(data.x.size(0), self.historical_steps, device=data.x.device) > (1-self.ts_drop)).bool() 98 | masking_mask[data.bos_mask] = False 99 | masking_mask[:,-1] = False 100 | data.x[masking_mask] = 0 101 | data.padding_mask[:,:self.historical_steps] = data.padding_mask[:,:self.historical_steps] + masking_mask 102 | 103 | output = self(data) 104 | 105 | if self.only_agent: 106 | self.leave_only_agent(data, output) 107 | 108 | loss = 0 109 | for lidx, lossfn in enumerate(self.losses): 110 | lossname = self.loss_names[lidx] 111 | loss_i = lossfn(data, output) 112 | loss = loss + self.loss_weights[lidx]*loss_i 113 | self.log(f'train/{lossname}', loss_i, prog_bar=True, on_step=True, on_epoch=True, batch_size=output['loc'].size(1)) 114 | self.log('lr', self.optimizer.param_groups[0]['lr'], prog_bar=False, on_step=False, on_epoch=True, batch_size=1) 115 | 116 | return loss 117 | 118 | def validation_step(self, data, batch_idx): 119 | output = self(data) 120 | 121 | y_hat_agent = output['loc'][:, data['agent_index'], :, : 2] 122 | y_agent = data.y[data['agent_index']] 123 | agent_reg_mask = output['reg_mask'][data['agent_index']] 124 | agent_source = data['source'] 125 | 126 | for midx, metric in enumerate(self.metrics_vl): 127 | metricname = self.metric_names[midx] 128 | metric.update(y_hat_agent.detach().cpu(), y_agent.detach().cpu(), agent_reg_mask.detach().cpu(), agent_source.detach().cpu()) 129 | 130 | def test_step(self, data, batch_idx): 131 | output = self(data) 132 | 133 | if self.only_agent: 134 | self.leave_only_agent(data, output) 135 | 136 | y_hat_agent = output['loc'][:, data['agent_index'], :, : 2] 137 | if data.y is not None: y_agent = data.y[data['agent_index']] 138 | pi_agent = output['pi'][data['agent_index']] 139 | origin_agent = data['positions'][data['agent_index'], self.ref_time] 140 | agent_reg_mask = output['reg_mask'][data['agent_index']] 141 | agent_source = data['source'] 142 | 143 | if not self.is_gtabs: 144 | y_hat_agent = torch.cumsum(y_hat_agent, dim=-2) 145 | if data.y is not None: y_agent = torch.cumsum(y_agent, dim=-2) 146 | 147 | if data.y is not None: 148 | for metric in self.metrics_vl: 149 | metric.update(y_hat_agent.detach().cpu(), y_agent.detach().cpu(), agent_reg_mask.detach().cpu(), agent_source.detach().cpu()) 150 | 151 | 152 | def test_epoch_end(self, outputs) -> None: 153 | 154 | ckpt_path = Path(self.trainer._ckpt_path) 155 | out_dir = os.path.join(ckpt_path.parent.parent, 'out') 156 | if not os.path.isdir(out_dir): 157 | os.mkdir(out_dir) 158 | 159 | metrics = dict() 160 | for midx, metric in enumerate(self.metrics_vl): 161 | metricname = self.metric_names[midx] 162 | metrics[metricname] = metric.compute().item() 163 | 164 | ckpt_name = ckpt_path.stem 165 | ckpt_fn = os.path.join(out_dir, f'result_{ckpt_name}.json') 166 | with open(ckpt_fn, 'w') as f: 167 | json.dump(metrics, f) 168 | 169 | 170 | @staticmethod 171 | def leave_only_agent(data, output): 172 | data.num_nodes = data.x.size(0) 173 | data.bos_mask = data.bos_mask[data['agent_index']] 174 | data.y = data.y[data['agent_index']] 175 | data.x = data.x[data['agent_index']] 176 | if 'category' in data.keys: data.category = data.category[data['agent_index']] 177 | data.positions = data.positions[data['agent_index']] 178 | data.rotate_mat = data.rotate_mat[data['agent_index']] 179 | data.rotate_angles = data.rotate_angles[data['agent_index']] 180 | data.has_goal = data.has_goal[data['agent_index']] 181 | data.padding_mask = data.padding_mask[data['agent_index']] 182 | 183 | al_agent_mask = torch.isin(data['lane_actor_index'][1], data['agent_index']) 184 | agent_has_lane = torch.isin(data['agent_index'], data['lane_actor_index'][1]) 185 | data.goal_idcs = data.goal_idcs[al_agent_mask] 186 | data.lane_actor_vectors = data.lane_actor_vectors[al_agent_mask] 187 | 188 | output['loc'] = output['loc'][:,data['agent_index']] 189 | output['pi'] = output['pi'][data['agent_index']] 190 | output['reg_mask'] = output['reg_mask'][data['agent_index']] 191 | if 'cls_mask' in output: output['cls_mask'] = output['cls_mask'][data['agent_index']] 192 | if 'goal_prob' in output: 193 | output['goal_prob'] = output['goal_prob'][al_agent_mask] 194 | if 'goal_cls_mask' in output: 195 | output['goal_cls_mask'] = output['goal_cls_mask'][al_agent_mask] 196 | 197 | data.lane_actor_index = data.lane_actor_index[:,al_agent_mask] 198 | for i, agent_i in enumerate(data['agent_index']): 199 | if agent_has_lane[i]: 200 | data.lane_actor_index[1][data.lane_actor_index[1] == agent_i] = i 201 | 202 | data.agent_index = torch.arange(data.x.size(0)).to(data.x.device) 203 | data.av_index = torch.arange(data.x.size(0)).to(data.x.device) 204 | data.batch = torch.arange(data.x.size(0)).to(data.x.device) 205 | 206 | def configure_optimizers(self): 207 | self.optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 208 | self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=self.scheduler_step, gamma=self.scheduler_gamma) 209 | return [self.optimizer], [self.scheduler] 210 | -------------------------------------------------------------------------------- /dataset/nuScenes_Argoverse/nuScenes_Argoverse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from itertools import permutations 4 | from itertools import product 5 | from typing import Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from pyquaternion import Quaternion 10 | import torch 11 | 12 | from multiprocessing import Process 13 | from multiprocessing import Pool 14 | from itertools import repeat 15 | 16 | from nuscenes.eval.prediction.splits import get_prediction_challenge_split 17 | from scipy.spatial.distance import cdist 18 | from shapely.geometry import Point 19 | from shapely.geometry.polygon import Polygon 20 | 21 | from torch_geometric.data import Data 22 | from torch_geometric.data import Dataset 23 | from torch_geometric.data.dataset import files_exist 24 | from tqdm import tqdm 25 | import pickle as pkl 26 | from shapely.geometry import LineString, Point 27 | import random 28 | 29 | import sys 30 | sys.path.append('/mnt/ssd2/frm_lightning_backup/') 31 | from models.utils.util import TemporalData 32 | from debug_util import * 33 | 34 | SPLIT_NAME = {'nuScenes': {'train': 'train', 'val': 'val', 'test': 'val', 'mini_train': 'mini_train', 'mini_val': 'mini_val'}, 35 | 'Argoverse': {'train': 'train', 'val': 'train', 'test': 'test_obs', 'sample': 'forecasting_sample'}} 36 | DATA_SOURCE = {0: 'nuScenes', 1: 'Argoverse'} 37 | CATEGORY_INTEREST = [0, 1, 2, 3, 4, 5, 7, 8] 38 | 39 | 40 | class nuArgoDataset(Dataset): 41 | 42 | def __init__(self, 43 | split: str, 44 | nu_root, 45 | Argo_root, 46 | nu_dir, 47 | Argo_dir, 48 | spec_args=None) -> None: 49 | self._split = split 50 | self.nu_root = nu_root 51 | self.Argo_root = Argo_root 52 | self.nu_dir = nu_dir 53 | self.Argo_dir = Argo_dir 54 | 55 | for k,v in spec_args.items(): 56 | self.__setattr__(k, v) 57 | 58 | self.nu_directory = SPLIT_NAME['nuScenes'][split] 59 | self.Argo_directory = SPLIT_NAME['Argoverse'][split] 60 | 61 | nu_raw_file_names = sorted(get_prediction_challenge_split(self.nu_directory, self.nu_root)) 62 | nu_processed_file_names = [os.path.splitext(f)[0] + '.pt' for f in nu_raw_file_names] 63 | nu_processed_paths = [os.path.join(self.nu_dir, self.nu_directory, f) for f in nu_processed_file_names] 64 | 65 | _Argo_raw_file_names_fn = os.path.join(self.Argo_dir, f'raw_processed_fns_{self.Argo_directory}.pt') 66 | if os.path.isfile(_Argo_raw_file_names_fn): 67 | with open(file=_Argo_raw_file_names_fn, mode='rb') as f: 68 | Argo_raw_file_names, Argo_processed_file_names, Argo_processed_paths = pkl.load(f) 69 | else: 70 | raise FileExistsError('Argo file name file should be exist') 71 | 72 | self._raw_file_names = [] 73 | self._processed_file_names = [] 74 | self._processed_paths = [] 75 | self._sources = [] 76 | if self.nus: 77 | self._raw_file_names = self._raw_file_names + nu_raw_file_names 78 | self._processed_file_names = self._processed_file_names + nu_processed_file_names 79 | self._processed_paths = self._processed_paths + nu_processed_paths 80 | self._sources = self._sources + [0]*len(nu_raw_file_names) 81 | if self.Argo: 82 | self._raw_file_names = self._raw_file_names + Argo_raw_file_names 83 | self._processed_file_names = self._processed_file_names + Argo_processed_file_names 84 | self._processed_paths = self._processed_paths + Argo_processed_paths 85 | self._sources = self._sources + [1]*len(Argo_raw_file_names) # 0 for nuScenes, 1 for Argoverse 86 | 87 | if self.type == 'grid': 88 | self.max_past, self.max_fut = 21, 60 89 | grid_past, grid_fut = torch.zeros(self.max_past, dtype=torch.bool), torch.zeros(self.max_fut, dtype=torch.bool) 90 | 91 | self.ts_pasts, self.ts_futs = torch.linspace(-20,0,21).int(), torch.linspace(0,60,61)[1:].int() 92 | self.ts_nus_past, self.ts_nus_fut = torch.linspace(-20,0,5).int(), torch.linspace(0,60,13)[1:].int() 93 | self.ts_Argo_past, self.ts_Argo_fut = torch.linspace(-20,0,21)[1:].int(), torch.linspace(0,30,31)[1:].int() 94 | 95 | self.mask_nus_past, self.mask_nus_fut = grid_past.clone(), grid_fut.clone() 96 | self.mask_nus_past[torch.isin(self.ts_pasts, self.ts_nus_past)] = True 97 | self.mask_nus_fut[torch.isin(self.ts_futs, self.ts_nus_fut)] = True 98 | self.mask_nus_tot = torch.cat((self.mask_nus_past, self.mask_nus_fut)) 99 | 100 | self.mask_Argo_past, self.mask_Argo_fut = grid_past.clone(), grid_fut.clone() 101 | self.mask_Argo_past[torch.isin(self.ts_pasts, self.ts_Argo_past)] = True 102 | self.mask_Argo_fut[torch.isin(self.ts_futs, self.ts_Argo_fut)] = True 103 | self.mask_Argo_tot = torch.cat((self.mask_Argo_past, self.mask_Argo_fut)) 104 | 105 | self.ts_nus_past, self.ts_nus_fut = self.ts_nus_past.float(), self.ts_nus_fut.float() 106 | self.ts_nus_tot = torch.cat((self.ts_nus_past,self.ts_nus_fut)) 107 | self.ts_Argo_past, self.ts_Argo_fut = self.ts_Argo_past.float(), self.ts_Argo_fut.float() 108 | self.ts_Argo_tot = torch.cat((self.ts_Argo_past, self.ts_Argo_fut)) 109 | 110 | elif self.type == 'continuous': 111 | self.nus_ts_past, self.nus_ts_fut = torch.linspace(-2,0,5), torch.linspace(0.5,6,12) 112 | self.nus_ts_tot = torch.cat((self.nus_ts_past, self.nus_ts_fut)) 113 | self.Argo_ts_past, self.Argo_ts_fut = torch.linspace(-1.9,0,20), torch.linspace(0.1,3,30) 114 | self.Argo_ts_tot = torch.cat((self.Argo_ts_past, self.Argo_ts_fut)) 115 | self.max_past, self.max_fut = 20, 30 116 | 117 | super(nuArgoDataset, self).__init__() 118 | 119 | def _download(self): 120 | return 121 | 122 | def _process(self): 123 | return 124 | 125 | @property 126 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 127 | return self._raw_file_names 128 | 129 | @property 130 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 131 | return self._processed_file_names 132 | 133 | @property 134 | def processed_paths(self) -> List[str]: 135 | return self._processed_paths 136 | 137 | def len(self) -> int: 138 | return len(self._raw_file_names) 139 | 140 | def get(self, idx) -> Data: 141 | data = torch.load(self.processed_paths[idx]) 142 | 143 | data['source'] = self._sources[idx] 144 | data['seq_id'] = str(data['seq_id']) 145 | 146 | if 'traffic_controls' in data.keys: del data.traffic_controls 147 | if 'turn_directions' in data.keys: del data.turn_directions 148 | if 'is_intersections' in data.keys: del data.is_intersections 149 | if 'city' in data.keys: del data.city 150 | if 'lane_rotate_angles' in data.keys: del data.lane_rotate_angles 151 | if 'lane_edge_index' in data.keys: del data.lane_edge_index 152 | if 'lane_edge_type' in data.keys: del data.lane_edge_type 153 | if 'lane_edge_index2_succ' in data.keys: del data.lane_edge_index2_succ 154 | if 'lane_edge_index2_pred' in data.keys: del data.lane_edge_index2_pred 155 | if 'lane_edge_index2_neigh' in data.keys: del data.lane_edge_index2_neigh 156 | # if isinstance(data.agent_index, torch.Tensor): data.agent_index = data.agent_index.item() 157 | # if isinstance(data.av_index, torch.Tensor): data.av_index = data.av_index.item() 158 | data.agent_index = torch.tensor(data.agent_index, dtype=torch.long) 159 | data.av_index = torch.tensor(data.av_index, dtype=torch.long) 160 | 161 | if data['source'] == 0: 162 | data.x = data.x/5 163 | 164 | if not self.is_gtabs: 165 | y_pad0 = torch.cat((torch.zeros(data.y.size(0),1,2), data.y), dim=1) 166 | data.y = y_pad0[:,1:] - y_pad0[:,:-1] 167 | if data['source'] == 0: 168 | data.y = data.y/5 169 | 170 | if 'category' in data.keys: 171 | data['category'] = data['category'].float().to(data['x'].device) 172 | data.padding_mask[~torch.isin(data['category'], torch.tensor(CATEGORY_INTEREST)), -self.max_fut:] = True 173 | del data.category 174 | 175 | if self.type == 'grid': 176 | _x = torch.zeros((data.x.size(0), self.max_past, data.x.size(-1)), dtype=data.x.dtype) 177 | if data.y is not None: 178 | _y = torch.zeros((data.y.size(0), self.max_fut, data.y.size(-1)), dtype=data.y.dtype) 179 | else: 180 | _y = data.y 181 | _bos_mask = torch.zeros((data.x.size(0), self.max_past), dtype=torch.bool) 182 | _padding_mask = torch.ones((data.x.size(0), self.max_past+self.max_fut), dtype=torch.bool) 183 | _positions = torch.zeros((data.x.size(0), self.max_past+self.max_fut, data.positions.size(-1)), dtype=data.positions.dtype) 184 | 185 | if data['source'] == 0: 186 | past_mask, fut_mask, tot_mask = self.mask_nus_past, self.mask_nus_fut, self.mask_nus_tot 187 | elif data['source'] == 1: 188 | past_mask, fut_mask, tot_mask = self.mask_Argo_past, self.mask_Argo_fut, self.mask_Argo_tot 189 | 190 | _x[:,past_mask] = data.x 191 | if data.y is not None: _y[:,fut_mask] = data.y 192 | _bos_mask[:, past_mask] = data.bos_mask 193 | _padding_mask[:,tot_mask] = data.padding_mask 194 | _positions[:,tot_mask] = data.positions 195 | 196 | data.x, data.y, data.bos_mask, data.padding_mask, data.positions = _x, _y, _bos_mask, _padding_mask, _positions 197 | 198 | elif self.type == 'continuous': 199 | raise NotImplementedError('continuous is not used for now') 200 | if data['source'] == 0: 201 | _x = torch.zeros((data.x.size(0), self.max_past, data.x.size(-1)), dtype=data.x.dtype) 202 | _y = torch.zeros((data.y.size(0), self.max_fut, data.y.size(-1)), dtype=data.y.dtype) 203 | _bos_mask = torch.zeros((data.x.size(0), self.max_past), dtype=torch.bool) 204 | _padding_mask = torch.ones((data.x.size(0), self.max_past+self.max_fut), dtype=torch.bool) 205 | _positions = torch.zeros((data.x.size(0), self.max_past+self.max_fut, data.positions.size(-1)), dtype=data.positions.dtype) 206 | 207 | _x[:,-5:] = data.x 208 | _y[:,:12] = data.y 209 | _bos_mask[:, -5:] = data.bos_mask 210 | _padding_mask[:,19-5+1:19+12+1] = data.padding_mask 211 | _positions[:,19-5+1:19+12+1] = data.positions 212 | 213 | data.x, data.y, data.bos_mask, data.padding_mask, data.positions = _x, _y, _bos_mask, _padding_mask, _positions 214 | 215 | _ts_past, _ts_fut = torch.zeros((data.x.size(0), self.max_past), dtype=self.nus_ts_past.dtype), torch.zeros((data.y.size(0), self.max_fut), dtype=self.nus_ts_fut.dtype) 216 | _ts_past[:,-5:] = self.nus_ts_past 217 | _ts_fut[:,:12] = self.nus_ts_fut 218 | data['ts_past'], data['ts_fut'], data['ts_tot'] = _ts_past, _ts_fut, torch.cat((_ts_past, _ts_fut), dim=-1) 219 | 220 | elif data['source'] == 1: 221 | _ts_past, _ts_fut = torch.zeros((data.x.size(0), self.max_past), dtype=self.nus_ts_past.dtype), torch.zeros((data.x.size(0), self.max_fut), dtype=self.nus_ts_fut.dtype) 222 | 223 | _ts_past[:] = self.Argo_ts_past 224 | _ts_fut[:] = self.Argo_ts_fut 225 | data['ts_past'], data['ts_fut'], data['ts_tot'] = _ts_past, _ts_fut, torch.cat((_ts_past, _ts_fut), dim=-1) 226 | else: 227 | raise KeyError('source should be nuScenes(0) or Argoverse(1)') 228 | 229 | if self._split == 'train': 230 | data = self.augment(data) 231 | 232 | return data 233 | 234 | def augment(self, data): 235 | if self.random_flip: 236 | if random.choice([0, 1]): 237 | data.x = data.x * torch.tensor([-1,1]) 238 | data.y = data.y * torch.tensor([-1,1]) 239 | data.positions = data.positions * torch.tensor([-1,1]) 240 | theta_x = torch.cos(data.theta) 241 | theta_y = torch.sin(data.theta) 242 | data.theta = torch.atan2(theta_y, -1*theta_x) 243 | angle_x = torch.cos(data.rotate_angles) 244 | angle_y = torch.sin(data.rotate_angles) 245 | data.rotate_angles = torch.atan2(angle_y, -1*angle_x) 246 | data.lane_positions = data.lane_positions * torch.tensor([-1,1]) 247 | data.lane_vectors = data.lane_vectors * torch.tensor([-1,1]) 248 | data.lane_actor_vectors = data.lane_actor_vectors * torch.tensor([-1,1]) 249 | if random.choice([0, 1]): 250 | data.x = data.x * torch.tensor([1,-1]) 251 | data.y = data.y * torch.tensor([1,-1]) 252 | data.positions = data.positions * torch.tensor([1,-1]) 253 | theta_x = torch.cos(data.theta) 254 | theta_y = torch.sin(data.theta) 255 | data.theta = torch.atan2(-1*theta_y, theta_x) 256 | angle_x = torch.cos(data.rotate_angles) 257 | angle_y = torch.sin(data.rotate_angles) 258 | data.rotate_angles = torch.atan2(-1*angle_y, angle_x) 259 | data.lane_positions = data.lane_positions * torch.tensor([1,-1]) 260 | data.lane_vectors = data.lane_vectors * torch.tensor([1,-1]) 261 | data.lane_actor_vectors = data.lane_actor_vectors * torch.tensor([1,-1]) 262 | 263 | return data 264 | 265 | def _process(self): 266 | if files_exist(self.processed_paths): # pragma: no cover 267 | print('Found processed files') 268 | return 269 | else: 270 | raise FileExistsError('Both nuScenes and Argoverse dataset should be parsed') 271 | 272 | 273 | if __name__ == '__main__': 274 | split = 'val' 275 | spec_args = {'type': 'grid', 'nus': True, 'Argo': True} # input type should be in ['grid', 'continuous'] -> grid for Conventional method, continuous for Neural DE method 276 | 277 | A1D = nuArgoDataset(split, nu_root='data/nuScenes', Argo_root='data/argodataset', nu_dir='preprocessed/nuScenes_frm', Argo_dir='preprocessed/Argoverse_abs', spec_args=spec_args) 278 | 279 | from debug_util import viz_data_goal 280 | 281 | for data in A1D: 282 | viz_data_goal(data, 'tmp', 20) 283 | -------------------------------------------------------------------------------- /models/encoders/enc_hivt_nusargo_grid.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_geometric.data import Batch 7 | from torch_geometric.data import Data 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.typing import Adj 10 | from torch_geometric.typing import OptTensor 11 | from torch_geometric.typing import Size 12 | from torch_geometric.utils import softmax 13 | from torch_geometric.utils import subgraph 14 | 15 | from models.utils.embedding import MultipleInputEmbedding 16 | from models.utils.embedding import SingleInputEmbedding 17 | from models.utils.util import DistanceDropEdge 18 | from models.utils.util import TemporalData 19 | from models.utils.util import init_weights 20 | 21 | 22 | class LocalEncoder(nn.Module): 23 | 24 | def __init__(self, 25 | **kwargs) -> None: 26 | super(LocalEncoder, self).__init__() 27 | 28 | for key, value in kwargs.items(): 29 | setattr(self, key, value) 30 | 31 | self.drop_edge = DistanceDropEdge(self.local_radius) 32 | self.aa_encoder = AAEncoder(historical_steps=self.historical_steps, 33 | node_dim=self.node_dim, 34 | edge_dim=self.edge_dim, 35 | embed_dim=self.embed_dim, 36 | num_heads=self.num_heads, 37 | dropout=self.dropout, 38 | parallel=self.parallel, 39 | input_diff=self.input_diff) 40 | self.temporal_encoder = TemporalEncoder(historical_steps=self.historical_steps, 41 | embed_dim=self.embed_dim, 42 | num_heads=self.num_heads, 43 | dropout=self.dropout, 44 | num_layers=self.num_temporal_layers) 45 | self.al_encoder = ALEncoder(node_dim=self.node_dim, 46 | edge_dim=self.edge_dim, 47 | embed_dim=self.embed_dim, 48 | num_heads=self.num_heads, 49 | dropout=self.dropout) 50 | 51 | self.apply(init_weights) 52 | 53 | def forward(self, data: TemporalData) -> torch.Tensor: 54 | # nus_batches = torch.where(data.source == 0)[0] 55 | # nus_mask = torch.isin(data.batch, nus_batches) 56 | # data.x[nus_mask] = data.x[nus_mask]/5 57 | 58 | for t in range(self.historical_steps): 59 | data[f'edge_index_{t}'], _ = subgraph(subset=~data['padding_mask'][:, t], edge_index=data.edge_index) 60 | data[f'edge_attr_{t}'] = \ 61 | data['positions'][data[f'edge_index_{t}'][0], t] - data['positions'][data[f'edge_index_{t}'][1], t] 62 | if self.parallel: 63 | snapshots = [None] * self.historical_steps 64 | for t in range(self.historical_steps): 65 | edge_index, edge_attr = self.drop_edge(data[f'edge_index_{t}'], data[f'edge_attr_{t}']) 66 | snapshots[t] = Data(x=data.x[:, t], edge_index=edge_index, edge_attr=edge_attr, 67 | num_nodes=data.num_nodes) 68 | batch = Batch.from_data_list(snapshots) 69 | out = self.aa_encoder(x=batch.x, t=None, edge_index=batch.edge_index, edge_attr=batch.edge_attr, 70 | bos_mask=data['bos_mask'], rotate_mat=data['rotate_mat']) 71 | out = out.view(self.historical_steps, out.shape[0] // self.historical_steps, -1) 72 | else: 73 | out = [None] * self.historical_steps 74 | for t in range(self.historical_steps): 75 | edge_index, edge_attr = self.drop_edge(data[f'edge_index_{t}'], data[f'edge_attr_{t}']) 76 | out[t] = self.aa_encoder(x=data.x[:, t], t=t, edge_index=edge_index, edge_attr=edge_attr, 77 | bos_mask=data['bos_mask'][:, t], rotate_mat=data['rotate_mat']) 78 | out = torch.stack(out) # [T, N, D] 79 | out = self.temporal_encoder(x=out, padding_mask=data['padding_mask'][:, : self.historical_steps]) 80 | edge_index, edge_attr = self.drop_edge(data['lane_actor_index'], data['lane_actor_vectors']) 81 | 82 | lane_len = (1-data['lane_paddings']).sum(-1) 83 | lane_start_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),0,:] 84 | lane_end_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),(lane_len-1).long(),:] 85 | lane_feat = lane_end_pos - lane_start_pos 86 | # is_intersections, turn_directions, traffic_controls = data['is_intersections'][:,0,0], data['turn_directions'][:,0,0], data['traffic_controls'][:,0,0] 87 | # out = self.al_encoder(x=(lane_feat, out), edge_index=edge_index, edge_attr=edge_attr, 88 | # is_intersections=is_intersections, turn_directions=turn_directions, 89 | # traffic_controls=traffic_controls, rotate_mat=data['rotate_mat']) 90 | out = self.al_encoder(x=(lane_feat, out), edge_index=edge_index, edge_attr=edge_attr, 91 | rotate_mat=data['rotate_mat']) 92 | return out 93 | 94 | 95 | class AAEncoder(MessagePassing): 96 | 97 | def __init__(self, 98 | historical_steps: int, 99 | node_dim: int, 100 | edge_dim: int, 101 | embed_dim: int, 102 | num_heads: int = 8, 103 | dropout: float = 0.1, 104 | parallel: bool = False, 105 | **kwargs) -> None: 106 | super(AAEncoder, self).__init__(aggr='add', node_dim=0, **kwargs) 107 | self.historical_steps = historical_steps 108 | self.embed_dim = embed_dim 109 | self.num_heads = num_heads 110 | self.parallel = parallel 111 | self.input_diff = kwargs['input_diff'] 112 | 113 | self.center_embed = SingleInputEmbedding(in_channel=node_dim, out_channel=embed_dim) 114 | self.nbr_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim], out_channel=embed_dim) 115 | self.lin_q = nn.Linear(embed_dim, embed_dim) 116 | self.lin_k = nn.Linear(embed_dim, embed_dim) 117 | self.lin_v = nn.Linear(embed_dim, embed_dim) 118 | self.lin_self = nn.Linear(embed_dim, embed_dim) 119 | self.attn_drop = nn.Dropout(dropout) 120 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 121 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 122 | self.out_proj = nn.Linear(embed_dim, embed_dim) 123 | self.proj_drop = nn.Dropout(dropout) 124 | self.norm1 = nn.LayerNorm(embed_dim) 125 | self.norm2 = nn.LayerNorm(embed_dim) 126 | self.mlp = nn.Sequential( 127 | nn.Linear(embed_dim, embed_dim * 4), 128 | nn.ReLU(inplace=True), 129 | nn.Dropout(dropout), 130 | nn.Linear(embed_dim * 4, embed_dim), 131 | nn.Dropout(dropout)) 132 | self.bos_token = nn.Parameter(torch.Tensor(historical_steps, embed_dim)) 133 | nn.init.normal_(self.bos_token, mean=0., std=.02) 134 | self.apply(init_weights) 135 | 136 | def forward(self, 137 | x: torch.Tensor, 138 | t: Optional[int], 139 | edge_index: Adj, 140 | edge_attr: torch.Tensor, 141 | bos_mask: torch.Tensor, 142 | rotate_mat: Optional[torch.Tensor] = None, 143 | size: Size = None) -> torch.Tensor: 144 | if self.parallel: 145 | if rotate_mat is None: 146 | center_embed = self.center_embed(x.view(self.historical_steps, x.shape[0] // self.historical_steps, -1)) 147 | else: 148 | center_embed = self.center_embed( 149 | torch.matmul(x.view(self.historical_steps, x.shape[0] // self.historical_steps, -1).unsqueeze(-2), 150 | rotate_mat.expand(self.historical_steps, *rotate_mat.shape)).squeeze(-2)) 151 | 152 | if self.input_diff: 153 | center_embed = torch.where(bos_mask.t().unsqueeze(-1), 154 | self.bos_token.unsqueeze(-2), 155 | center_embed) 156 | center_embed = center_embed.contiguous().view(x.shape[0], -1) 157 | 158 | else: 159 | if rotate_mat is None: 160 | center_embed = self.center_embed(x) 161 | else: 162 | center_embed = self.center_embed(torch.bmm(x.unsqueeze(-2), rotate_mat).squeeze(-2)) 163 | if self.input_diff: center_embed = torch.where(bos_mask.unsqueeze(-1), self.bos_token[t], center_embed) 164 | center_embed = center_embed + self._mha_block(self.norm1(center_embed), x, edge_index, edge_attr, rotate_mat, 165 | size) 166 | center_embed = center_embed + self._ff_block(self.norm2(center_embed)) 167 | return center_embed 168 | 169 | def message(self, 170 | edge_index: Adj, 171 | center_embed_i: torch.Tensor, 172 | x_j: torch.Tensor, 173 | edge_attr: torch.Tensor, 174 | rotate_mat: Optional[torch.Tensor], 175 | index: torch.Tensor, 176 | ptr: OptTensor, 177 | size_i: Optional[int]) -> torch.Tensor: 178 | if rotate_mat is None: 179 | nbr_embed = self.nbr_embed([x_j, edge_attr]) 180 | else: 181 | if self.parallel: 182 | center_rotate_mat = rotate_mat.repeat(self.historical_steps, 1, 1)[edge_index[1]] 183 | else: 184 | center_rotate_mat = rotate_mat[edge_index[1]] 185 | nbr_embed = self.nbr_embed([torch.bmm(x_j.unsqueeze(-2), center_rotate_mat).squeeze(-2), 186 | torch.bmm(edge_attr.unsqueeze(-2), center_rotate_mat).squeeze(-2)]) 187 | query = self.lin_q(center_embed_i).view(-1, self.num_heads, self.embed_dim // self.num_heads) 188 | key = self.lin_k(nbr_embed).view(-1, self.num_heads, self.embed_dim // self.num_heads) 189 | value = self.lin_v(nbr_embed).view(-1, self.num_heads, self.embed_dim // self.num_heads) 190 | scale = (self.embed_dim // self.num_heads) ** 0.5 191 | alpha = (query * key).sum(dim=-1) / scale 192 | alpha = softmax(alpha, index, ptr, size_i) 193 | alpha = self.attn_drop(alpha) 194 | return value * alpha.unsqueeze(-1) 195 | 196 | def update(self, 197 | inputs: torch.Tensor, 198 | center_embed: torch.Tensor) -> torch.Tensor: 199 | inputs = inputs.view(-1, self.embed_dim) 200 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(center_embed)) 201 | return inputs + gate * (self.lin_self(center_embed) - inputs) 202 | 203 | def _mha_block(self, 204 | center_embed: torch.Tensor, 205 | x: torch.Tensor, 206 | edge_index: Adj, 207 | edge_attr: torch.Tensor, 208 | rotate_mat: Optional[torch.Tensor], 209 | size: Size) -> torch.Tensor: 210 | center_embed = self.out_proj(self.propagate(edge_index=edge_index, x=x, center_embed=center_embed, 211 | edge_attr=edge_attr, rotate_mat=rotate_mat, size=size)) 212 | return self.proj_drop(center_embed) 213 | 214 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 215 | return self.mlp(x) 216 | 217 | 218 | class TemporalEncoder(nn.Module): 219 | 220 | def __init__(self, 221 | historical_steps: int, 222 | embed_dim: int, 223 | num_heads: int = 8, 224 | num_layers: int = 4, 225 | dropout: float = 0.1) -> None: 226 | super(TemporalEncoder, self).__init__() 227 | encoder_layer = TemporalEncoderLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) 228 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers, 229 | norm=nn.LayerNorm(embed_dim)) 230 | self.padding_token = nn.Parameter(torch.Tensor(historical_steps, 1, embed_dim)) 231 | self.cls_token = nn.Parameter(torch.Tensor(1, 1, embed_dim)) 232 | self.pos_embed = nn.Parameter(torch.Tensor(historical_steps + 1, 1, embed_dim)) 233 | attn_mask = self.generate_square_subsequent_mask(historical_steps + 1) 234 | self.register_buffer('attn_mask', attn_mask) 235 | nn.init.normal_(self.padding_token, mean=0., std=.02) 236 | nn.init.normal_(self.cls_token, mean=0., std=.02) 237 | nn.init.normal_(self.pos_embed, mean=0., std=.02) 238 | self.apply(init_weights) 239 | 240 | def forward(self, 241 | x: torch.Tensor, 242 | padding_mask: torch.Tensor) -> torch.Tensor: 243 | x = torch.where(padding_mask.t().unsqueeze(-1), self.padding_token, x) 244 | expand_cls_token = self.cls_token.expand(-1, x.shape[1], -1) 245 | x = torch.cat((x, expand_cls_token), dim=0) 246 | x = x + self.pos_embed 247 | out = self.transformer_encoder(src=x, mask=self.attn_mask, src_key_padding_mask=None) 248 | return out[-1] # [N, D] 249 | 250 | @staticmethod 251 | def generate_square_subsequent_mask(seq_len: int) -> torch.Tensor: 252 | mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1) 253 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 254 | return mask 255 | 256 | 257 | class TemporalEncoderLayer(nn.Module): 258 | 259 | def __init__(self, 260 | embed_dim: int, 261 | num_heads: int = 8, 262 | dropout: float = 0.1) -> None: 263 | super(TemporalEncoderLayer, self).__init__() 264 | self.self_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) 265 | self.linear1 = nn.Linear(embed_dim, embed_dim * 4) 266 | self.dropout = nn.Dropout(dropout) 267 | self.linear2 = nn.Linear(embed_dim * 4, embed_dim) 268 | self.norm1 = nn.LayerNorm(embed_dim) 269 | self.norm2 = nn.LayerNorm(embed_dim) 270 | self.dropout1 = nn.Dropout(dropout) 271 | self.dropout2 = nn.Dropout(dropout) 272 | 273 | def forward(self, 274 | src: torch.Tensor, 275 | src_mask: Optional[torch.Tensor] = None, 276 | src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 277 | x = src 278 | x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) 279 | x = x + self._ff_block(self.norm2(x)) 280 | return x 281 | 282 | def _sa_block(self, 283 | x: torch.Tensor, 284 | attn_mask: Optional[torch.Tensor], 285 | key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: 286 | x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] 287 | return self.dropout1(x) 288 | 289 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 290 | x = self.linear2(self.dropout(F.relu_(self.linear1(x)))) 291 | return self.dropout2(x) 292 | 293 | 294 | class ALEncoder(MessagePassing): 295 | 296 | def __init__(self, 297 | node_dim: int, 298 | edge_dim: int, 299 | embed_dim: int, 300 | num_heads: int = 8, 301 | dropout: float = 0.1, 302 | **kwargs) -> None: 303 | super(ALEncoder, self).__init__(aggr='add', node_dim=0, **kwargs) 304 | self.embed_dim = embed_dim 305 | self.num_heads = num_heads 306 | 307 | self.lane_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim], out_channel=embed_dim) 308 | self.lin_q = nn.Linear(embed_dim, embed_dim) 309 | self.lin_k = nn.Linear(embed_dim, embed_dim) 310 | self.lin_v = nn.Linear(embed_dim, embed_dim) 311 | self.lin_self = nn.Linear(embed_dim, embed_dim) 312 | self.attn_drop = nn.Dropout(dropout) 313 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 314 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 315 | self.out_proj = nn.Linear(embed_dim, embed_dim) 316 | self.proj_drop = nn.Dropout(dropout) 317 | self.norm1 = nn.LayerNorm(embed_dim) 318 | self.norm2 = nn.LayerNorm(embed_dim) 319 | self.mlp = nn.Sequential( 320 | nn.Linear(embed_dim, embed_dim * 4), 321 | nn.ReLU(inplace=True), 322 | nn.Dropout(dropout), 323 | nn.Linear(embed_dim * 4, embed_dim), 324 | nn.Dropout(dropout)) 325 | self.is_intersection_embed = nn.Parameter(torch.Tensor(2, embed_dim)) 326 | self.turn_direction_embed = nn.Parameter(torch.Tensor(3, embed_dim)) 327 | self.traffic_control_embed = nn.Parameter(torch.Tensor(2, embed_dim)) 328 | nn.init.normal_(self.is_intersection_embed, mean=0., std=.02) 329 | nn.init.normal_(self.turn_direction_embed, mean=0., std=.02) 330 | nn.init.normal_(self.traffic_control_embed, mean=0., std=.02) 331 | self.apply(init_weights) 332 | 333 | def forward(self, 334 | x: Tuple[torch.Tensor, torch.Tensor], 335 | edge_index: Adj, 336 | edge_attr: torch.Tensor, 337 | is_intersections: torch.Tensor = None, 338 | turn_directions: torch.Tensor = None, 339 | traffic_controls: torch.Tensor = None, 340 | rotate_mat: Optional[torch.Tensor] = None, 341 | size: Size = None) -> torch.Tensor: 342 | x_lane, x_actor = x 343 | 344 | x_actor = x_actor + self._mha_block(self.norm1(x_actor), x_lane, edge_index, edge_attr, rotate_mat, size) 345 | x_actor = x_actor + self._ff_block(self.norm2(x_actor)) 346 | return x_actor 347 | 348 | def message(self, 349 | edge_index: Adj, 350 | x_i: torch.Tensor, 351 | x_j: torch.Tensor, 352 | edge_attr: torch.Tensor, 353 | # is_intersections_j, 354 | # turn_directions_j, 355 | # traffic_controls_j, 356 | rotate_mat: Optional[torch.Tensor], 357 | index: torch.Tensor, 358 | ptr: OptTensor, 359 | size_i: Optional[int]) -> torch.Tensor: 360 | if rotate_mat is None: 361 | x_j = self.lane_embed([x_j, edge_attr]) 362 | else: 363 | rotate_mat = rotate_mat[edge_index[1]] 364 | x_j = self.lane_embed([torch.bmm(x_j.unsqueeze(-2), rotate_mat).squeeze(-2), 365 | torch.bmm(edge_attr.unsqueeze(-2), rotate_mat).squeeze(-2)]) 366 | query = self.lin_q(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads) 367 | key = self.lin_k(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 368 | value = self.lin_v(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 369 | scale = (self.embed_dim // self.num_heads) ** 0.5 370 | alpha = (query * key).sum(dim=-1) / scale 371 | alpha = softmax(alpha, index, ptr, size_i) 372 | alpha = self.attn_drop(alpha) 373 | return value * alpha.unsqueeze(-1) 374 | 375 | def update(self, 376 | inputs: torch.Tensor, 377 | x: torch.Tensor) -> torch.Tensor: 378 | x_actor = x[1] 379 | inputs = inputs.view(-1, self.embed_dim) 380 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x_actor)) 381 | return inputs + gate * (self.lin_self(x_actor) - inputs) 382 | 383 | def _mha_block(self, 384 | x_actor: torch.Tensor, 385 | x_lane: torch.Tensor, 386 | edge_index: Adj, 387 | edge_attr: torch.Tensor, 388 | # is_intersections: torch.Tensor, 389 | # turn_directions: torch.Tensor, 390 | # traffic_controls: torch.Tensor, 391 | rotate_mat: Optional[torch.Tensor], 392 | size: Size) -> torch.Tensor: 393 | x_actor = self.out_proj(self.propagate(edge_index=edge_index, x=(x_lane, x_actor), edge_attr=edge_attr, 394 | rotate_mat=rotate_mat, size=size)) 395 | return self.proj_drop(x_actor) 396 | 397 | def _ff_block(self, x_actor: torch.Tensor) -> torch.Tensor: 398 | return self.mlp(x_actor) -------------------------------------------------------------------------------- /dataset/Argoverse/Argoverse_abs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from itertools import permutations 4 | from itertools import product 5 | from typing import Callable, Dict, List, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from pyquaternion import Quaternion 10 | import torch 11 | 12 | from multiprocessing import Process 13 | from multiprocessing import Pool 14 | from itertools import repeat 15 | 16 | from scipy.spatial.distance import cdist 17 | from shapely.geometry import Point 18 | from shapely.geometry.polygon import Polygon 19 | 20 | from torch_geometric.data import Data 21 | from torch_geometric.data import Dataset 22 | from torch_geometric.data.dataset import files_exist 23 | from tqdm import tqdm 24 | import pickle as pkl 25 | from shapely.geometry import LineString, Point 26 | import random 27 | 28 | import sys 29 | sys.path.append('/mnt/ssd2/frm_lightning_backup/') 30 | from models.utils.util import TemporalData 31 | from debug_util import * 32 | 33 | SPLIT_NAME = {'nuScenes': {'train': 'train', 'val': 'train_val', 'test': 'val', 'mini_train': 'mini_train', 'mini_val': 'mini_val'}, 34 | 'Argoverse': {'train': 'train', 'val': 'val', 'test': 'test_obs', 'sample': 'forecasting_sample'}} 35 | 36 | 37 | class ArgoverseDataset(Dataset): 38 | 39 | def __init__(self, 40 | split: str, 41 | root: str, 42 | process_dir: str, 43 | transform: Optional[Callable] = None, 44 | local_radius: float = 50, 45 | spec_args: Dict = None) -> None: 46 | self._split = split 47 | self._local_radius = local_radius 48 | for k,v in spec_args.items(): 49 | self.__setattr__(k, v) 50 | 51 | self._directory = SPLIT_NAME[self.dataset][split] 52 | self.root = root 53 | self.process_dir = process_dir 54 | 55 | raw_file_names_fn = os.path.join(self.process_dir, f'raw_processed_fns_{split}.pt') 56 | 57 | if os.path.isfile(raw_file_names_fn): 58 | with open(file=raw_file_names_fn, mode='rb') as f: 59 | self._raw_file_names, self._processed_file_names, self._processed_paths = pkl.load(f) 60 | else: 61 | self._raw_file_names = os.listdir(self.raw_dir) 62 | self._processed_file_names, self._processed_paths = [], [] 63 | for f in self._raw_file_names: 64 | processed_file_name = os.path.splitext(f)[0] + '.pt' 65 | processed_path = os.path.join(self.processed_dir, processed_file_name) 66 | self._processed_file_names.append(processed_file_name) 67 | self._processed_paths.append(processed_path) 68 | 69 | os.makedirs(self.process_dir, exist_ok=True) 70 | with open(file=raw_file_names_fn, mode='wb') as f: 71 | pkl.dump([self._raw_file_names, self._processed_file_names, self._processed_paths], f) 72 | 73 | super(ArgoverseDataset, self).__init__(root, transform=transform) 74 | 75 | def _download(self): 76 | return 77 | 78 | @property 79 | def raw_dir(self) -> str: 80 | return os.path.join(self.root, self._directory, 'data') 81 | 82 | @property 83 | def processed_dir(self) -> str: 84 | return os.path.join(self.process_dir, self._directory) 85 | 86 | # @property 87 | # def raw_file_names(self) -> Union[str, List[str], Tuple]: 88 | # return self._raw_file_names 89 | 90 | @property 91 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 92 | return self._processed_file_names 93 | 94 | @property 95 | def processed_paths(self) -> List[str]: 96 | return self._processed_paths 97 | 98 | def len(self) -> int: 99 | return len(self._raw_file_names) 100 | 101 | def get(self, idx) -> Data: 102 | data = torch.load(self.processed_paths[idx]) 103 | if self._split == 'train': 104 | data = self.augment(data) 105 | return data 106 | 107 | def augment(self, data): 108 | if self.random_flip: 109 | if random.choice([0, 1]): 110 | data.x = data.x * torch.tensor([-1,1]) 111 | data.y = data.y * torch.tensor([-1,1]) 112 | data.positions = data.positions * torch.tensor([-1,1]) 113 | theta_x = torch.cos(data.theta) 114 | theta_y = torch.sin(data.theta) 115 | data.theta = torch.atan2(theta_y, -1*theta_x) 116 | angle_x = torch.cos(data.rotate_angles) 117 | angle_y = torch.sin(data.rotate_angles) 118 | data.rotate_angles = torch.atan2(angle_y, -1*angle_x) 119 | # lane_angle_x = torch.cos(data.lane_rotate_angles) 120 | # lane_angle_y = torch.sin(data.lane_rotate_angles) 121 | # data.lane_rotate_angles = torch.atan2(lane_angle_y, -1*lane_angle_x) 122 | data.lane_positions = data.lane_positions * torch.tensor([-1,1]) 123 | data.lane_vectors = data.lane_vectors * torch.tensor([-1,1]) 124 | data.lane_actor_vectors = data.lane_actor_vectors * torch.tensor([-1,1]) 125 | if random.choice([0, 1]): 126 | data.x = data.x * torch.tensor([1,-1]) 127 | data.y = data.y * torch.tensor([1,-1]) 128 | data.positions = data.positions * torch.tensor([1,-1]) 129 | theta_x = torch.cos(data.theta) 130 | theta_y = torch.sin(data.theta) 131 | data.theta = torch.atan2(-1*theta_y, theta_x) 132 | angle_x = torch.cos(data.rotate_angles) 133 | angle_y = torch.sin(data.rotate_angles) 134 | data.rotate_angles = torch.atan2(-1*angle_y, angle_x) 135 | # lane_angle_x = torch.cos(data.lane_rotate_angles) 136 | # lane_angle_y = torch.sin(data.lane_rotate_angles) 137 | # data.lane_rotate_angles = torch.atan2(-1*lane_angle_y, lane_angle_x) 138 | data.lane_positions = data.lane_positions * torch.tensor([1,-1]) 139 | data.lane_vectors = data.lane_vectors * torch.tensor([1,-1]) 140 | data.lane_actor_vectors = data.lane_actor_vectors * torch.tensor([1,-1]) 141 | 142 | if self.random_rotate: 143 | rotate_scale = self.random_rotate_deg / 180 * torch.pi 144 | rotate_angle_pert = torch.clamp(torch.normal(torch.zeros_like(data.rotate_angles), torch.ones_like(data.rotate_angles)), -1, 1) * rotate_scale 145 | data.rotate_angles = data.rotate_angles + rotate_angle_pert 146 | data.rotate_angles = data.rotate_angles + 2*torch.pi*(data.rotate_angles < -torch.pi) - 2*torch.pi*(data.rotate_angles > torch.pi) 147 | 148 | 149 | return data 150 | 151 | def _process(self): 152 | 153 | if files_exist(self.processed_paths): # pragma: no cover 154 | print('Found processed files') 155 | return 156 | 157 | print('Processing...', file=sys.stderr) 158 | 159 | os.makedirs(self.processed_dir, exist_ok=True) 160 | self.process() 161 | 162 | print('Done!', file=sys.stderr) 163 | 164 | def process(self) -> None: 165 | 166 | from argoverse.map_representation.map_api import ArgoverseMap 167 | 168 | am = ArgoverseMap() 169 | 170 | self.process_argoverse(self._split, self._raw_file_names, am, self._local_radius) 171 | 172 | def process_argoverse(self, split, raw_file_names: str, am, radius: float) -> Dict: 173 | 174 | for raw_fn in tqdm(raw_file_names): 175 | df = pd.read_csv(os.path.join(self.raw_dir, raw_fn)) 176 | 177 | # filter out actors that are unseen during the historical time steps 178 | timestamps = list(np.sort(df['TIMESTAMP'].unique())) 179 | historical_timestamps = timestamps[: 20] 180 | ref_timestep = [timestamps[19]] 181 | historical_df = df[df['TIMESTAMP'].isin(ref_timestep)] 182 | actor_ids = list(historical_df['TRACK_ID'].unique()) 183 | df = df[df['TRACK_ID'].isin(actor_ids)] 184 | num_nodes = len(actor_ids) 185 | 186 | av_df = df[df['OBJECT_TYPE'] == 'AV'].iloc 187 | av_index = actor_ids.index(av_df[0]['TRACK_ID']) 188 | agent_df = df[df['OBJECT_TYPE'] == 'AGENT'].iloc 189 | agent_index = actor_ids.index(agent_df[0]['TRACK_ID']) 190 | city = df['CITY_NAME'].values[0] 191 | 192 | # make the scene centered at AV 193 | origin = torch.tensor([av_df[19]['X'], av_df[19]['Y']], dtype=torch.float) 194 | av_heading_vector = origin - torch.tensor([av_df[18]['X'], av_df[18]['Y']], dtype=torch.float) 195 | theta = torch.atan2(av_heading_vector[1], av_heading_vector[0]) 196 | rotate_mat = torch.tensor([[torch.cos(theta), -torch.sin(theta)], 197 | [torch.sin(theta), torch.cos(theta)]]) 198 | 199 | # initialization 200 | x = torch.zeros(num_nodes, 50, 2, dtype=torch.float) 201 | edge_index = torch.LongTensor(list(permutations(range(num_nodes), 2))).t().contiguous() 202 | padding_mask = torch.ones(num_nodes, 50, dtype=torch.bool) 203 | bos_mask = torch.zeros(num_nodes, 20, dtype=torch.bool) 204 | rotate_angles = torch.zeros(num_nodes, dtype=torch.float) 205 | 206 | for actor_id, actor_df in df.groupby('TRACK_ID'): 207 | node_idx = actor_ids.index(actor_id) 208 | node_steps = [timestamps.index(timestamp) for timestamp in actor_df['TIMESTAMP']] 209 | padding_mask[node_idx, node_steps] = False 210 | if padding_mask[node_idx, 19]: # make no predictions for actors that are unseen at the current time step 211 | padding_mask[node_idx, 20:] = True 212 | xy = torch.from_numpy(np.stack([actor_df['X'].values, actor_df['Y'].values], axis=-1)).float() 213 | x[node_idx, node_steps] = torch.matmul(xy - origin, rotate_mat) 214 | node_historical_steps = list(filter(lambda node_step: node_step < 20, node_steps)) 215 | if len(node_historical_steps) > 1: # calculate the heading of the actor (approximately) 216 | heading_vector = x[node_idx, node_historical_steps[-1]] - x[node_idx, node_historical_steps[-2]] 217 | rotate_angles[node_idx] = torch.atan2(heading_vector[1], heading_vector[0]) 218 | else: # make no predictions for the actor if the number of valid time steps is less than 2 219 | padding_mask[node_idx, 20:] = True 220 | 221 | # bos_mask is True if time step t is valid and time step t-1 is invalid 222 | bos_mask[:, 0] = ~padding_mask[:, 0] 223 | bos_mask[:, 1: 20] = padding_mask[:, : 19] & ~padding_mask[:, 1: 20] 224 | 225 | positions = x.clone() 226 | x[:, 20:] = torch.where((padding_mask[:, 19].unsqueeze(-1) | padding_mask[:, 20:]).unsqueeze(-1), 227 | torch.zeros(num_nodes, 30, 2), 228 | x[:, 20:] - x[:, 19].unsqueeze(-2)) 229 | x[:, 0: 20] = torch.where(padding_mask[:, : 20].unsqueeze(-1), 230 | torch.zeros(num_nodes, 20, 2), 231 | x[:, 0: 20] - x[:, 19].unsqueeze(1)) 232 | 233 | # get lane features at the current time step 234 | df_19 = df[df['TIMESTAMP'] == timestamps[19]] 235 | node_inds_19 = [actor_ids.index(actor_id) for actor_id in df_19['TRACK_ID']] 236 | node_positions_19 = torch.from_numpy(np.stack([df_19['X'].values, df_19['Y'].values], axis=-1)).float() 237 | lane_positions, lane_vectors, lane_lengths, is_intersections, turn_directions, traffic_controls = self.get_lane_features(am, node_inds_19, node_positions_19, origin, rotate_mat, city, radius) 238 | 239 | node_positions_goal = positions[:,-1] 240 | node_diff_goal = positions[:,-1] - positions[:,-2] 241 | node_goal_mask = ~padding_mask[:,-1] 242 | 243 | goal_idcs, has_goal = self.get_goal_lane(node_positions_goal, node_diff_goal, node_goal_mask, lane_positions, lane_vectors) 244 | 245 | node_inds_ref = torch.arange(num_nodes, dtype=torch.float)[~padding_mask[:,self.ref_time]] 246 | lane_positions, lane_vectors, lane_rotate_angles, lane_paddings, is_intersections, turn_directions, traffic_controls, lane_actor_index, lane_actor_vectors, lane_lengths, goal_idcs, has_goal = \ 247 | self.get_lane_tensors(node_inds_ref, node_positions_19, lane_positions, lane_vectors, lane_lengths, is_intersections, turn_directions, traffic_controls, goal_idcs, has_goal, origin, rotate_mat, rotate_angles, radius) 248 | 249 | y = None if split == 'test' else x[:, 20:] 250 | seq_id = os.path.splitext(os.path.basename(raw_fn))[0] 251 | 252 | processed = { 253 | 'x': x[:, : 20], # [N, 20, 2] 254 | 'positions': positions, # [N, 50, 2] 255 | 'edge_index': edge_index, # [2, N x N - 1] 256 | 'y': y, # [N, 30, 2] 257 | 'num_nodes': num_nodes, 258 | 'padding_mask': padding_mask, # [N, 50] 259 | 'bos_mask': bos_mask, # [N, 20] 260 | 'rotate_angles': rotate_angles, # [N] 261 | 'lane_positions': lane_positions, # [L, 6, 2] 262 | 'lane_vectors': lane_vectors, # [L, 2] 263 | 'lane_paddings': lane_paddings, # [L,6] 264 | 'lane_lengths': lane_lengths, 265 | 'is_intersections': is_intersections, # [L] 266 | 'turn_directions': turn_directions, # [L] 267 | 'traffic_controls': traffic_controls, # [L] 268 | 'lane_actor_index': lane_actor_index, # [2, E_{A-L}] 269 | 'lane_actor_vectors': lane_actor_vectors, # [E_{A-L}, 2] 270 | 'goal_idcs': goal_idcs, 271 | 'has_goal': has_goal, 272 | 'seq_id': int(seq_id), 273 | 'av_index': av_index, 274 | 'agent_index': agent_index, 275 | 'city': city, 276 | 'origin': origin.unsqueeze(0), 277 | 'theta': theta, 278 | } 279 | 280 | data = TemporalData(**processed) 281 | torch.save(data, os.path.join(self.processed_dir, seq_id+'.pt')) 282 | return 283 | 284 | 285 | def get_lane_features(self, am, 286 | node_inds: List[int], 287 | node_positions: torch.Tensor, 288 | origin: torch.Tensor, 289 | rotate_mat: torch.Tensor, 290 | city: str, 291 | radius: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, 292 | torch.Tensor]: 293 | radius = 80 294 | lane_positions, lane_vectors, lane_lengths, is_intersections, turn_directions, traffic_controls = [], [], [], [], [], [] 295 | lane_ids = set() 296 | for node_position in node_positions: 297 | lane_ids.update(am.get_lane_ids_in_xy_bbox(node_position[0], node_position[1], city, radius)) 298 | node_positions = torch.matmul(node_positions - origin, rotate_mat).float() 299 | 300 | for lane_id in lane_ids: 301 | lane_centerline = torch.from_numpy(am.get_lane_segment_centerline(lane_id, city)[:, : 2]).float() 302 | lane_centerline = torch.matmul(lane_centerline - origin, rotate_mat) 303 | 304 | is_intersection = am.lane_is_in_intersection(lane_id, city) 305 | turn_direction = am.get_lane_turn_direction(lane_id, city) 306 | traffic_control = am.lane_has_traffic_control_measure(lane_id, city) 307 | if turn_direction == 'NONE': 308 | turn_direction = 0 309 | elif turn_direction == 'LEFT': 310 | turn_direction = 1 311 | elif turn_direction == 'RIGHT': 312 | turn_direction = 2 313 | else: 314 | raise ValueError('turn direction is not valid') 315 | 316 | line = LineString(lane_centerline) 317 | total_length = line.length 318 | new_points = [] 319 | for i in range(int(total_length)): 320 | # get the point at the i-th distance along the line 321 | point = line.interpolate(i) 322 | new_points.append([point.x, point.y]) 323 | new_points = torch.tensor(new_points) 324 | 325 | if len(new_points) < 1: 326 | continue 327 | 328 | n_segments = int(np.ceil(len(new_points) / (self.lseg_len+1))) 329 | n_poses = int(np.ceil(len(new_points) / n_segments)) 330 | for n in range(n_segments): 331 | lane_segment = new_points[n * n_poses: (n+1) * n_poses] 332 | count = len(lane_segment) - 1 333 | if count > 0: 334 | lane_positions.append((lane_segment[1:] + lane_segment[:-1])/2) 335 | lane_vectors.append(lane_segment[1:] - lane_segment[:-1]) 336 | lane_lengths.append(count) 337 | is_intersections.append(is_intersection * torch.ones(count, dtype=torch.uint8)) 338 | turn_directions.append(turn_direction * torch.ones(count, dtype=torch.uint8)) 339 | traffic_controls.append(traffic_control * torch.ones(count, dtype=torch.uint8)) 340 | 341 | return lane_positions, lane_vectors, lane_lengths, is_intersections, turn_directions, traffic_controls 342 | 343 | def get_goal_lane(self, node_positions_goal, node_diff_goal, node_goal_mask, lane_positions, lane_vectors): 344 | lane_num = len(lane_positions) 345 | 346 | goal_idcs = [] 347 | has_goal = [] 348 | for nidx in range(node_positions_goal.size(0)): 349 | if not node_goal_mask[nidx]: 350 | goal_idcs.append(torch.zeros(lane_num)) 351 | has_goal.append(torch.zeros(lane_num)) 352 | else: 353 | query_pose = node_positions_goal[nidx] 354 | query_diff = node_diff_goal[nidx] 355 | query_angle = torch.atan2(query_diff[1], query_diff[0]) 356 | 357 | dist_vals = [] 358 | angle_diffs = [] 359 | for lidx in range(len(lane_positions)): 360 | lane_poses = lane_positions[lidx] 361 | lane_diffs = lane_vectors[lidx] 362 | 363 | distances = torch.norm(lane_poses - query_pose, p=2, dim=-1) 364 | dist_vals.append(torch.min(distances)) 365 | idx = torch.argmin(distances) 366 | 367 | lane_angle = torch.atan2(lane_diffs[idx,1], lane_diffs[idx,0]) 368 | angle_diff = torch.abs(normalize_angle(query_angle-lane_angle)) 369 | angle_diffs.append(angle_diff) 370 | 371 | if torch.norm(query_diff, p=2, dim=-1) < 0.1: # if difference is too small, angle is within noise. 372 | idcs_yaw = torch.arange(lane_num) 373 | else: 374 | idcs_yaw = torch.where(torch.tensor(angle_diffs) <= self.lseg_angle_thres*torch.pi/180)[0] 375 | idcs_dist = torch.where(torch.tensor(dist_vals) <= self.lseg_dist_thres)[0] 376 | idcs = np.intersect1d(idcs_dist, idcs_yaw) 377 | 378 | if len(idcs) > 0: 379 | assigned_node_id = idcs[int(torch.argmin(torch.tensor(dist_vals)[idcs]))] 380 | goal_ = torch.zeros(lane_num) 381 | goal_[assigned_node_id] = 1. 382 | goal_idcs.append(goal_) 383 | 384 | has_goal_ = torch.zeros(lane_num) 385 | has_goal_[assigned_node_id] = nidx+1 386 | has_goal.append(has_goal_) 387 | else: 388 | goal_idcs.append(torch.zeros(lane_num)) 389 | has_goal.append(torch.zeros(lane_num)) 390 | 391 | return torch.cat(goal_idcs,-1), torch.cat(has_goal,-1) 392 | 393 | def get_lane_tensors(self, node_inds: List[int], 394 | node_positions: torch.Tensor, # global coordi 395 | lane_positions: List[torch.Tensor], # local coordi 396 | lane_vectors: List[torch.Tensor], 397 | lane_lengths: List[int], 398 | is_intersections, turn_directions, traffic_controls, 399 | goal_idcs: torch.Tensor, 400 | has_goal, 401 | origin: torch.Tensor, 402 | rotate_mat: torch.Tensor, 403 | rotate_angles, 404 | radius: float): 405 | 406 | lane_positions_ = torch.zeros(len(lane_positions), self.lseg_len, 2) 407 | lane_vectors_ = torch.zeros(len(lane_positions), self.lseg_len, 2) 408 | lane_padding_ = torch.ones(len(lane_positions), self.lseg_len) 409 | is_intersections_ = torch.zeros(len(lane_positions), self.lseg_len, 1) 410 | turn_directions_ = torch.zeros(len(lane_positions), self.lseg_len, 1) 411 | traffic_controls_ = torch.zeros(len(lane_positions), self.lseg_len, 1) 412 | for lidx in range(len(lane_positions)): 413 | lane_length = lane_lengths[lidx] 414 | lane_positions_[lidx, :lane_length] = lane_positions[lidx] 415 | lane_vectors_[lidx, :lane_length] = lane_vectors[lidx] 416 | lane_padding_[lidx, :lane_length] = 0 417 | is_intersections_[lidx, :lane_length] = is_intersections[lidx].unsqueeze(-1) 418 | turn_directions_[lidx, :lane_length] = turn_directions[lidx].unsqueeze(-1) 419 | traffic_controls_[lidx, :lane_length] = traffic_controls[lidx].unsqueeze(-1) 420 | lane_rotate_angles_ = torch.arctan2(lane_vectors_[:,0,1], lane_vectors_[:,0,0]) 421 | 422 | node_positions = torch.matmul(node_positions - origin, rotate_mat) 423 | lane_actor_index = torch.flip(torch.LongTensor(list(product(node_inds.int(), torch.arange(lane_vectors_.size(0))))).t().contiguous(), [0]) 424 | assert (torch.tensor(lane_lengths) == 0).float().sum() == 0, '0 length lanes are included' 425 | lane_actor_vectors = \ 426 | lane_positions_[torch.arange(len(lane_positions)),(torch.tensor(lane_lengths)-1).long(),:].repeat(len(node_inds),1) - node_positions.repeat_interleave(lane_vectors_.size(0), dim=0) 427 | actors_rotate_mat = torch.empty(node_inds.size(0),2,2) 428 | sin_vals, cos_vals = torch.sin(rotate_angles), torch.cos(rotate_angles) 429 | actors_rotate_mat[:,0,0] = cos_vals 430 | actors_rotate_mat[:,0,1] = -sin_vals 431 | actors_rotate_mat[:,1,0] = sin_vals 432 | actors_rotate_mat[:,1,1] = cos_vals 433 | lane_actor_vectors_norm = torch.bmm(lane_actor_vectors.unsqueeze(1), actors_rotate_mat[lane_actor_index[1]]).squeeze(1) 434 | mask = (-20 torch.pi: 455 | angle -= 2*torch.pi 456 | return angle 457 | 458 | if __name__ == '__main__': 459 | split = 'val' 460 | spec_args = {'dataset': 'Argoverse', 'n_jobs': 0, 't_h': 2, 't_f': 3, 'res': 10, 'ref_time':19, 'lseg_len': 10, 'lseg_angle_thres': 30, 'lseg_dist_thres': 2.5} 461 | A1D = ArgoverseDataset(split, root='data/argodataset', process_dir='preprocessed/Argoverse', spec_args=spec_args) -------------------------------------------------------------------------------- /models/encoders/enc_hivt_nusargo_sde_sep2.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_geometric.data import Batch 7 | from torch_geometric.data import Data 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.typing import Adj 10 | from torch_geometric.typing import OptTensor 11 | from torch_geometric.typing import Size 12 | from torch_geometric.utils import softmax 13 | from torch_geometric.utils import subgraph 14 | 15 | from models.utils.embedding import MultipleInputEmbedding 16 | from models.utils.embedding import SingleInputEmbedding 17 | from models.utils.util import DistanceDropEdge 18 | from models.utils.util import TemporalData 19 | from models.utils.util import init_weights 20 | from models.utils.ode_utils import get_timesteps, GRU_Unit 21 | from models.utils.sde_utils import SDiffeqSolver, SDEFunc 22 | import torchsde 23 | from models.utils.sdeint import sdeint, sdeint_dual 24 | 25 | class LocalEncoderSDESepPara2(nn.Module): 26 | 27 | def __init__(self, 28 | **kwargs) -> None: 29 | super(LocalEncoderSDESepPara2, self).__init__() 30 | 31 | for key, value in kwargs.items(): 32 | setattr(self, key, value) 33 | 34 | self.drop_edge = DistanceDropEdge(self.local_radius) 35 | self.aa_encoder = AAEncoder(historical_steps=self.historical_steps, 36 | node_dim=self.node_dim, 37 | edge_dim=self.edge_dim, 38 | embed_dim=self.embed_dim, 39 | num_heads=self.num_heads, 40 | dropout=self.dropout, 41 | parallel=self.parallel) 42 | 43 | self.al_encoder = ALEncoder(node_dim=self.node_dim, 44 | edge_dim=self.edge_dim, 45 | embed_dim=self.embed_dim, 46 | num_heads=self.num_heads, 47 | dropout=self.dropout) 48 | 49 | self.gru_unit = GRU_Unit(self.embed_dim, self.embed_dim, n_units=self.embed_dim) 50 | 51 | sigma, theta, mu = 0.5, 1.0, 0.0 52 | post_drift = FFunc(self.embed_dim, num_layers=self.sde_layers) 53 | prior_drift = HFunc(theta=theta, mu=mu) 54 | diffusion_nus= GFunc(self.embed_dim, num_layers=self.sde_layers, sigma=sigma) 55 | diffusion_Argo2= GFunc(self.embed_dim, num_layers=self.sde_layers, sigma=sigma) 56 | self.lsde_func = LSDEFunc(f=post_drift, g_nus=diffusion_nus, g_Argo2=diffusion_Argo2, h=prior_drift, embed_dim=self.embed_dim) 57 | self.lsde_func.noise_type, self.lsde_func.sde_type = 'diagonal', 'ito' 58 | 59 | self.real_label, self.fake_label = 0, 1 60 | 61 | self.hidden = nn.Parameter(torch.Tensor(self.embed_dim)) 62 | nn.init.normal_(self.hidden, mean=0., std=.02) 63 | 64 | self.apply(init_weights) 65 | 66 | def forward(self, data: TemporalData) -> torch.Tensor: 67 | 68 | lane_len = (1-data['lane_paddings']).sum(-1) 69 | lane_start_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),0,:] 70 | lane_end_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),(lane_len-1).long(),:] 71 | lane_feat = lane_end_pos - lane_start_pos 72 | 73 | nus_batches = torch.where(data.source == 0)[0] 74 | nus_mask = torch.isin(data.batch, nus_batches) 75 | 76 | actor_num, ts_obs, in_dim = data.x.shape 77 | 78 | prev_hidden = self.hidden.unsqueeze(0).repeat(actor_num+len(data['agent_index']),1) 79 | 80 | actors_pad = data['padding_mask'] 81 | actors_past_pad, actors_fut_pad = actors_pad[:,:self.ref_time+1], actors_pad[:,self.ref_time+1:] 82 | actors_past_mask = ~actors_past_pad 83 | 84 | ######### Pre-compute AA encoding ######### 85 | 86 | ## Add only agent to train SDENet ## 87 | 88 | edge_from = data.edge_index[0][torch.isin(data.edge_index[1], data['agent_index'])] 89 | edge_to = data.edge_index[1][torch.isin(data.edge_index[1], data['agent_index'])] 90 | _, new_edge_to = torch.unique(edge_to, return_inverse=True) 91 | new_edge_ = torch.stack((edge_from,new_edge_to+actor_num), 0) 92 | new_edge = torch.cat((data.edge_index, new_edge_), -1) 93 | 94 | x_agent = data.x[data['agent_index']] 95 | x_actors = torch.cat((data.x, x_agent+2*torch.randn_like(x_agent)), dim=0) 96 | actors_pad = torch.cat((actors_pad, actors_pad[data['agent_index']]), dim=0) 97 | actors_mask = ~actors_pad[:,:self.ref_time+1] 98 | new_positions = torch.cat((data['positions'], data['positions'][data['agent_index']]), 0) 99 | new_bos_mask = torch.cat((data['bos_mask'], data['bos_mask'][data['agent_index']]), 0) 100 | new_rotate_mat = torch.cat((data['rotate_mat'], data['rotate_mat'][data['agent_index']]), 0) 101 | new_agent_index = torch.cat((data['agent_index'], torch.arange(actor_num,actor_num+data['agent_index'].size(0)).to(data['agent_index'].device))) 102 | 103 | nus_mask = torch.cat((nus_mask,(data.source == 0)), dim=0) 104 | 105 | ##################################### 106 | 107 | for t in range(self.historical_steps): 108 | data[f'edge_index_{t}'], _ = subgraph(subset=~actors_pad[:, t], edge_index=new_edge) 109 | data[f'edge_attr_{t}'] = \ 110 | new_positions[data[f'edge_index_{t}'][0], t] - new_positions[data[f'edge_index_{t}'][1], t] 111 | 112 | if self.parallel: 113 | snapshots = [None] * self.historical_steps 114 | for t in range(self.historical_steps): 115 | edge_index, edge_attr = self.drop_edge(data[f'edge_index_{t}'], data[f'edge_attr_{t}']) 116 | snapshots[t] = Data(x=x_actors[:, t], edge_index=edge_index, edge_attr=edge_attr, 117 | num_nodes=data.num_nodes+len(data['agent_index'])) 118 | batch = Batch.from_data_list(snapshots) 119 | aa_out = self.aa_encoder(x=batch.x, t=None, edge_index=batch.edge_index, edge_attr=batch.edge_attr, 120 | bos_mask=new_bos_mask, rotate_mat=new_rotate_mat) 121 | aa_out = aa_out.view(self.historical_steps, aa_out.shape[0] // self.historical_steps, -1) 122 | else: 123 | raise NotImplementedError 124 | 125 | ########################################### 126 | 127 | 128 | past_time_steps = torch.linspace(-self.max_past_t,0,self.historical_steps) 129 | past_time_steps = -1*past_time_steps 130 | time_points_iter = range(0, past_time_steps.size(-1)) 131 | if self.run_backwards: 132 | prev_t, t_i = past_time_steps[-1] - 0.01, past_time_steps[-1] 133 | time_points_iter = reversed(time_points_iter) 134 | else: 135 | prev_t, t_i = past_time_steps[0] - 0.01, past_time_steps[0] 136 | 137 | latent_ys=[] 138 | diffusions=[] 139 | 140 | for idx, t in enumerate(time_points_iter): 141 | 142 | time_points = torch.tensor([prev_t, t_i]) 143 | 144 | ############ DiffeqSolver 풀어서 쓰기 ############## 145 | first_point, time_steps_to_predict = prev_hidden, time_points 146 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 147 | 148 | 149 | pred_y, diff_noise = sdeint_dual(self.lsde_func, first_point, time_steps_to_predict, nus_mask, dt=self.minimum_step, 150 | rtol = self.rtol, atol = self.atol, method = self.method) 151 | pred_y = pred_y.permute(1,2,0) 152 | 153 | assert(pred_y.size()[0] == n_traj_samples) 154 | assert(pred_y.size()[1] == n_traj) 155 | 156 | # ode_sol = pred_y[0].permute(0,2,1) 157 | ode_sol = pred_y 158 | #################################################### 159 | 160 | if torch.mean(ode_sol[:, :, 0] - prev_hidden) >= 0.001: 161 | print("Error: first point of the ODE is not equal to initial value") 162 | print(torch.mean(ode_sol[:, :, 0] - prev_hidden)) 163 | exit() 164 | 165 | yi_ode = ode_sol[:, :, -1] 166 | xi = aa_out[t] 167 | maski = actors_mask[:,t] 168 | 169 | yi = self.gru_unit(input_tensor=xi, h_cur=yi_ode, mask=maski).squeeze(0) 170 | 171 | diffusion_t = diff_noise[new_agent_index] 172 | 173 | # return to iteration 174 | prev_hidden = yi 175 | if idx+1 < past_time_steps.size(-1): 176 | if self.run_backwards: 177 | prev_t, t_i = past_time_steps[t], past_time_steps[t-1] 178 | else: 179 | prev_t, t_i = past_time_steps[t], past_time_steps[t+1] 180 | 181 | latent_ys.append(yi) 182 | diffusions.append(diffusion_t) 183 | 184 | latent_ys = torch.stack(latent_ys)[:,:-len(data['agent_index'])] 185 | diffusions = torch.stack(diffusions) 186 | 187 | eos_idcs = self.ref_time - torch.argmax(data['bos_mask'].float(), dim=1) 188 | out = latent_ys[eos_idcs, torch.arange(latent_ys.size(1)),:] 189 | 190 | agent_eos_idcs = eos_idcs[data['agent_index']] 191 | diff_out = diffusions[agent_eos_idcs.repeat(2), torch.arange(diffusions.size(1))] 192 | 193 | 194 | diffusions_in, diffusions_out = torch.chunk(diff_out,2,0) 195 | in_labels = torch.full_like(diffusions_in, self.real_label) 196 | out_labels = torch.full_like(diffusions_out, self.fake_label) 197 | 198 | edge_index, edge_attr = self.drop_edge(data['lane_actor_index'], data['lane_actor_vectors']) 199 | out = self.al_encoder(x=(lane_feat, out), edge_index=edge_index, edge_attr=edge_attr, 200 | rotate_mat=data['rotate_mat']) 201 | 202 | return out, diffusions_in, diffusions_out, in_labels, out_labels 203 | 204 | def forward_ood(self, data: TemporalData) -> torch.Tensor: 205 | 206 | lane_len = (1-data['lane_paddings']).sum(-1) 207 | lane_start_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),0,:] 208 | lane_end_pos = data['lane_positions'][torch.arange(data['lane_positions'].size(0)),(lane_len-1).long(),:] 209 | lane_feat = lane_end_pos - lane_start_pos 210 | 211 | nus_batches = torch.where(data.source == 0)[0] 212 | nus_mask = torch.isin(data.batch, nus_batches) 213 | 214 | actor_num, ts_obs, in_dim = data.x.shape 215 | 216 | actors_pad = data['padding_mask'] 217 | actors_past_pad, actors_fut_pad = actors_pad[:,:self.ref_time+1], actors_pad[:,self.ref_time+1:] 218 | actors_past_mask = ~actors_past_pad 219 | 220 | ######### Pre-compute AA encoding ######### 221 | 222 | actors_mask = ~actors_pad[:,:self.ref_time+1] 223 | 224 | ##################################### 225 | 226 | for t in range(self.historical_steps): 227 | data[f'edge_index_{t}'], _ = subgraph(subset=~actors_pad[:, t], edge_index=data.edge_index) 228 | data[f'edge_attr_{t}'] = \ 229 | data['positions'][data[f'edge_index_{t}'][0], t] - data['positions'][data[f'edge_index_{t}'][1], t] 230 | 231 | if self.parallel: 232 | snapshots = [None] * self.historical_steps 233 | for t in range(self.historical_steps): 234 | edge_index, edge_attr = self.drop_edge(data[f'edge_index_{t}'], data[f'edge_attr_{t}']) 235 | snapshots[t] = Data(x=data.x[:, t], edge_index=edge_index, edge_attr=edge_attr, 236 | num_nodes=data.num_nodes) 237 | batch = Batch.from_data_list(snapshots) 238 | aa_out = self.aa_encoder(x=batch.x, t=None, edge_index=batch.edge_index, edge_attr=batch.edge_attr, 239 | bos_mask=data['bos_mask'], rotate_mat=data['rotate_mat']) 240 | aa_out = aa_out.view(self.historical_steps, aa_out.shape[0] // self.historical_steps, -1) 241 | else: 242 | raise NotImplementedError 243 | 244 | ########################################### 245 | 246 | past_time_steps = torch.linspace(-self.max_past_t,0,self.historical_steps) 247 | past_time_steps = -1*past_time_steps 248 | 249 | # eos_idcs = self.ref_time - torch.argmax(agent_past_mask.float(), dim=1) 250 | eos_idcs = self.ref_time - torch.argmax(data['bos_mask'].float(), dim=1) 251 | 252 | eval_iter = 10 253 | 254 | outs = [] 255 | for j in range(eval_iter): 256 | 257 | prev_hidden = torch.zeros((actor_num, self.embed_dim), device=data.x.device) 258 | time_points_iter = range(0, past_time_steps.size(-1)) 259 | 260 | if self.run_backwards: 261 | prev_t, t_i = past_time_steps[-1] - 0.01, past_time_steps[-1] 262 | 263 | time_points_iter = reversed(time_points_iter) 264 | else: 265 | prev_t, t_i = past_time_steps[0] - 0.01, past_time_steps[0] 266 | 267 | latent_ys=[] 268 | for idx, t in enumerate(time_points_iter): 269 | time_points = torch.tensor([prev_t, t_i]) 270 | ############ DiffeqSolver 풀어서 쓰기 ############## 271 | first_point, time_steps_to_predict = prev_hidden, time_points 272 | n_traj_samples, n_traj = first_point.size()[0], first_point.size()[1] 273 | 274 | pred_y, diff_noise = sdeint_dual(self.lsde_func, first_point, time_steps_to_predict, nus_mask, dt=self.minimum_step, 275 | rtol = self.rtol, atol = self.atol, method = self.method) 276 | pred_y = pred_y.permute(1,2,0) 277 | 278 | assert(pred_y.size()[0] == n_traj_samples) 279 | assert(pred_y.size()[1] == n_traj) 280 | 281 | # ode_sol = pred_y[0].permute(0,2,1) 282 | ode_sol = pred_y 283 | #################################################### 284 | 285 | if torch.mean(ode_sol[:, :, 0] - prev_hidden) >= 0.001: 286 | print("Error: first point of the ODE is not equal to initial value") 287 | print(torch.mean(ode_sol[:, :, 0] - prev_hidden)) 288 | exit() 289 | 290 | yi_ode = ode_sol[:, :, -1] 291 | xi = aa_out[t] 292 | maski = actors_mask[:,t] 293 | 294 | yi = self.gru_unit(input_tensor=xi, h_cur=yi_ode, mask=maski).squeeze(0) 295 | 296 | # return to iteration 297 | prev_hidden = yi 298 | if idx+1 < past_time_steps.size(-1): 299 | if self.run_backwards: 300 | prev_t, t_i = past_time_steps[t], past_time_steps[t-1] 301 | else: 302 | prev_t, t_i = past_time_steps[t], past_time_steps[t+1] 303 | 304 | latent_ys.append(yi) 305 | 306 | latent_ys = torch.stack(latent_ys) 307 | 308 | out = latent_ys[eos_idcs, torch.arange(latent_ys.size(1)),:] 309 | outs.append(out) 310 | 311 | outs = torch.stack(outs) 312 | actors_std = outs.std(0).mean(-1) 313 | out = outs.mean(0) 314 | 315 | edge_index, edge_attr = self.drop_edge(data['lane_actor_index'], data['lane_actor_vectors']) 316 | out = self.al_encoder(x=(lane_feat, out), edge_index=edge_index, edge_attr=edge_attr, 317 | rotate_mat=data['rotate_mat']) 318 | 319 | 320 | # # OOD, IND visualization 해보는 코드 321 | # num_samples2show = 10 322 | 323 | # past_ts = torch.linspace(-2, 0,21) 324 | 325 | # ood_actors = data['positions'][torch.topk(actors_std,num_samples2show,0,True)[1]][:,:21] 326 | # ood_masks = ~data['padding_mask'][torch.topk(actors_std,num_samples2show,0,True)[1]][:,:21] 327 | # import matplotlib.pyplot as plt 328 | # for ai, pos in enumerate(ood_actors[:100]): 329 | # fig, ax = plt.subplots(1,1,figsize=(5,5)) 330 | # pos_ = pos[ood_masks[ai]].detach().cpu() 331 | # ts_ = past_ts[ood_masks[ai]] 332 | # for i_, pos__ in enumerate(pos_): 333 | # ax.scatter(pos__[0], pos__[1], c='b') 334 | # ax.text(pos__[0], pos__[1], f'{ts_[i_].item():.1f}') 335 | # ax.scatter(pos_[-1,0], pos_[-1,1], c='r') 336 | # ax.set_aspect('equal') 337 | # plt.margins(0.5) 338 | # plt.savefig(f'tmp/tmp_ood/{ai}.jpg') 339 | # plt.close() 340 | 341 | # in_actors = data['positions'][torch.topk(actors_std,num_samples2show,0,False)[1]][:,:21] 342 | # in_masks = ~data['padding_mask'][torch.topk(actors_std,num_samples2show,0,False)[1]][:,:21] 343 | # import matplotlib.pyplot as plt 344 | # for ai, pos in enumerate(in_actors[:100]): 345 | # fig, ax = plt.subplots(1,1,figsize=(5,5)) 346 | # pos_ = pos[in_masks[ai]].detach().cpu() 347 | # ts_ = past_ts[in_masks[ai]] 348 | # for i_, pos__ in enumerate(pos_): 349 | # ax.scatter(pos__[0], pos__[1], c='b') 350 | # ax.text(pos__[0], pos__[1], f'{ts_[i_].item():.1f}') 351 | # ax.scatter(pos_[-1,0], pos_[-1,1], c='r') 352 | # ax.set_aspect('equal') 353 | # plt.margins(0.5) 354 | # plt.savefig(f'tmp/tmp_in/{ai}.jpg') 355 | # plt.close() 356 | 357 | # mask = ((actors_std>3.5).float() * (actors_std<4.5).float()).bool() 358 | # in_actors = data['positions'][mask] 359 | # in_masks = ~data['padding_mask'][mask] 360 | # import matplotlib.pyplot as plt 361 | # for ai, pos in enumerate(in_actors[:100]): 362 | # fig, ax = plt.subplots(1,1,figsize=(5,5)) 363 | # pos_ = pos[in_masks[ai]].detach().cpu() 364 | # ax.scatter(pos_[-1,0], pos_[-1,1]) 365 | # ax.plot(pos_[:,0], pos_[:,1]) 366 | # ax.set_aspect('equal') 367 | # plt.savefig(f'tmp_norm/{ai}.jpg') 368 | # plt.close() 369 | 370 | return out, actors_std 371 | 372 | class FFunc(nn.Module): 373 | """Posterior drift.""" 374 | def __init__(self, embed_dim, num_layers=2): 375 | super(FFunc, self).__init__() 376 | net_list = [] 377 | net_list.append(nn.Linear(embed_dim+2, embed_dim)) 378 | for _ in range(num_layers): 379 | net_list.append(nn.Tanh()) 380 | net_list.append(nn.Linear(embed_dim, embed_dim)) 381 | # self.net = nn.Sequential( 382 | # nn.Linear(embed_dim+2, embed_dim), 383 | # nn.Tanh(), 384 | # nn.Linear(embed_dim, embed_dim), 385 | # nn.Tanh(), 386 | # nn.Linear(embed_dim, embed_dim) 387 | # ) 388 | self.net = nn.Sequential(*nn.ModuleList(net_list)) 389 | 390 | def forward(self, t, y): 391 | # if t.dim() == 0: 392 | # t = float(t) * torch.ones_like(y) 393 | # # Positional encoding in transformers; must use `t`, since the posterior is likely inhomogeneous. 394 | # inp = torch.cat((torch.sin(t), torch.cos(t), y), dim=-1) 395 | _t = torch.ones(y.size(0), 1) * float(t) 396 | _t = _t.to(y) 397 | inp = torch.cat((y,torch.sin(_t), torch.cos(_t)), dim=-1) 398 | return self.net(inp) 399 | 400 | 401 | class HFunc(nn.Module): 402 | """Prior drift""" 403 | def __init__(self, theta=1.0, mu=0.0): 404 | super(HFunc, self).__init__() 405 | self.theta = nn.Parameter(torch.tensor([[theta]]), requires_grad=False) 406 | self.mu = nn.Parameter(torch.tensor([[mu]]), requires_grad=False) 407 | 408 | def forward(self, t, y): 409 | return self.theta * (self.mu - y) 410 | 411 | 412 | class GFunc(nn.Module): 413 | """Diffusion""" 414 | def __init__(self, embed_dim, sigma=0.5, num_layers=2): 415 | super(GFunc, self).__init__() 416 | # self.sigma = nn.Parameter(torch.tensor([[sigma]]), requires_grad=False) 417 | 418 | net_list = [] 419 | net_list.append(nn.Linear(embed_dim+2, embed_dim)) 420 | for _ in range(num_layers-1): 421 | net_list.append(nn.Tanh()) 422 | net_list.append(nn.Linear(embed_dim, embed_dim)) 423 | net_list.append(nn.Tanh()) 424 | net_list.append(nn.Linear(embed_dim, 1)) 425 | 426 | self.net = nn.Sequential(*nn.ModuleList(net_list)) 427 | 428 | # self.net = nn.Sequential( 429 | # nn.Linear(embed_dim+2, embed_dim), 430 | # nn.Tanh(), 431 | # nn.Linear(embed_dim, embed_dim), 432 | # nn.Tanh(), 433 | # nn.Linear(embed_dim, 1) 434 | # ) 435 | 436 | def forward(self, t, y): 437 | _t = torch.ones(y.size(0), 1) * float(t) 438 | _t = _t.to(y) 439 | out = self.net(torch.cat((y, torch.sin(_t), torch.cos(_t)), dim=-1)) 440 | return torch.sigmoid(out) 441 | 442 | class LSDEFunc(torchsde.SDEIto): 443 | def __init__(self, f, g_nus, g_Argo2, h, embed_dim, order=1): 444 | super().__init__(noise_type="diagonal") 445 | self.order, self.intloss, self.sensitivity = order, None, None 446 | self.f_func, self.g_nus, self.g_argo, self.h_func = f, g_nus, g_Argo2, h 447 | self.fnfe, self.gnfe, self.hnfe = 0, 0, 0 448 | self.embed_dim = embed_dim 449 | 450 | 451 | def forward(self, s, x): 452 | pass 453 | 454 | def h(self, s, x): 455 | """ Prior drift 456 | :param s: 457 | :param x: 458 | """ 459 | self.hnfe += 1 460 | return self.h_func(t=s, y=x) 461 | 462 | def f(self, s, x): 463 | """Posterior drift. 464 | :param s: 465 | :param x: 466 | """ 467 | self.fnfe += 1 468 | return self.f_func(t=s, y=x) 469 | 470 | def g(self, s, x, nus_mask): 471 | """Diffusion. Dual diffusion network, separated by nus, argo2 472 | :param s: 473 | :param x: 474 | """ 475 | self.gnfe += 1 476 | # out = self.g_nus(t=s, y=x).repeat(1,self.embed_dim) 477 | out = torch.empty(x.size(0), self.embed_dim, device=x.device) 478 | out_0 = self.g_nus(t=s, y=x[nus_mask]) 479 | out_1 = self.g_argo(t=s, y=x[~nus_mask]) 480 | out[nus_mask,:] = out_0.repeat(1,self.embed_dim) 481 | out[~nus_mask,:] = out_1.repeat(1,self.embed_dim) 482 | return out 483 | 484 | # def g(self, s, x): 485 | # """Diffusion. 486 | # :param s: 487 | # :param x: 488 | # """ 489 | # self.gnfe += 1 490 | # out = self.g_nus(t=s, y=x).repeat(1,self.embed_dim) 491 | # # out = torch.zeros(x.size(0),self.embed_dim).to(x.device) 492 | # # out_0 = self.g_nus(t=s, y=x[nus_mask]) 493 | # # out_1 = self.g_argo(t=s, y=x[~nus_mask]) 494 | # # out[nus_mask,:] = out_0.repeat(1,self.embed_dim) 495 | # # out[~nus_mask,:] = out_1.repeat(1,self.embed_dim) 496 | # return out 497 | 498 | class AAEncoder(MessagePassing): 499 | 500 | def __init__(self, 501 | historical_steps: int, 502 | node_dim: int, 503 | edge_dim: int, 504 | embed_dim: int, 505 | num_heads: int = 8, 506 | dropout: float = 0.1, 507 | parallel: bool = False, 508 | **kwargs) -> None: 509 | super(AAEncoder, self).__init__(aggr='add', node_dim=0, **kwargs) 510 | self.historical_steps = historical_steps 511 | self.embed_dim = embed_dim 512 | self.num_heads = num_heads 513 | self.parallel = parallel 514 | 515 | self.center_embed = SingleInputEmbedding(in_channel=node_dim, out_channel=embed_dim) 516 | self.nbr_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim], out_channel=embed_dim) 517 | self.lin_q = nn.Linear(embed_dim, embed_dim) 518 | self.lin_k = nn.Linear(embed_dim, embed_dim) 519 | self.lin_v = nn.Linear(embed_dim, embed_dim) 520 | self.lin_self = nn.Linear(embed_dim, embed_dim) 521 | self.attn_drop = nn.Dropout(dropout) 522 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 523 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 524 | self.out_proj = nn.Linear(embed_dim, embed_dim) 525 | self.proj_drop = nn.Dropout(dropout) 526 | self.norm1 = nn.LayerNorm(embed_dim) 527 | self.norm2 = nn.LayerNorm(embed_dim) 528 | self.mlp = nn.Sequential( 529 | nn.Linear(embed_dim, embed_dim * 4), 530 | nn.ReLU(inplace=True), 531 | nn.Dropout(dropout), 532 | nn.Linear(embed_dim * 4, embed_dim), 533 | nn.Dropout(dropout)) 534 | self.bos_token = nn.Parameter(torch.Tensor(historical_steps, embed_dim)) 535 | nn.init.normal_(self.bos_token, mean=0., std=.02) 536 | self.apply(init_weights) 537 | 538 | def forward(self, 539 | x: torch.Tensor, 540 | t: Optional[int], 541 | edge_index: Adj, 542 | edge_attr: torch.Tensor, 543 | bos_mask: torch.Tensor, 544 | rotate_mat: Optional[torch.Tensor] = None, 545 | size: Size = None) -> torch.Tensor: 546 | if self.parallel: 547 | if rotate_mat is None: 548 | center_embed = self.center_embed(x.view(self.historical_steps, x.shape[0] // self.historical_steps, -1)) 549 | else: 550 | center_embed = self.center_embed( 551 | torch.matmul(x.view(self.historical_steps, x.shape[0] // self.historical_steps, -1).unsqueeze(-2), 552 | rotate_mat.expand(self.historical_steps, *rotate_mat.shape)).squeeze(-2)) 553 | 554 | center_embed = torch.where(bos_mask.t().unsqueeze(-1), 555 | self.bos_token.unsqueeze(-2), 556 | center_embed).contiguous().view(x.shape[0], -1) 557 | else: 558 | if rotate_mat is None: 559 | center_embed = self.center_embed(x) 560 | else: 561 | center_embed = self.center_embed(torch.bmm(x.unsqueeze(-2), rotate_mat).squeeze(-2)) 562 | center_embed = torch.where(bos_mask.unsqueeze(-1), self.bos_token[t], center_embed) 563 | center_embed = center_embed + self._mha_block(self.norm1(center_embed), x, edge_index, edge_attr, rotate_mat, 564 | size) 565 | center_embed = center_embed + self._ff_block(self.norm2(center_embed)) 566 | return center_embed 567 | 568 | def message(self, 569 | edge_index: Adj, 570 | center_embed_i: torch.Tensor, 571 | x_j: torch.Tensor, 572 | edge_attr: torch.Tensor, 573 | rotate_mat: Optional[torch.Tensor], 574 | index: torch.Tensor, 575 | ptr: OptTensor, 576 | size_i: Optional[int]) -> torch.Tensor: 577 | if rotate_mat is None: 578 | nbr_embed = self.nbr_embed([x_j, edge_attr]) 579 | else: 580 | if self.parallel: 581 | center_rotate_mat = rotate_mat.repeat(self.historical_steps, 1, 1)[edge_index[1]] 582 | else: 583 | center_rotate_mat = rotate_mat[edge_index[1]] 584 | nbr_embed = self.nbr_embed([torch.bmm(x_j.unsqueeze(-2), center_rotate_mat).squeeze(-2), 585 | torch.bmm(edge_attr.unsqueeze(-2), center_rotate_mat).squeeze(-2)]) 586 | query = self.lin_q(center_embed_i).view(-1, self.num_heads, self.embed_dim // self.num_heads) 587 | key = self.lin_k(nbr_embed).view(-1, self.num_heads, self.embed_dim // self.num_heads) 588 | value = self.lin_v(nbr_embed).view(-1, self.num_heads, self.embed_dim // self.num_heads) 589 | scale = (self.embed_dim // self.num_heads) ** 0.5 590 | alpha = (query * key).sum(dim=-1) / scale 591 | alpha = softmax(alpha, index, ptr, size_i) 592 | alpha = self.attn_drop(alpha) 593 | return value * alpha.unsqueeze(-1) 594 | 595 | def update(self, 596 | inputs: torch.Tensor, 597 | center_embed: torch.Tensor) -> torch.Tensor: 598 | inputs = inputs.view(-1, self.embed_dim) 599 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(center_embed)) 600 | return inputs + gate * (self.lin_self(center_embed) - inputs) 601 | 602 | def _mha_block(self, 603 | center_embed: torch.Tensor, 604 | x: torch.Tensor, 605 | edge_index: Adj, 606 | edge_attr: torch.Tensor, 607 | rotate_mat: Optional[torch.Tensor], 608 | size: Size) -> torch.Tensor: 609 | center_embed = self.out_proj(self.propagate(edge_index=edge_index, x=x, center_embed=center_embed, 610 | edge_attr=edge_attr, rotate_mat=rotate_mat, size=size)) 611 | return self.proj_drop(center_embed) 612 | 613 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 614 | return self.mlp(x) 615 | 616 | 617 | class TemporalEncoder(nn.Module): 618 | 619 | def __init__(self, 620 | historical_steps: int, 621 | embed_dim: int, 622 | num_heads: int = 8, 623 | num_layers: int = 4, 624 | dropout: float = 0.1) -> None: 625 | super(TemporalEncoder, self).__init__() 626 | encoder_layer = TemporalEncoderLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) 627 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers, 628 | norm=nn.LayerNorm(embed_dim)) 629 | self.padding_token = nn.Parameter(torch.Tensor(historical_steps, 1, embed_dim)) 630 | self.cls_token = nn.Parameter(torch.Tensor(1, 1, embed_dim)) 631 | self.pos_embed = nn.Parameter(torch.Tensor(historical_steps + 1, 1, embed_dim)) 632 | attn_mask = self.generate_square_subsequent_mask(historical_steps + 1) 633 | self.register_buffer('attn_mask', attn_mask) 634 | nn.init.normal_(self.padding_token, mean=0., std=.02) 635 | nn.init.normal_(self.cls_token, mean=0., std=.02) 636 | nn.init.normal_(self.pos_embed, mean=0., std=.02) 637 | self.apply(init_weights) 638 | 639 | def forward(self, 640 | x: torch.Tensor, 641 | padding_mask: torch.Tensor) -> torch.Tensor: 642 | x = torch.where(padding_mask.t().unsqueeze(-1), self.padding_token, x) 643 | expand_cls_token = self.cls_token.expand(-1, x.shape[1], -1) 644 | x = torch.cat((x, expand_cls_token), dim=0) 645 | x = x + self.pos_embed 646 | out = self.transformer_encoder(src=x, mask=self.attn_mask, src_key_padding_mask=None) 647 | return out[-1] # [N, D] 648 | 649 | @staticmethod 650 | def generate_square_subsequent_mask(seq_len: int) -> torch.Tensor: 651 | mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1) 652 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 653 | return mask 654 | 655 | 656 | class TemporalEncoderLayer(nn.Module): 657 | 658 | def __init__(self, 659 | embed_dim: int, 660 | num_heads: int = 8, 661 | dropout: float = 0.1) -> None: 662 | super(TemporalEncoderLayer, self).__init__() 663 | self.self_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) 664 | self.linear1 = nn.Linear(embed_dim, embed_dim * 4) 665 | self.dropout = nn.Dropout(dropout) 666 | self.linear2 = nn.Linear(embed_dim * 4, embed_dim) 667 | self.norm1 = nn.LayerNorm(embed_dim) 668 | self.norm2 = nn.LayerNorm(embed_dim) 669 | self.dropout1 = nn.Dropout(dropout) 670 | self.dropout2 = nn.Dropout(dropout) 671 | 672 | def forward(self, 673 | src: torch.Tensor, 674 | src_mask: Optional[torch.Tensor] = None, 675 | src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: 676 | x = src 677 | x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) 678 | x = x + self._ff_block(self.norm2(x)) 679 | return x 680 | 681 | def _sa_block(self, 682 | x: torch.Tensor, 683 | attn_mask: Optional[torch.Tensor], 684 | key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor: 685 | x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0] 686 | return self.dropout1(x) 687 | 688 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 689 | x = self.linear2(self.dropout(F.relu_(self.linear1(x)))) 690 | return self.dropout2(x) 691 | 692 | 693 | class ALEncoder(MessagePassing): 694 | 695 | def __init__(self, 696 | node_dim: int, 697 | edge_dim: int, 698 | embed_dim: int, 699 | num_heads: int = 8, 700 | dropout: float = 0.1, 701 | **kwargs) -> None: 702 | super(ALEncoder, self).__init__(aggr='add', node_dim=0, **kwargs) 703 | self.embed_dim = embed_dim 704 | self.num_heads = num_heads 705 | 706 | self.lane_embed = MultipleInputEmbedding(in_channels=[node_dim, edge_dim], out_channel=embed_dim) 707 | self.lin_q = nn.Linear(embed_dim, embed_dim) 708 | self.lin_k = nn.Linear(embed_dim, embed_dim) 709 | self.lin_v = nn.Linear(embed_dim, embed_dim) 710 | self.lin_self = nn.Linear(embed_dim, embed_dim) 711 | self.attn_drop = nn.Dropout(dropout) 712 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 713 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 714 | self.out_proj = nn.Linear(embed_dim, embed_dim) 715 | self.proj_drop = nn.Dropout(dropout) 716 | self.norm1 = nn.LayerNorm(embed_dim) 717 | self.norm2 = nn.LayerNorm(embed_dim) 718 | self.mlp = nn.Sequential( 719 | nn.Linear(embed_dim, embed_dim * 4), 720 | nn.ReLU(inplace=True), 721 | nn.Dropout(dropout), 722 | nn.Linear(embed_dim * 4, embed_dim), 723 | nn.Dropout(dropout)) 724 | self.is_intersection_embed = nn.Parameter(torch.Tensor(2, embed_dim)) 725 | self.turn_direction_embed = nn.Parameter(torch.Tensor(3, embed_dim)) 726 | self.traffic_control_embed = nn.Parameter(torch.Tensor(2, embed_dim)) 727 | nn.init.normal_(self.is_intersection_embed, mean=0., std=.02) 728 | nn.init.normal_(self.turn_direction_embed, mean=0., std=.02) 729 | nn.init.normal_(self.traffic_control_embed, mean=0., std=.02) 730 | self.apply(init_weights) 731 | 732 | def forward(self, 733 | x: Tuple[torch.Tensor, torch.Tensor], 734 | edge_index: Adj, 735 | edge_attr: torch.Tensor, 736 | is_intersections: torch.Tensor = None, 737 | turn_directions: torch.Tensor = None, 738 | traffic_controls: torch.Tensor = None, 739 | rotate_mat: Optional[torch.Tensor] = None, 740 | size: Size = None) -> torch.Tensor: 741 | x_lane, x_actor = x 742 | 743 | x_actor = x_actor + self._mha_block(self.norm1(x_actor), x_lane, edge_index, edge_attr, rotate_mat, size) 744 | x_actor = x_actor + self._ff_block(self.norm2(x_actor)) 745 | return x_actor 746 | 747 | def message(self, 748 | edge_index: Adj, 749 | x_i: torch.Tensor, 750 | x_j: torch.Tensor, 751 | edge_attr: torch.Tensor, 752 | # is_intersections_j, 753 | # turn_directions_j, 754 | # traffic_controls_j, 755 | rotate_mat: Optional[torch.Tensor], 756 | index: torch.Tensor, 757 | ptr: OptTensor, 758 | size_i: Optional[int]) -> torch.Tensor: 759 | if rotate_mat is None: 760 | x_j = self.lane_embed([x_j, edge_attr]) 761 | else: 762 | rotate_mat = rotate_mat[edge_index[1]] 763 | x_j = self.lane_embed([torch.bmm(x_j.unsqueeze(-2), rotate_mat).squeeze(-2), 764 | torch.bmm(edge_attr.unsqueeze(-2), rotate_mat).squeeze(-2)]) 765 | query = self.lin_q(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads) 766 | key = self.lin_k(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 767 | value = self.lin_v(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads) 768 | scale = (self.embed_dim // self.num_heads) ** 0.5 769 | alpha = (query * key).sum(dim=-1) / scale 770 | alpha = softmax(alpha, index, ptr, size_i) 771 | alpha = self.attn_drop(alpha) 772 | return value * alpha.unsqueeze(-1) 773 | 774 | def update(self, 775 | inputs: torch.Tensor, 776 | x: torch.Tensor) -> torch.Tensor: 777 | x_actor = x[1] 778 | inputs = inputs.view(-1, self.embed_dim) 779 | gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x_actor)) 780 | return inputs + gate * (self.lin_self(x_actor) - inputs) 781 | 782 | def _mha_block(self, 783 | x_actor: torch.Tensor, 784 | x_lane: torch.Tensor, 785 | edge_index: Adj, 786 | edge_attr: torch.Tensor, 787 | # is_intersections: torch.Tensor, 788 | # turn_directions: torch.Tensor, 789 | # traffic_controls: torch.Tensor, 790 | rotate_mat: Optional[torch.Tensor], 791 | size: Size) -> torch.Tensor: 792 | x_actor = self.out_proj(self.propagate(edge_index=edge_index, x=(x_lane, x_actor), edge_attr=edge_attr, 793 | rotate_mat=rotate_mat, size=size)) 794 | return self.proj_drop(x_actor) 795 | 796 | def _ff_block(self, x_actor: torch.Tensor) -> torch.Tensor: 797 | return self.mlp(x_actor) --------------------------------------------------------------------------------