├── runner ├── _init_path.py ├── scripts │ ├── dist_test.sh │ └── dist_train.sh ├── utils │ ├── starter │ │ ├── data.py │ │ ├── network.py │ │ └── config_parser.py │ ├── submission.py │ ├── tester.py │ └── eval.py ├── train.py └── cfgs │ └── waymo │ ├── trajflow+100_percent_data.yaml │ └── trajflow+20_percent_data.yaml ├── assets └── trajflow_overview.png ├── setup ├── requirements.txt └── setup_trajflow.py ├── .gitignore ├── trajflow ├── mtr_ops │ ├── attention │ │ ├── __init__.py │ │ ├── src │ │ │ ├── attention_api.cpp │ │ │ ├── attention_func.h │ │ │ ├── attention_func_v2.h │ │ │ ├── attention_value_computation_kernel.cu │ │ │ ├── attention_weight_computation_kernel.cu │ │ │ ├── attention_func.cpp │ │ │ └── attention_func_v2.cpp │ │ ├── attention_utils.py │ │ └── attention_utils_v2.py │ └── knn │ │ ├── src │ │ ├── knn_api.cpp │ │ ├── knn_gpu.h │ │ ├── knn.cpp │ │ └── knn_gpu.cu │ │ └── knn_utils.py ├── models │ ├── __init__.py │ ├── layers │ │ ├── polyline_encoder.py │ │ ├── transformer │ │ │ └── transformer_encoder_layer.py │ │ └── common_layers.py │ ├── denoising_decoder │ │ ├── build_network.py │ │ ├── compute_loss.py │ │ └── decoder_utils.py │ ├── dmt_model.py │ └── context_encoder │ │ └── mtr_encoder.py ├── utils │ ├── init_objective.py │ ├── mtr_loss_utils.py │ ├── denoising_data_rescale.py │ └── motion_utils.py ├── datasets │ ├── __init__.py │ ├── waymo │ │ └── waymo_types.py │ └── dataset.py └── config.py ├── LICENSE └── README.md /runner/_init_path.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../') 3 | -------------------------------------------------------------------------------- /assets/trajflow_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSL-Lab/TrajFlow/HEAD/assets/trajflow_overview.png -------------------------------------------------------------------------------- /setup/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | wandb 3 | easydict 4 | pyyaml 5 | scikit-image 6 | tqdm 7 | einops 8 | scikit-learn 9 | matplotlib 10 | accelerate 11 | einops 12 | ema-pytorch 13 | gitpython 14 | protobuf===3.20.3 15 | typing-extensions===4.5.0 16 | GitPython -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | cuda_toolkits/ 3 | tmp/ 4 | cuda*.run 5 | */version.py 6 | 7 | **__pycache__** 8 | *.pkl 9 | *.so 10 | *.npz 11 | *.stat 12 | output/ 13 | build 14 | data/ 15 | **egg-info** 16 | *.pdf 17 | *.png 18 | *.html 19 | .idea 20 | *sandbox 21 | *sif 22 | helper*.sh 23 | *slurm* -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/ops/attention 3 | """ 4 | 5 | from . import attention_utils 6 | from . import attention_utils_v2 7 | 8 | __all__ = { 9 | 'v1': attention_utils, 10 | 'v2': attention_utils_v2, 11 | } -------------------------------------------------------------------------------- /trajflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from .context_encoder.mtr_encoder import MTREncoder 9 | from .denoising_decoder.denoising_decoder import DenoisingDecoder 10 | 11 | -------------------------------------------------------------------------------- /runner/scripts/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | NGPUS=$1 5 | PY_ARGS=${@:2} 6 | 7 | 8 | while true 9 | do 10 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 11 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 12 | if [ "${status}" != "0" ]; then 13 | break; 14 | fi 15 | done 16 | echo $PORT 17 | 18 | torchrun --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} test.py --launcher pytorch ${PY_ARGS} 19 | -------------------------------------------------------------------------------- /runner/scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | NGPUS=$1 5 | PY_ARGS=${@:2} 6 | 7 | 8 | while true 9 | do 10 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 11 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 12 | if [ "${status}" != "0" ]; then 13 | break; 14 | fi 15 | done 16 | echo $PORT 17 | 18 | torchrun --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} train.py --launcher pytorch ${PY_ARGS} 19 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/knn/src/knn_api.cpp: -------------------------------------------------------------------------------- 1 | // Motion Transformer (MTR): Motion Forecasting Transformer with Global Intention Localization and Local Movement Refinement 2 | // Written by Shaoshuai Shi 3 | // All Rights Reserved 4 | 5 | 6 | #include 7 | #include 8 | 9 | #include "knn_gpu.h" 10 | 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("knn_batch", &knn_batch, "knn_batch"); 14 | m.def("knn_batch_mlogk", &knn_batch_mlogk, "knn_batch_mlogk"); 15 | } 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Qi Yan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /trajflow/utils/init_objective.py: -------------------------------------------------------------------------------- 1 | from .denoising_data_rescale import shift_data_to_normalize, shift_data_to_denormalize 2 | 3 | 4 | def prepare_denoiser_data(batch_dict, data_rescale, device): 5 | """ 6 | Retrieve and prepare the relevant data for the denoising model. 7 | """ 8 | bs = sum(batch_dict['batch_sample_count']) 9 | 10 | # normalize GT trajectory 11 | gt_traj_metric = batch_dict['input_dict']['center_gt_trajs'][..., :2].to(device) # [B, T, 2] 12 | gt_traj_mask = batch_dict['input_dict']['center_gt_trajs_mask'].to(device) # [B, T] 13 | gt_traj_normalized = shift_data_to_normalize(gt_traj_metric, gt_traj_mask.bool(), data_rescale) # [B, T, 2] 14 | 15 | 16 | """update data dict""" 17 | denoiser_dict = { 18 | 'gt_traj_metric': gt_traj_metric, # [B, T, 2] 19 | 'gt_traj_normalized': gt_traj_normalized, # [B, T, 2] 20 | 'gt_traj_mask': gt_traj_mask, # [B, T] 21 | } 22 | batch_dict['denoiser_dict'] = denoiser_dict 23 | 24 | return batch_dict 25 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/knn/src/knn_gpu.h: -------------------------------------------------------------------------------- 1 | // Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | // Published at NeurIPS 2022 3 | // Written by Li Jiang, Shaoshuai Shi 4 | // All Rights Reserved 5 | 6 | 7 | #ifndef KNN_H 8 | #define KNN_H 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | // #include 15 | 16 | 17 | void knn_batch(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); 18 | void knn_batch_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); 19 | 20 | void knn_batch_mlogk(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k); 21 | void knn_batch_mlogk_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream); 22 | 23 | 24 | #endif 25 | -------------------------------------------------------------------------------- /runner/utils/starter/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from trajflow.datasets import build_dataloader 9 | 10 | 11 | def init_dataloader(cfg, logger): 12 | # Use per-GPU batch size if in distributed mode, otherwise use total batch size. 13 | if cfg.OPT.DIST_TRAIN: 14 | train_batch_size = cfg.OPT.BATCH_SIZE_PER_GPU 15 | else: 16 | train_batch_size = cfg.OPT.TOTAL_GPUS * cfg.OPT.BATCH_SIZE_PER_GPU 17 | 18 | train_set, train_loader, train_sampler = build_dataloader( 19 | dataset_cfg=cfg.DATA_CONFIG, batch_size=train_batch_size, 20 | dist=cfg.OPT.DIST_TRAIN, workers=cfg.OPT.WORKERS, 21 | logger=logger, training=True, testing=False, inter_pred=False) 22 | 23 | if cfg.OPT.DIST_TRAIN: 24 | test_batch_size = cfg.OPT.BATCH_SIZE_PER_GPU * 4 25 | else: 26 | test_batch_size = train_batch_size * 2 # or adjust as needed for non-DDP 27 | 28 | test_set, test_loader, test_sampler = build_dataloader( 29 | dataset_cfg=cfg.DATA_CONFIG, batch_size=test_batch_size, 30 | dist=cfg.OPT.DIST_TRAIN, workers=cfg.OPT.WORKERS, 31 | logger=logger, training=False, testing=False, inter_pred=False) 32 | 33 | return train_set, train_loader, train_sampler, test_set, test_loader, test_sampler 34 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_func.h" 5 | #include "attention_func_v2.h" 6 | 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("attention_weight_computation_wrapper", &attention_weight_computation_wrapper, 10 | "attention weight computation forward."); 11 | m.def("attention_weight_computation_grad_wrapper", &attention_weight_computation_grad_wrapper, 12 | "attention weight computation backward."); 13 | m.def("attention_value_computation_wrapper", &attention_value_computation_wrapper, 14 | "attention result computation forward."); 15 | m.def("attention_value_computation_grad_wrapper", &attention_value_computation_grad_wrapper, 16 | "attention result computation backward."); 17 | 18 | m.def("attention_weight_computation_wrapper_v2", &attention_weight_computation_wrapper_v2, 19 | "attention weight computation forward."); 20 | m.def("attention_weight_computation_grad_wrapper_v2", &attention_weight_computation_grad_wrapper_v2, 21 | "attention weight computation backward."); 22 | m.def("attention_value_computation_wrapper_v2", &attention_value_computation_wrapper_v2, 23 | "attention result computation forward."); 24 | m.def("attention_value_computation_grad_wrapper_v2", &attention_value_computation_grad_wrapper_v2, 25 | "attention result computation backward."); 26 | } -------------------------------------------------------------------------------- /trajflow/mtr_ops/knn/src/knn.cpp: -------------------------------------------------------------------------------- 1 | // Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | // Published at NeurIPS 2022 3 | // Written by Li Jiang, Shaoshuai Shi 4 | // All Rights Reserved 5 | 6 | 7 | #include "knn_gpu.h" 8 | 9 | // input xyz: (n, 3), float 10 | // input query_xyz: (m, 3), float 11 | // input batch_idxs: (n), int 12 | // input query_batch_offsets: (B + 1), int, offsets[-1] = m 13 | // output idx: (n, k), int 14 | void knn_batch(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ 15 | 16 | const float *query_xyz = query_xyz_tensor.data(); 17 | const float *xyz = xyz_tensor.data(); 18 | const int *batch_idxs = batch_idxs_tensor.data(); 19 | const int *query_batch_offsets = query_batch_offsets_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 23 | 24 | knn_batch_cuda(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); 25 | } 26 | 27 | 28 | void knn_batch_mlogk(at::Tensor xyz_tensor, at::Tensor query_xyz_tensor, at::Tensor batch_idxs_tensor, at::Tensor query_batch_offsets_tensor, at::Tensor idx_tensor, int n, int m, int k){ 29 | 30 | const float *query_xyz = query_xyz_tensor.data(); 31 | const float *xyz = xyz_tensor.data(); 32 | const int *batch_idxs = batch_idxs_tensor.data(); 33 | const int *query_batch_offsets = query_batch_offsets_tensor.data(); 34 | int *idx = idx_tensor.data(); 35 | 36 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 37 | 38 | knn_batch_mlogk_cuda(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx, stream); 39 | } 40 | -------------------------------------------------------------------------------- /trajflow/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import torch 14 | from torch.utils.data import DataLoader 15 | from trajflow.utils import common_utils 16 | 17 | from .waymo.waymo_dataset import WaymoDataset 18 | 19 | 20 | __all__ = { 21 | 'WaymoDataset': WaymoDataset, 22 | } 23 | 24 | 25 | def build_dataloader(dataset_cfg, batch_size, dist, workers=4, 26 | logger=None, 27 | training=True, testing=False, inter_pred=False): 28 | 29 | dataset = __all__[dataset_cfg.DATASET]( 30 | dataset_cfg=dataset_cfg, 31 | training=training, 32 | testing=testing, 33 | inter_pred=inter_pred, 34 | logger=logger, 35 | ) 36 | 37 | if dist: 38 | if training: 39 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 40 | else: 41 | rank, world_size = common_utils.get_dist_info() 42 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, world_size, rank, shuffle=False) 43 | else: 44 | sampler = None 45 | 46 | drop_last = dataset_cfg.get('DATALOADER_DROP_LAST', False) and training 47 | dataloader = DataLoader( 48 | dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, 49 | shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch, 50 | drop_last=drop_last, sampler=sampler, timeout=0, 51 | ) 52 | 53 | return dataset, dataloader, sampler 54 | -------------------------------------------------------------------------------- /trajflow/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | from pathlib import Path 14 | 15 | import yaml 16 | from easydict import EasyDict 17 | 18 | 19 | def _merge_new_config(config, new_config): 20 | if '_BASE_CONFIG_' in new_config: 21 | with open(new_config['_BASE_CONFIG_'], 'r') as f: 22 | try: 23 | yaml_config = yaml.load(f, Loader=yaml.FullLoader) 24 | except: 25 | yaml_config = yaml.load(f) 26 | config.update(EasyDict(yaml_config)) 27 | 28 | for key, val in new_config.items(): 29 | if not isinstance(val, dict): 30 | config[key] = val 31 | continue 32 | if key not in config: 33 | config[key] = EasyDict() 34 | _merge_new_config(config[key], val) 35 | 36 | return config 37 | 38 | 39 | def log_config_to_file(cfg, pre='cfg', logger=None): 40 | for key, val in cfg.items(): 41 | if isinstance(cfg[key], EasyDict): 42 | logger.info('--- %s.%s = edict() ---' % (pre, key)) 43 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 44 | continue 45 | logger.info('%s.%s: %s' % (pre, key, val)) 46 | 47 | 48 | def cfg_from_yaml_file(cfg_file, config): 49 | with open(cfg_file, 'r') as f: 50 | try: 51 | new_config = yaml.load(f, Loader=yaml.FullLoader) 52 | except: 53 | new_config = yaml.load(f) 54 | 55 | _merge_new_config(config=config, new_config=new_config) 56 | 57 | return config 58 | 59 | 60 | def init_cfg(): 61 | cfg = EasyDict() 62 | cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve() 63 | cfg.LOCAL_RANK = 0 64 | return cfg 65 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/knn/knn_utils.py: -------------------------------------------------------------------------------- 1 | # Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | # Published at NeurIPS 2022 3 | # Written by Li Jiang, Shaoshuai Shi 4 | # All Rights Reserved 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Function 10 | 11 | from . import knn_cuda 12 | 13 | 14 | class KNNBatch(Function): 15 | @staticmethod 16 | def forward(ctx, xyz, query_xyz, batch_idxs, query_batch_offsets, k): 17 | ''' 18 | :param ctx: 19 | :param xyz: (n, 3) float 20 | :param query_xyz: (m, 3), float 21 | :param batch_idxs: (n) int 22 | :param query_batch_offsets: (B+1) int, offsets[-1] = m 23 | :param k: int 24 | :return: idx (n, k) 25 | ''' 26 | 27 | n = xyz.size(0) 28 | m = query_xyz.size(0) 29 | assert k <= m 30 | assert xyz.is_contiguous() and xyz.is_cuda 31 | assert query_xyz.is_contiguous() and query_xyz.is_cuda 32 | assert batch_idxs.is_contiguous() and batch_idxs.is_cuda 33 | assert query_batch_offsets.is_contiguous() and query_batch_offsets.is_cuda 34 | 35 | idx = torch.cuda.IntTensor(n, k).zero_() 36 | 37 | knn_cuda.knn_batch(xyz, query_xyz, batch_idxs, query_batch_offsets, idx, n, m, k) 38 | 39 | return idx 40 | 41 | @staticmethod 42 | def backward(ctx, a=None): 43 | return None, None, None, None, None 44 | 45 | 46 | knn_batch = KNNBatch.apply 47 | 48 | 49 | class KNNBatchMlogK(Function): 50 | @staticmethod 51 | def forward(ctx, xyz, query_xyz, batch_idxs, query_batch_offsets, k): 52 | ''' 53 | :param ctx: 54 | :param xyz: (n, 3) float 55 | :param query_xyz: (m, 3), float 56 | :param batch_idxs: (n) int 57 | :param query_batch_offsets: (B+1) int, offsets[-1] = m 58 | :param k: int 59 | :return: idx (n, k) 60 | ''' 61 | 62 | n = xyz.size(0) 63 | m = query_xyz.size(0) 64 | # assert k <= m 65 | assert xyz.is_contiguous() and xyz.is_cuda 66 | assert query_xyz.is_contiguous() and query_xyz.is_cuda 67 | assert batch_idxs.is_contiguous() and batch_idxs.is_cuda 68 | assert query_batch_offsets.is_contiguous() and query_batch_offsets.is_cuda 69 | assert k <= 128 70 | idx = torch.cuda.IntTensor(n, k).zero_() 71 | 72 | knn_cuda.knn_batch_mlogk(xyz, query_xyz, batch_idxs, query_batch_offsets, idx, n, m, k) 73 | 74 | return idx 75 | 76 | @staticmethod 77 | def backward(ctx, a=None): 78 | return None, None, None, None, None 79 | 80 | knn_batch_mlogk = KNNBatchMlogK.apply 81 | -------------------------------------------------------------------------------- /trajflow/datasets/waymo/waymo_types.py: -------------------------------------------------------------------------------- 1 | # Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | # Published at NeurIPS 2022 3 | # Written by Shaoshuai Shi 4 | # All Rights Reserved 5 | 6 | 7 | object_type = { 8 | 0: 'TYPE_UNSET', 9 | 1: 'TYPE_VEHICLE', 10 | 2: 'TYPE_PEDESTRIAN', 11 | 3: 'TYPE_CYCLIST', 12 | 4: 'TYPE_OTHER' 13 | } 14 | 15 | lane_type = { 16 | 0: 'TYPE_UNDEFINED', 17 | 1: 'TYPE_FREEWAY', 18 | 2: 'TYPE_SURFACE_STREET', 19 | 3: 'TYPE_BIKE_LANE' 20 | } 21 | 22 | road_line_type = { 23 | 0: 'TYPE_UNKNOWN', 24 | 1: 'TYPE_BROKEN_SINGLE_WHITE', 25 | 2: 'TYPE_SOLID_SINGLE_WHITE', 26 | 3: 'TYPE_SOLID_DOUBLE_WHITE', 27 | 4: 'TYPE_BROKEN_SINGLE_YELLOW', 28 | 5: 'TYPE_BROKEN_DOUBLE_YELLOW', 29 | 6: 'TYPE_SOLID_SINGLE_YELLOW', 30 | 7: 'TYPE_SOLID_DOUBLE_YELLOW', 31 | 8: 'TYPE_PASSING_DOUBLE_YELLOW' 32 | } 33 | 34 | road_edge_type = { 35 | 0: 'TYPE_UNKNOWN', 36 | # // Physical road boundary that doesn't have traffic on the other side (e.g., 37 | # // a curb or the k-rail on the right side of a freeway). 38 | 1: 'TYPE_ROAD_EDGE_BOUNDARY', 39 | # // Physical road boundary that separates the car from other traffic 40 | # // (e.g. a k-rail or an island). 41 | 2: 'TYPE_ROAD_EDGE_MEDIAN' 42 | } 43 | 44 | polyline_type = { 45 | # for lane 46 | 'TYPE_UNDEFINED': -1, 47 | 'TYPE_FREEWAY': 1, 48 | 'TYPE_SURFACE_STREET': 2, 49 | 'TYPE_BIKE_LANE': 3, 50 | 51 | # for roadline 52 | 'TYPE_UNKNOWN': -1, 53 | 'TYPE_BROKEN_SINGLE_WHITE': 6, 54 | 'TYPE_SOLID_SINGLE_WHITE': 7, 55 | 'TYPE_SOLID_DOUBLE_WHITE': 8, 56 | 'TYPE_BROKEN_SINGLE_YELLOW': 9, 57 | 'TYPE_BROKEN_DOUBLE_YELLOW': 10, 58 | 'TYPE_SOLID_SINGLE_YELLOW': 11, 59 | 'TYPE_SOLID_DOUBLE_YELLOW': 12, 60 | 'TYPE_PASSING_DOUBLE_YELLOW': 13, 61 | 62 | # for roadedge 63 | 'TYPE_ROAD_EDGE_BOUNDARY': 15, 64 | 'TYPE_ROAD_EDGE_MEDIAN': 16, 65 | 66 | # for stopsign 67 | 'TYPE_STOP_SIGN': 17, 68 | 69 | # for crosswalk 70 | 'TYPE_CROSSWALK': 18, 71 | 72 | # for speed bump 73 | 'TYPE_SPEED_BUMP': 19, 74 | 75 | # for driveway 76 | 'TYPE_DRIVEWAY': 20 77 | } 78 | 79 | 80 | signal_state = { 81 | 0: 'LANE_STATE_UNKNOWN', 82 | 83 | # // States for traffic signals with arrows. 84 | 1: 'LANE_STATE_ARROW_STOP', 85 | 2: 'LANE_STATE_ARROW_CAUTION', 86 | 3: 'LANE_STATE_ARROW_GO', 87 | 88 | # // Standard round traffic signals. 89 | 4: 'LANE_STATE_STOP', 90 | 5: 'LANE_STATE_CAUTION', 91 | 6: 'LANE_STATE_GO', 92 | 93 | # // Flashing light signals. 94 | 7: 'LANE_STATE_FLASHING_STOP', 95 | 8: 'LANE_STATE_FLASHING_CAUTION' 96 | } 97 | 98 | signal_state_to_id = {} 99 | for key, val in signal_state.items(): 100 | signal_state_to_id[val] = key -------------------------------------------------------------------------------- /setup/setup_trajflow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import os 14 | import subprocess 15 | 16 | from setuptools import find_packages, setup 17 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 18 | 19 | FILE_PATH = os.path.dirname(os.path.abspath(__file__)) 20 | PROJ_DIR = os.path.dirname(FILE_PATH) 21 | os.chdir(PROJ_DIR) 22 | 23 | 24 | def get_git_commit_number(): 25 | if not os.path.exists('.git'): 26 | return '0000000' 27 | 28 | cmd_out = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE) 29 | git_commit_number = cmd_out.stdout.decode('utf-8')[:7] 30 | return git_commit_number 31 | 32 | 33 | def make_cuda_ext(name, module, sources): 34 | cuda_ext = CUDAExtension( 35 | name='%s.%s' % (module, name), 36 | sources=[os.path.join(*module.split('.'), src) for src in sources] 37 | ) 38 | return cuda_ext 39 | 40 | 41 | def write_version_to_file(version, target_file): 42 | with open(target_file, 'w') as f: 43 | print('__version__ = "%s"' % version, file=f) 44 | 45 | 46 | if __name__ == '__main__': 47 | version = '0.0.0+%s' % get_git_commit_number() 48 | write_version_to_file(version, 'trajflow/version.py') 49 | 50 | setup( 51 | name='TrajFlow', 52 | version=version, 53 | description='TrajFlow: Multi-modal Motion Prediction via Flow Matching', 54 | author='Qi Yan, Brian Zhang, Yutong Zhang, Daniel Yang, Joshua White, Di Chen, Jiachao Liu, Langechuan Liu, Binnan Zhuang, Shaoshuai Shi, Renjie Liao', 55 | license='Apache License 2.0', 56 | packages=find_packages(exclude=['runner', 'data', 'output']), 57 | cmdclass={ 58 | 'build_ext': BuildExtension, 59 | }, 60 | ext_modules=[ 61 | make_cuda_ext( 62 | name='knn_cuda', 63 | module='trajflow.mtr_ops.knn', 64 | sources=[ 65 | 'src/knn.cpp', 66 | 'src/knn_gpu.cu', 67 | 'src/knn_api.cpp', 68 | ], 69 | ), 70 | make_cuda_ext( 71 | name='attention_cuda', 72 | module='trajflow.mtr_ops.attention', 73 | sources=[ 74 | 'src/attention_api.cpp', 75 | 'src/attention_func_v2.cpp', 76 | 'src/attention_func.cpp', 77 | 'src/attention_value_computation_kernel_v2.cu', 78 | 'src/attention_value_computation_kernel.cu', 79 | 'src/attention_weight_computation_kernel_v2.cu', 80 | 'src/attention_weight_computation_kernel.cu', 81 | ], 82 | ), 83 | ], 84 | ) 85 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_func.h: -------------------------------------------------------------------------------- 1 | #ifndef _ATTENTION_FUNC_H 2 | #define _ATTENTION_FUNC_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void attention_weight_computation_launcher( 11 | int b, int total_query_num, int local_size, 12 | int total_key_num, int nhead, int hdim, 13 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 14 | const int *index_pair, 15 | const float *query_features, const float* key_features, 16 | float *output); 17 | 18 | 19 | int attention_weight_computation_wrapper( 20 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 21 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 22 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 23 | at::Tensor output); 24 | 25 | 26 | void attention_weight_computation_grad_launcher( 27 | int b, int total_query_num, int local_size, 28 | int total_key_num, int nhead, int hdim, 29 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 30 | const int *index_pair, 31 | const float *query_features, const float* key_features, 32 | float *grad_out, float* grad_query_features, float* grad_key_features); 33 | 34 | 35 | int attention_weight_computation_grad_wrapper( 36 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 37 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 38 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 39 | at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features); 40 | 41 | 42 | void attention_value_computation_launcher( 43 | int b, int total_query_num, int local_size, 44 | int total_key_num, int nhead, int hdim, 45 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 46 | const int *index_pair, 47 | const float *attn_weight, const float* value_features, 48 | float *output); 49 | 50 | 51 | int attention_value_computation_wrapper( 52 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 53 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 54 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 55 | at::Tensor output); 56 | 57 | 58 | void attention_value_computation_grad_launcher( 59 | int b, int total_query_num, int local_size, 60 | int total_key_num, int nhead, int hdim, 61 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 62 | const int *index_pair, 63 | const float *attn_weight, const float* value_features, 64 | float *grad_out, float* grad_attn_weight, float* grad_value_features); 65 | 66 | 67 | int attention_value_computation_grad_wrapper( 68 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 69 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 70 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 71 | at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features); 72 | 73 | #endif -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_func_v2.h: -------------------------------------------------------------------------------- 1 | #ifndef _ATTENTION_FUNC_V2_H 2 | #define _ATTENTION_FUNC_V2_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void attention_weight_computation_launcher_v2( 11 | int b, int total_query_num, int local_size, 12 | int total_key_num, int nhead, int hdim, 13 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 14 | const int *index_pair, 15 | const float *query_features, const float* key_features, 16 | float *output); 17 | 18 | 19 | int attention_weight_computation_wrapper_v2( 20 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 21 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 22 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 23 | at::Tensor output); 24 | 25 | 26 | void attention_weight_computation_grad_launcher_v2( 27 | int b, int total_query_num, int local_size, 28 | int total_key_num, int nhead, int hdim, 29 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 30 | const int *index_pair, 31 | const float *query_features, const float* key_features, 32 | float *grad_out, float* grad_query_features, float* grad_key_features); 33 | 34 | 35 | int attention_weight_computation_grad_wrapper_v2( 36 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 37 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 38 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 39 | at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features); 40 | 41 | 42 | void attention_value_computation_launcher_v2( 43 | int b, int total_query_num, int local_size, 44 | int total_key_num, int nhead, int hdim, 45 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 46 | const int *index_pair, 47 | const float *attn_weight, const float* value_features, 48 | float *output); 49 | 50 | 51 | int attention_value_computation_wrapper_v2( 52 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 53 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 54 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 55 | at::Tensor output); 56 | 57 | 58 | void attention_value_computation_grad_launcher_v2( 59 | int b, int total_query_num, int local_size, 60 | int total_key_num, int nhead, int hdim, 61 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 62 | const int *index_pair, 63 | const float *attn_weight, const float* value_features, 64 | float *grad_out, float* grad_attn_weight, float* grad_value_features); 65 | 66 | 67 | int attention_value_computation_grad_wrapper_v2( 68 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 69 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 70 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 71 | at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features); 72 | 73 | #endif -------------------------------------------------------------------------------- /trajflow/models/layers/polyline_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Shaoshuai Shi. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | ##################################################################################### 7 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 8 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 9 | #################################################################################### 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from .common_layers import build_mlps 15 | 16 | 17 | class PointNetPolylineEncoder(nn.Module): 18 | def __init__(self, in_channels, hidden_dim, num_layers=3, num_pre_layers=1, out_channels=None): 19 | super().__init__() 20 | self.pre_mlps = build_mlps( 21 | c_in=in_channels, 22 | mlp_channels=[hidden_dim] * num_pre_layers, 23 | ret_before_act=False 24 | ) 25 | self.mlps = build_mlps( 26 | c_in=hidden_dim * 2, 27 | mlp_channels=[hidden_dim] * (num_layers - num_pre_layers), 28 | ret_before_act=False 29 | ) 30 | 31 | if out_channels is not None: 32 | self.out_mlps = build_mlps( 33 | c_in=hidden_dim, mlp_channels=[hidden_dim, out_channels], 34 | ret_before_act=True, without_norm=True 35 | ) 36 | else: 37 | self.out_mlps = None 38 | 39 | def forward(self, polylines, polylines_mask): 40 | """ 41 | Args: 42 | polylines (batch_size, num_polylines, num_points_each_polylines, C): 43 | polylines_mask (batch_size, num_polylines, num_points_each_polylines): 44 | 45 | Returns: 46 | """ 47 | batch_size, num_polylines, num_points_each_polylines, C = polylines.shape 48 | 49 | # pre-mlp 50 | polylines_feature_valid = self.pre_mlps(polylines[polylines_mask]) # (N, C) 51 | polylines_feature = polylines.new_zeros(batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1]) 52 | polylines_feature[polylines_mask] = polylines_feature_valid 53 | 54 | # get global feature 55 | pooled_feature = polylines_feature.max(dim=2)[0] 56 | polylines_feature = torch.cat((polylines_feature, pooled_feature[:, :, None, :].repeat(1, 1, num_points_each_polylines, 1)), dim=-1) 57 | 58 | # mlp 59 | polylines_feature_valid = self.mlps(polylines_feature[polylines_mask]) 60 | feature_buffers = polylines_feature.new_zeros(batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1]) 61 | feature_buffers[polylines_mask] = polylines_feature_valid 62 | 63 | # max-pooling 64 | feature_buffers = feature_buffers.max(dim=2)[0] # (batch_size, num_polylines, C) 65 | 66 | # out-mlp 67 | if self.out_mlps is not None: 68 | valid_mask = (polylines_mask.sum(dim=-1) > 0) 69 | feature_buffers_valid = self.out_mlps(feature_buffers[valid_mask]) # (N, C) 70 | feature_buffers = feature_buffers.new_zeros(batch_size, num_polylines, feature_buffers_valid.shape[-1]) 71 | feature_buffers[valid_mask] = feature_buffers_valid 72 | return feature_buffers 73 | -------------------------------------------------------------------------------- /trajflow/utils/mtr_loss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Shaoshuai Shi. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | ##################################################################################### 7 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 8 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 9 | #################################################################################### 10 | 11 | 12 | import torch 13 | 14 | 15 | def nll_loss_gmm_direct(pred_scores, pred_trajs, gt_trajs, gt_valid_mask, pre_nearest_mode_idxs=None, 16 | timestamp_loss_weight=None, use_square_gmm=False, log_std_range=(-1.609, 5.0), rho_limit=0.5): 17 | """ 18 | GMM Loss for Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 19 | Written by Shaoshuai Shi 20 | 21 | Args: 22 | pred_scores (batch_size, num_modes): 23 | pred_trajs (batch_size, num_modes, num_timestamps, 5 or 3) 24 | gt_trajs (batch_size, num_timestamps, 2): 25 | gt_valid_mask (batch_size, num_timestamps): 26 | timestamp_loss_weight (num_timestamps): 27 | """ 28 | if use_square_gmm: 29 | assert pred_trajs.shape[-1] == 3 30 | else: 31 | assert pred_trajs.shape[-1] == 5 32 | 33 | batch_size = pred_scores.shape[0] 34 | 35 | if pre_nearest_mode_idxs is not None: 36 | nearest_mode_idxs = pre_nearest_mode_idxs 37 | else: 38 | distance = (pred_trajs[:, :, :, 0:2] - gt_trajs[:, None, :, :]).norm(dim=-1) 39 | distance = (distance * gt_valid_mask[:, None, :]).sum(dim=-1) 40 | 41 | nearest_mode_idxs = distance.argmin(dim=-1) 42 | nearest_mode_bs_idxs = torch.arange(batch_size).type_as(nearest_mode_idxs) # (batch_size, 2) 43 | 44 | nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs] # (batch_size, num_timestamps, 5) 45 | res_trajs = gt_trajs - nearest_trajs[:, :, 0:2] # (batch_size, num_timestamps, 2) 46 | dx = res_trajs[:, :, 0] 47 | dy = res_trajs[:, :, 1] 48 | 49 | if use_square_gmm: 50 | log_std1 = log_std2 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) 51 | std1 = std2 = torch.exp(log_std1) # (0.2m to 150m) 52 | rho = torch.zeros_like(log_std1) 53 | else: 54 | log_std1 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) 55 | log_std2 = torch.clip(nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1]) 56 | std1 = torch.exp(log_std1) # (0.2m to 150m) 57 | std2 = torch.exp(log_std2) # (0.2m to 150m) 58 | rho = torch.clip(nearest_trajs[:, :, 4], min=-rho_limit, max=rho_limit) 59 | 60 | gt_valid_mask = gt_valid_mask.type_as(pred_scores) 61 | if timestamp_loss_weight is not None: 62 | gt_valid_mask = gt_valid_mask * timestamp_loss_weight[None, :] 63 | 64 | # -log(a^-1 * e^b) = log(a) - b 65 | reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) # (batch_size, num_timestamps) 66 | reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ((dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2)) # (batch_size, num_timestamps) 67 | 68 | reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1) 69 | 70 | return reg_loss, nearest_mode_idxs -------------------------------------------------------------------------------- /runner/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import torch 9 | from torch.nn.parallel import DistributedDataParallel 10 | 11 | from runner.utils.starter.config_parser import init_basics 12 | from runner.utils.starter.data import init_dataloader 13 | from runner.utils.starter.network import init_network, init_optimizer, init_ema_helper, load_checkpoint, init_scheduler 14 | from runner.utils.eval import repeat_eval_ckpt 15 | from runner.utils.trainer import train_model 16 | 17 | 18 | def main(): 19 | """ 20 | Main function. 21 | """ 22 | 23 | """Init""" 24 | args, cfg, logger, wb_log = init_basics() 25 | 26 | 27 | """build dataloader""" 28 | _, train_loader, train_sampler, _, test_loader, _ = init_dataloader(cfg, logger) 29 | 30 | 31 | """build model""" 32 | model, denoiser = init_network(cfg, logger) 33 | 34 | 35 | """build optimizer""" 36 | optimizer = init_optimizer(model, cfg.OPT) 37 | ema_helper = init_ema_helper(model, cfg.OPT, logger) 38 | 39 | 40 | """load model checkpoint""" 41 | it, start_epoch, last_epoch = load_checkpoint(model, optimizer, ema_helper, logger, 42 | ckpt_path=args.ckpt, ckpt_dir=cfg.SAVE_DIR.CKPT_DIR) 43 | 44 | 45 | """build scheduler""" 46 | # build after loading ckpt since the optimizer may be changed 47 | scheduler = init_scheduler(optimizer, cfg.OPT, total_epochs=cfg.OPT.NUM_EPOCHS - start_epoch, 48 | total_iters_each_epoch=len(train_loader), last_epoch=last_epoch) 49 | 50 | 51 | """adapt to distributed training""" 52 | if cfg.OPT.DIST_TRAIN: 53 | denoiser = DistributedDataParallel(denoiser, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()], find_unused_parameters=True) 54 | 55 | 56 | """start training""" 57 | logger.info('**********************Start training %s/%s(%s)**********************' 58 | % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) 59 | 60 | train_model(denoiser, optimizer, scheduler, train_loader, ema_helper, cfg, 61 | start_epoch, it, logger, wb_log, 62 | train_sampler=train_sampler, test_loader=test_loader, 63 | ckpt_save_interval=args.ckpt_save_interval, ckpt_save_time_interval=args.ckpt_save_time_interval, 64 | max_ckpt_save_num=args.max_ckpt_save_num, logger_iter_interval=args.logger_iter_interval) 65 | 66 | logger.info('**********************End training %s/%s(%s)**********************\n\n\n' 67 | % (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) 68 | 69 | 70 | """start evaluation""" 71 | logger.info('**********************Start evaluation %s/%s(%s)**********************' % 72 | (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) 73 | 74 | # tweak eval settings# 75 | args.start_epoch = max(args.epochs, 0) # Only evaluate the last 10 epochs 76 | cfg.DATA_CONFIG.SAMPLE_INTERVAL.val = 1 # Evaluate all samples 77 | args.interactive = False # do not run interactive evaluation 78 | args.submit = False # do not generate submission files 79 | repeat_eval_ckpt(denoiser.module if cfg.OPT.DIST_TRAIN else denoiser, test_loader, cfg, args, logger, 80 | args_ema_coef=None) 81 | 82 | logger.info('**********************End evaluation %s/%s(%s)**********************' % 83 | (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag)) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /runner/cfgs/waymo/trajflow+100_percent_data.yaml: -------------------------------------------------------------------------------- 1 | DATA_CONFIG: 2 | DATASET: WaymoDataset 3 | OBJECT_TYPE: &object_type ['TYPE_VEHICLE', 'TYPE_PEDESTRIAN', 'TYPE_CYCLIST'] 4 | 5 | DATA_ROOT: 'data/waymo' 6 | SPLIT_DIR: { 7 | 'train': 'processed_scenarios_training', 8 | 'eval': 'processed_scenarios_validation', 9 | 'inter_eval': 'processed_scenarios_validation_interactive', 10 | 'test': 'processed_scenarios_testing', 11 | 'inter_test': 'processed_scenarios_testing_interactive' 12 | } 13 | 14 | INFO_FILE: { 15 | 'train': 'processed_scenarios_training_infos.pkl', 16 | 'eval': 'processed_scenarios_val_infos.pkl', 17 | 'inter_eval': 'processed_scenarios_val_inter_infos.pkl', 18 | 'test': 'processed_scenarios_test_infos.pkl', 19 | 'inter_test': 'processed_scenarios_test_inter_infos.pkl', 20 | } 21 | 22 | SAMPLE_INTERVAL: { 23 | 'train': 1, 24 | 'eval': 10, 25 | "inter_eval": 10, 26 | 'test': 10, 27 | 'inter_test': 10 28 | } 29 | 30 | INFO_FILTER_DICT: 31 | filter_info_by_object_type: *object_type 32 | 33 | POINT_SAMPLED_INTERVAL: 1 34 | NUM_POINTS_EACH_POLYLINE: 20 35 | VECTOR_BREAK_DIST_THRESH: 1.0 36 | 37 | NUM_OF_SRC_POLYLINES: 768 38 | CENTER_OFFSET_OF_MAP: ¢er_offset [30.0, 0] 39 | 40 | DATA_RESCALE: linear 41 | 42 | MODEL_DMT: 43 | CONTEXT_ENCODER: 44 | # following vanilla MTR configurations 45 | NAME: MTREncoder 46 | 47 | NUM_OF_ATTN_NEIGHBORS: 16 48 | NUM_INPUT_ATTR_AGENT: 29 49 | NUM_INPUT_ATTR_MAP: 9 50 | 51 | NUM_CHANNEL_IN_MLP_AGENT: 256 52 | NUM_CHANNEL_IN_MLP_MAP: 64 53 | NUM_LAYER_IN_MLP_AGENT: 3 54 | NUM_LAYER_IN_MLP_MAP: 5 55 | NUM_LAYER_IN_PRE_MLP_MAP: 3 56 | 57 | D_MODEL: 256 58 | NUM_ATTN_LAYERS: 6 59 | NUM_ATTN_HEAD: 8 60 | DROPOUT_OF_ATTN: 0.1 61 | 62 | USE_LOCAL_ATTN: True 63 | 64 | DMT: 65 | # denoising transformer network configurations 66 | D_QUERY: 512 # dimension of query token 67 | D_OBJ: 256 # dimension of context object token 68 | D_MAP: 256 # dimension of context map token 69 | 70 | DEPTH: 6 # number of transformer layers 71 | HEADS: 8 # number of attention heads 72 | DROPOUT: 0.1 # dropout rate 73 | 74 | NUM_QUERY: 64 # number of query tokens 75 | 76 | NUM_FUTURE_FRAMES: 80 77 | 78 | OBJECT_TYPE: *object_type 79 | CENTER_OFFSET_OF_MAP: *center_offset 80 | 81 | NUM_BASE_MAP_POLYLINES: 256 82 | NUM_WAYPOINT_MAP_POLYLINES: 128 83 | 84 | LOSS_WEIGHTS: { 85 | 'cls': 1.0, 86 | 'reg': 1.0, 87 | 'vel': 0.5, 88 | 'pl': 0.1 89 | } 90 | 91 | NUM_FUTURE_FRAMES: 80 92 | NUM_MOTION_MODES: 6 93 | 94 | INTENTION_POINTS_FILE: data/waymo/cluster_64_center_dict.pkl 95 | 96 | DENOISING: 97 | # denoising objective configurations 98 | FM: 99 | ### FM params ### 100 | SAMPLING_STEPS: 1 101 | OBJECTIVE: pred_data 102 | T_SCHEDULE: uniform 103 | ### FM params ### 104 | 105 | ### General params ### 106 | TIED_NOISE: True # use tied noise vector in case of multi-path diffusion 107 | CTC_LOSS: True # use cross-time consistency loss for data prediction objective 108 | ### General params ### 109 | 110 | OPT: 111 | # optimization hyperparameters 112 | BATCH_SIZE_PER_GPU: 10 113 | NUM_EPOCHS: 30 114 | 115 | OPTIMIZER: AdamW 116 | LR: 0.0001 117 | WEIGHT_DECAY: 0.01 118 | 119 | SCHEDULER: lambdaLR 120 | DECAY_STEP_LIST: [22, 24, 26, 28] 121 | LR_DECAY: 0.5 122 | LR_CLIP: 0.000001 123 | 124 | EMA_COEF: [0.999] 125 | 126 | GRAD_NORM_CLIP: 1000.0 127 | -------------------------------------------------------------------------------- /runner/cfgs/waymo/trajflow+20_percent_data.yaml: -------------------------------------------------------------------------------- 1 | DATA_CONFIG: 2 | DATASET: WaymoDataset 3 | OBJECT_TYPE: &object_type ['TYPE_VEHICLE', 'TYPE_PEDESTRIAN', 'TYPE_CYCLIST'] 4 | 5 | DATA_ROOT: 'data/waymo' 6 | SPLIT_DIR: { 7 | 'train': 'processed_scenarios_training', 8 | 'eval': 'processed_scenarios_validation', 9 | 'inter_eval': 'processed_scenarios_validation_interactive', 10 | 'test': 'processed_scenarios_testing', 11 | 'inter_test': 'processed_scenarios_testing_interactive' 12 | } 13 | 14 | INFO_FILE: { 15 | 'train': 'processed_scenarios_training_infos.pkl', 16 | 'eval': 'processed_scenarios_val_infos.pkl', 17 | 'inter_eval': 'processed_scenarios_val_inter_infos.pkl', 18 | 'test': 'processed_scenarios_test_infos.pkl', 19 | 'inter_test': 'processed_scenarios_test_inter_infos.pkl', 20 | } 21 | 22 | SAMPLE_INTERVAL: { 23 | 'train': 5, 24 | 'eval': 10, 25 | "inter_eval": 10, 26 | 'test': 10, 27 | 'inter_test': 10 28 | } 29 | 30 | INFO_FILTER_DICT: 31 | filter_info_by_object_type: *object_type 32 | 33 | POINT_SAMPLED_INTERVAL: 1 34 | NUM_POINTS_EACH_POLYLINE: 20 35 | VECTOR_BREAK_DIST_THRESH: 1.0 36 | 37 | NUM_OF_SRC_POLYLINES: 768 38 | CENTER_OFFSET_OF_MAP: ¢er_offset [30.0, 0] 39 | 40 | DATA_RESCALE: linear 41 | 42 | MODEL_DMT: 43 | CONTEXT_ENCODER: 44 | # following vanilla MTR configurations 45 | NAME: MTREncoder 46 | 47 | NUM_OF_ATTN_NEIGHBORS: 16 48 | NUM_INPUT_ATTR_AGENT: 29 49 | NUM_INPUT_ATTR_MAP: 9 50 | 51 | NUM_CHANNEL_IN_MLP_AGENT: 256 52 | NUM_CHANNEL_IN_MLP_MAP: 64 53 | NUM_LAYER_IN_MLP_AGENT: 3 54 | NUM_LAYER_IN_MLP_MAP: 5 55 | NUM_LAYER_IN_PRE_MLP_MAP: 3 56 | 57 | D_MODEL: 256 58 | NUM_ATTN_LAYERS: 6 59 | NUM_ATTN_HEAD: 8 60 | DROPOUT_OF_ATTN: 0.1 61 | 62 | USE_LOCAL_ATTN: True 63 | 64 | DMT: 65 | # denoising transformer network configurations 66 | D_QUERY: 512 # dimension of query token 67 | D_OBJ: 256 # dimension of context object token 68 | D_MAP: 256 # dimension of context map token 69 | 70 | DEPTH: 6 # number of transformer layers 71 | HEADS: 8 # number of attention heads 72 | DROPOUT: 0.1 # dropout rate 73 | 74 | NUM_QUERY: 64 # number of query tokens 75 | 76 | NUM_FUTURE_FRAMES: 80 77 | 78 | OBJECT_TYPE: *object_type 79 | CENTER_OFFSET_OF_MAP: *center_offset 80 | 81 | NUM_BASE_MAP_POLYLINES: 256 82 | NUM_WAYPOINT_MAP_POLYLINES: 128 83 | 84 | LOSS_WEIGHTS: { 85 | 'cls': 1.0, 86 | 'reg': 1.0, 87 | 'vel': 0.5, 88 | 'pl': 0.1 89 | } 90 | 91 | NUM_FUTURE_FRAMES: 80 92 | NUM_MOTION_MODES: 6 93 | 94 | INTENTION_POINTS_FILE: data/waymo/cluster_64_center_dict.pkl 95 | 96 | DENOISING: 97 | # denoising objective configurations 98 | FM: 99 | ### FM params ### 100 | SAMPLING_STEPS: 1 101 | OBJECTIVE: pred_data 102 | T_SCHEDULE: uniform 103 | ### FM params ### 104 | 105 | ### General params ### 106 | TIED_NOISE: True # use tied noise vector in case of multi-path diffusion 107 | CTC_LOSS: True # use cross-time consistency loss for data prediction objective 108 | ### General params ### 109 | 110 | OPT: 111 | # optimization hyperparameters 112 | BATCH_SIZE_PER_GPU: 10 113 | NUM_EPOCHS: 30 114 | 115 | OPTIMIZER: AdamW 116 | LR: 0.0001 117 | WEIGHT_DECAY: 0.01 118 | 119 | SCHEDULER: lambdaLR 120 | DECAY_STEP_LIST: [22, 24, 26, 28] 121 | LR_DECAY: 0.5 122 | LR_CLIP: 0.000001 123 | 124 | EMA_COEF: [0.999] 125 | 126 | GRAD_NORM_CLIP: 1000.0 127 | -------------------------------------------------------------------------------- /runner/utils/submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | from waymo_open_dataset.protos.motion_submission_pb2 import * 4 | 5 | 6 | def traj_serialize(trajectories, scores, object_ids): 7 | scored_obj_trajs = [] 8 | for i in range(trajectories.shape[0]): 9 | center_x, center_y = trajectories[i, 4::5, 0], trajectories[i, 4::5, 1] 10 | traj = Trajectory(center_x=center_x, center_y=center_y) 11 | object_traj = ScoredTrajectory(confidence=scores[i], trajectory=traj) 12 | scored_obj_trajs.append(object_traj) 13 | return SingleObjectPrediction(trajectories=scored_obj_trajs, object_id=object_ids) 14 | 15 | 16 | def serialize_single_scenario(scenario_list): 17 | single_prediction_list = [] 18 | scenario_id = scenario_list[0]['scenario_id'] 19 | # Assert all scenario_ids match once 20 | assert all(d['scenario_id'] == scenario_id for d in scenario_list) 21 | for single_dict in scenario_list: 22 | sc_id = single_dict['scenario_id'] 23 | single_prediction = traj_serialize(single_dict['pred_trajs'], 24 | single_dict['pred_scores'], single_dict['object_id']) 25 | single_prediction_list.append(single_prediction) 26 | prediction_set = PredictionSet(predictions=single_prediction_list) 27 | return ChallengeScenarioPredictions(scenario_id=scenario_id, single_predictions=prediction_set) 28 | 29 | 30 | def joint_serialize_single_scenario(scenario_list): 31 | assert len(scenario_list)==2 32 | scenario_id = scenario_list[0]['scenario_id'] 33 | joint_score = scenario_list[0]['pred_scores'] 34 | full_scored_trajs = [] 35 | for j in range(6): 36 | object_trajs = [] 37 | for i in range(2): 38 | center_x = scenario_list[i]['pred_trajs'][j, 4::5, 0] 39 | center_y = scenario_list[i]['pred_trajs'][j, 4::5, 1] 40 | traj = Trajectory(center_x=center_x, center_y=center_y) 41 | score_traj = ObjectTrajectory(object_id=scenario_list[i]['object_id'], trajectory=traj) 42 | object_trajs.append(score_traj) 43 | full_scored_trajs.append(ScoredJointTrajectory(trajectories=object_trajs, confidence=joint_score[j])) 44 | joint_prediction = JointPrediction(joint_trajectories=full_scored_trajs) 45 | return ChallengeScenarioPredictions(scenario_id=scenario_id, joint_prediction=joint_prediction) 46 | 47 | 48 | def serialize_single_batch(final_pred_dicts, joint_pred=False): 49 | ret_scenarios = [] 50 | for b in range(len(final_pred_dicts)): 51 | scenario_list = final_pred_dicts[b] 52 | if joint_pred: 53 | scenario_preds = joint_serialize_single_scenario(scenario_list) 54 | else: 55 | scenario_preds = serialize_single_scenario(scenario_list) 56 | ret_scenarios.append(scenario_preds) 57 | return ret_scenarios 58 | 59 | 60 | def save_submission_file(scenerio_predictions, inter_pred, save_dir, save_name, submission_info, logger): 61 | submission_type = 2 if inter_pred else 1 62 | 63 | submission = MotionChallengeSubmission( 64 | account_name=submission_info['account_name'], 65 | unique_method_name=submission_info['unique_method_name'], 66 | authors=submission_info['authors'], 67 | affiliation=submission_info['affiliation'], 68 | submission_type=submission_type, 69 | scenario_predictions=scenerio_predictions, 70 | uses_lidar_data=submission_info['uses_lidar_data'], 71 | uses_camera_data=submission_info['uses_camera_data'], 72 | uses_public_model_pretraining=submission_info['uses_public_model_pretraining'], 73 | public_model_names=submission_info['public_model_names'], 74 | num_model_parameters=submission_info['num_model_parameters'] 75 | ) 76 | 77 | os.makedirs(save_dir, exist_ok=True) 78 | proto_path = os.path.join(save_dir, f"{save_name}.proto") 79 | tar_path = os.path.join(save_dir, f"{save_name}.gz") 80 | 81 | with open(proto_path, "wb") as f: 82 | f.write(submission.SerializeToString()) 83 | 84 | with tarfile.open(tar_path, "w:gz") as tar: 85 | tar.add(proto_path) 86 | 87 | os.remove(proto_path) 88 | 89 | logger.info("Submission file saved to {:s} with {:d} trajectories...".format(tar_path, len(scenerio_predictions))) 90 | -------------------------------------------------------------------------------- /trajflow/models/denoising_decoder/build_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import copy 14 | import pickle 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | from trajflow.models.layers.common_layers import build_mlps 20 | from trajflow.config import init_cfg 21 | from trajflow.models.layers.transformer.dmt_decoder_layer import DMTDecoderLayer 22 | 23 | 24 | def build_in_proj_layer(d_input, d_model, d_obj, d_map): 25 | in_proj_center_obj = build_mlps(c_in=d_input, mlp_channels=[d_model] * 2, ret_before_act=True, without_norm=True) 26 | in_proj_obj = build_mlps(c_in=d_input, mlp_channels=[d_obj] * 2, ret_before_act=True, without_norm=True) 27 | in_proj_map = build_mlps(c_in=d_input, mlp_channels=[d_map] * 2, ret_before_act=True, without_norm=True) 28 | return in_proj_center_obj, in_proj_obj, in_proj_map 29 | 30 | 31 | def build_transformer_decoder(d_tf, nhead, dropout, num_decoder_layers): 32 | decoder_layer = DMTDecoderLayer(d_model=d_tf, nhead=nhead, dim_feedforward=d_tf * 4, 33 | dropout=dropout, activation="relu", normalize_before=False, 34 | use_concat_pe_ca=True, normalization_type='layer_norm', bias=True, 35 | qk_norm=False, adaLN=False) 36 | decoder_layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_decoder_layers)]) 37 | return decoder_layers 38 | 39 | 40 | def build_dense_future_prediction_layers(hidden_dim, d_obj, num_future_frames): 41 | obj_pos_encoding_layer = build_mlps(c_in=2, mlp_channels=[hidden_dim, hidden_dim, hidden_dim], ret_before_act=True, without_norm=True) 42 | dense_future_head = build_mlps(c_in=hidden_dim + d_obj, mlp_channels=[hidden_dim, hidden_dim, num_future_frames * 7], ret_before_act=True) 43 | future_traj_mlps = build_mlps(c_in=4 * num_future_frames, mlp_channels=[hidden_dim, hidden_dim, hidden_dim], ret_before_act=True, without_norm=True) 44 | traj_fusion_mlps = build_mlps(c_in=hidden_dim + d_obj, mlp_channels=[hidden_dim, hidden_dim, d_obj], ret_before_act=True, without_norm=True) 45 | return obj_pos_encoding_layer, dense_future_head, future_traj_mlps, traj_fusion_mlps 46 | 47 | 48 | def build_motion_query(d_model, model_cfg): 49 | _init_cfg = init_cfg() 50 | intention_points_file = _init_cfg.ROOT_DIR / model_cfg.INTENTION_POINTS_FILE 51 | with open(intention_points_file, 'rb') as f: 52 | intention_points_dict = pickle.load(f) 53 | intention_points = {} 54 | for cur_type in model_cfg.OBJECT_TYPE: 55 | cur_intention_points = intention_points_dict[cur_type] 56 | cur_intention_points = torch.from_numpy(cur_intention_points).float().view(-1, 2).cuda() 57 | intention_points[cur_type] = cur_intention_points 58 | 59 | intention_query_mlps = build_mlps(c_in=d_model, mlp_channels=[d_model, d_model], ret_before_act=True) 60 | return intention_points, intention_query_mlps 61 | 62 | 63 | def build_motion_head(d_model, map_d_model, hidden_size, num_future_frames, num_decoder_layers): 64 | temp_layer = build_mlps(c_in=d_model * 2 + map_d_model, mlp_channels=[d_model, d_model], ret_before_act=True) 65 | query_feature_fusion_layers = nn.ModuleList([copy.deepcopy(temp_layer) for _ in range(num_decoder_layers)]) 66 | 67 | motion_reg_head = build_mlps(c_in=d_model, mlp_channels=[hidden_size, hidden_size, num_future_frames * 7], ret_before_act=True) 68 | motion_cls_head = build_mlps(c_in=d_model, mlp_channels=[hidden_size, hidden_size, 1], ret_before_act=True) 69 | 70 | motion_reg_heads = nn.ModuleList([copy.deepcopy(motion_reg_head) for _ in range(num_decoder_layers)]) 71 | motion_cls_heads = nn.ModuleList([copy.deepcopy(motion_cls_head) for _ in range(num_decoder_layers)]) 72 | return query_feature_fusion_layers, motion_reg_heads, motion_cls_heads 73 | 74 | -------------------------------------------------------------------------------- /trajflow/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | # Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | # Published at NeurIPS 2022 3 | # Written by Shaoshuai Shi 4 | # All Rights Reserved 5 | 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data as torch_data 10 | import trajflow.utils.common_utils as common_utils 11 | 12 | 13 | class DatasetTemplate(torch_data.Dataset): 14 | def __init__(self, dataset_cfg=None, training=True, testing=False, inter_pred=False, logger=None): 15 | super().__init__() 16 | self.dataset_cfg = dataset_cfg 17 | self.training = training 18 | self.testing = testing 19 | self.inter_pred = inter_pred 20 | self.logger = logger 21 | 22 | @property 23 | def mode(self): 24 | if self.training: 25 | return 'train' 26 | else: 27 | if self.inter_pred: 28 | if self.testing: 29 | return 'inter_test' 30 | else: 31 | return 'inter_eval' 32 | else: 33 | if self.testing: 34 | return 'test' 35 | else: 36 | return 'eval' 37 | 38 | def merge_all_iters_to_one_epoch(self, merge=True, epochs=None): 39 | if merge: 40 | self._merge_all_iters_to_one_epoch = True 41 | self.total_epochs = epochs 42 | else: 43 | self._merge_all_iters_to_one_epoch = False 44 | 45 | def __len__(self): 46 | raise NotImplementedError 47 | 48 | def __getitem__(self, index): 49 | raise NotImplementedError 50 | 51 | def collate_batch(self, batch_list): 52 | """ 53 | Args: 54 | batch_list: 55 | scenario_id: (num_center_objects) 56 | track_index_to_predict (num_center_objects): 57 | 58 | obj_trajs (num_center_objects, num_objects, num_timestamps, num_attrs): 59 | obj_trajs_mask (num_center_objects, num_objects, num_timestamps): 60 | map_polylines (num_center_objects, num_polylines, num_points_each_polyline, 9): [x, y, z, dir_x, dir_y, dir_z, global_type, pre_x, pre_y] 61 | map_polylines_mask (num_center_objects, num_polylines, num_points_each_polyline) 62 | 63 | obj_trajs_pos: (num_center_objects, num_objects, num_timestamps, 3) 64 | obj_trajs_last_pos: (num_center_objects, num_objects, 3) 65 | obj_types: (num_objects) 66 | obj_ids: (num_objects) 67 | 68 | center_objects_world: (num_center_objects, 10) [cx, cy, cz, dx, dy, dz, heading, vel_x, vel_y, valid] 69 | center_objects_type: (num_center_objects) 70 | center_objects_id: (num_center_objects) 71 | 72 | obj_trajs_future_state (num_center_objects, num_objects, num_future_timestamps, 4): [x, y, vx, vy] 73 | obj_trajs_future_mask (num_center_objects, num_objects, num_future_timestamps): 74 | center_gt_trajs (num_center_objects, num_future_timestamps, 4): [x, y, vx, vy] 75 | center_gt_trajs_mask (num_center_objects, num_future_timestamps): 76 | center_gt_final_valid_idx (num_center_objects): the final valid timestamp in num_future_timestamps 77 | """ 78 | batch_size = len(batch_list) 79 | key_to_list = {} 80 | for key in batch_list[0].keys(): 81 | key_to_list[key] = [batch_list[bs_idx][key] for bs_idx in range(batch_size)] 82 | 83 | input_dict = {} 84 | for key, val_list in key_to_list.items(): 85 | 86 | if key in ['obj_trajs', 'obj_trajs_mask', 'map_polylines', 'map_polylines_mask', 'map_polylines_center', 87 | 'obj_trajs_pos', 'obj_trajs_last_pos', 'obj_trajs_future_state', 'obj_trajs_future_mask']: 88 | val_list = [torch.from_numpy(x) for x in val_list] 89 | input_dict[key] = common_utils.merge_batch_by_padding_2nd_dim(val_list) 90 | elif key in ['scenario_id', 'obj_types', 'obj_ids', 'center_objects_type', 'center_objects_id']: 91 | input_dict[key] = np.concatenate(val_list, axis=0) 92 | else: 93 | val_list = [torch.from_numpy(x) for x in val_list] 94 | input_dict[key] = torch.cat(val_list, dim=0) 95 | 96 | batch_sample_count = [len(x['track_index_to_predict']) for x in batch_list] 97 | batch_dict = {'batch_size': batch_size, 'input_dict': input_dict, 'batch_sample_count': batch_sample_count} 98 | return batch_dict 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrajFlow: Multi-modal Motion Prediction via Flow Matching 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-red)](https://www.arxiv.org/abs/2506.08541) 4 | [![Project Webpage](https://img.shields.io/badge/Project_Page-Website-green?logo=googlechrome&logoColor=white)](https://traj-flow.github.io/) 5 | 6 | The official PyTorch implementation of IROS'25 paper named "TrajFlow: Multi-modal Motion Prediction via Flow Matching". 7 | 8 | ## Overview 9 | 10 | ![TrajFlow diagram](assets/trajflow_overview.png) 11 | We propose a new flow matching framework to predict multi-modal trajectories on the large-scale Waymo Open Motion Dataset. 12 | 13 | ## Install Python Environment 14 | 15 | **Step 1:** Create a python environment 16 | 17 | ```bash 18 | conda create --name trajflow python=3.10 -y 19 | conda activate trajflow 20 | ``` 21 | 22 | Please note that we use `python=3.10` mainly for compatibility with the `waymo-open-dataset-tf-2-12-0` package, which is required for metrics evaluation. 23 | 24 | **Step 2:** Install the required packages 25 | 26 | ```bash 27 | # install pytorch 28 | ## [cuda 11.8] 29 | # conda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 pytorch-cuda=11.8 -c pytorch -c nvidia 30 | pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118 31 | 32 | # install waymo helper 33 | pip install waymo-open-dataset-tf-2-12-0 34 | 35 | # install other packages 36 | pip install -r setup/requirements.txt 37 | ``` 38 | 39 | **Step 3:** Compile CUDA code 40 | 41 | ```bash 42 | conda install git 43 | conda install -c conda-forge ninja 44 | 45 | # we use compute nodes with CUDA 11.8 46 | python setup/setup_trajflow.py develop 47 | 48 | # if you don't have CUDA 11.8 installed, you can use the following command to install it 49 | # ref: https://stackoverflow.com/questions/67483626/setup-tensorflow-2-4-on-ubuntu-20-04-with-gpu-without-sudo 50 | mkdir -p cuda_toolkits/cuda-11.8 && mkdir -p tmp 51 | wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 52 | bash ./cuda_11.8.0_520.61.05_linux.run --silent --tmpdir=$(pwd)/tmp --toolkit --toolkitpath=$(pwd)/cuda_toolkits/cuda-11.8 53 | rm -rf tmp 54 | export CUDA11=$(pwd)/cuda_toolkits/cuda-11.8 55 | export PATH=$CUDA11/bin:$PATH 56 | export LD_LIBRARY_PATH=$CUDA11/lib64:$CUDA11/extras/CUPTI/lib64:$LD_LIBRARY_PATH 57 | export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX" 58 | python setup/setup_trajflow.py develop 59 | ``` 60 | 61 | Finally, run the following for sanity check: 62 | 63 | ```bash 64 | python -c "import torch; import trajflow; print(torch.__version__, trajflow.__file__, 'pytorch sanity check pass'); " 65 | python -c "from waymo_open_dataset.metrics.ops import py_metrics_ops; print('waymo metrics sanity check pass'); " 66 | ``` 67 | 68 | ## Waymo Dataset Preparation 69 | 70 | **Step 1:** Download Waymo Open Motion Dataset `v1.3.0` from the [official website](https://waymo.com/open/download/) at `waymo_open_dataset_motion_v_1_3_0/uncompressed/scenario`, and organize the data as follows: 71 | 72 | ```bash 73 | ├── data 74 | │ ├── waymo 75 | │ │ ├── scenario 76 | │ │ │ ├──training 77 | │ │ │ ├──validation 78 | │ │ │ ├──testing 79 | ├── ... 80 | ``` 81 | 82 | **Step 2:** Preprocess the dataset: 83 | 84 | ```bash 85 | cd trajflow/datasets/waymo 86 | python data_preprocess.py ../../../data/waymo/scenario/ ../../../data/waymo 87 | ``` 88 | 89 | The processed data will be saved to `data/waymo/` directory as follows: 90 | 91 | ```bash 92 | ├── data 93 | │ ├── waymo 94 | │ │ ├── processed_scenarios_training 95 | │ │ ├── processed_scenarios_validation 96 | │ │ ├── processed_scenarios_testing 97 | │ │ ├── processed_scenarios_validation_interactive 98 | │ │ ├── processed_scenarios_testing_interactive 99 | │ │ ├── processed_scenarios_training_infos.pkl 100 | │ │ ├── processed_scenarios_val_infos.pkl 101 | │ │ ├── processed_scenarios_test_infos.pkl 102 | │ │ ├── processed_scenarios_val_inter_infos.pkl 103 | │ │ ├── processed_scenarios_test_inter_infos.pkl 104 | ├── ... 105 | ``` 106 | 107 | We use the clustering result from [MTR](https://github.com/sshaoshuai/MTR) for intention points, which is saved in `data/waymo/cluster_64_center_dict.pkl`. 108 | 109 | ## Training and Evaluation 110 | 111 | ```bash 112 | ## setup wandb credentials 113 | wandb login 114 | 115 | ## training 116 | cd runner 117 | bash scripts/dist_train.sh 4 --cfg_file cfgs/waymo/trajflow+100_percent_data.yaml --epoch 40 --batch_size 80 --extra_tag trajflow --max_ckpt_save_num 100 --ckpt_save_interval 1 118 | 119 | ## evaluation 120 | ### validation set 121 | python test.py --batch_size 64 --extra_tag=$(hostname) \ 122 | --ckpt ${PATH_TO_CKPT} \ 123 | --val --full_eval 124 | 125 | ### testing set (for submission) 126 | python test.py --batch_size 64 --extra_tag=$(hostname) \ 127 | --ckpt ${PATH_TO_CKPT} \ 128 | --test --full_eval --submit --email ${EMAIL} --method_nm ${METHOD_NAME} 129 | ``` 130 | 131 | ## Acknowledgment and Contact 132 | 133 | We would like to thank the [MTR](https://github.com/sshaoshuai/MTR) and [BeTopNet](https://github.com/OpenDriveLab/BeTop) repositories for their open-source codebase. 134 | 135 | If you have any questions, please contact [Qi Yan](mailto:qi.yan@ece.ubc.ca). 136 | 137 | ## Citation 138 | 139 | If you find this work useful, please consider citing: 140 | 141 | ```bibtex 142 | @article{yan2025trajflow, 143 | title={TrajFlow: Multi-modal Motion Prediction via Flow Matching}, 144 | author={Yan, Qi and Zhang, Brian and Zhang, Yutong and Yang, Daniel and White, Joshua and Chen, Di and Liu, Jiachao and Liu, Langechuan and Zhuang, Binnan and Shi, Shaoshuai and others}, 145 | journal={arXiv preprint arXiv:2506.08541}, 146 | year={2025} 147 | } 148 | ``` 149 | -------------------------------------------------------------------------------- /trajflow/utils/denoising_data_rescale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | META_INFO = { 13 | # 'range_x': [-75, 300], # from val GT 14 | # 'range_y': [-120, 125], # from val GT 15 | # 'range_x': [-240, 335], # from train GT 16 | # 'range_y': [-130, 135], # from train GT 17 | # 'range_x': [-250.0, 350.0], # overall GT 18 | # 'range_y': [-150.0, 150.0], # overall GT 19 | 'range_x': [-10.0, 170.0], # 0.1% - 99.9% percentile 20 | 'range_y': [-60.0, 60.0], # 0.1% - 99.9% percentile 21 | # 'range_x': [0.0, 2.0], # hack, not really changing the range 22 | # 'range_y': [0.0, 2.0], # hack, not really changing the range 23 | 24 | 'sqrt_x_coef': 15.0, 25 | 'sqrt_y_coef': 15.0, 26 | 'cbrt_x_coef': 8.0, 27 | 'cbrt_y_coef': 8.0, 28 | 29 | 'sqrt_x_offset': -0.3, 30 | 'sqrt_y_offset': 0.0, 31 | } 32 | 33 | 34 | def shift_data_to_normalize(traj_data, traj_mask, data_rescale='sqrt', meta_info=META_INFO): 35 | """ 36 | @param traj_data: [N, T, 2] or [N, T, 3] or [N, X, Y, 3] 37 | @param tarj_mask: [N, T] 38 | @param meta_info: dict 39 | @param data_rescale: str 40 | """ 41 | traj_data = traj_data.clone() 42 | assert traj_data.size(-1) == 2 or traj_data.size(-1) == 3 # the third dimension for z-coord is not changed if it exists 43 | 44 | if traj_data.size(-1) == 3: 45 | traj_data_z = traj_data[..., 2].clone() 46 | if traj_mask is None: 47 | traj_mask = torch.ones_like(traj_data[..., 0]).bool() 48 | else: 49 | if len(traj_mask.shape) in [len(traj_data.shape) - 1, len(traj_data.shape)]: 50 | pass 51 | else: 52 | breakpoint() 53 | ori_padded_data = traj_data[torch.logical_not(traj_mask)] 54 | 55 | if data_rescale == 'linear': 56 | min_x, max_x = meta_info['range_x'] 57 | min_y, max_y = meta_info['range_y'] 58 | 59 | traj_data[..., 0] = (traj_data[..., 0] - min_x) / (max_x - min_x) * 2 - 1 60 | traj_data[..., 1] = (traj_data[..., 1] - min_y) / (max_y - min_y) * 2 - 1 61 | elif data_rescale == 'sqrt': 62 | traj_data = torch.abs(traj_data).sqrt() * torch.sign(traj_data) # [N, T, 2] 63 | traj_data[..., 0] = traj_data[..., 0] / meta_info['sqrt_x_coef'] + meta_info['sqrt_x_offset'] # x-coord 64 | traj_data[..., 1] = traj_data[..., 1] / meta_info['sqrt_y_coef'] + meta_info['sqrt_y_offset'] # y-coord 65 | elif data_rescale == 'cbrt': 66 | traj_data = torch.abs(traj_data).pow(1/3) * torch.sign(traj_data) 67 | traj_data[..., 0] = traj_data[..., 0] / meta_info['cbrt_x_coef'] # x-coord 68 | traj_data[..., 1] = traj_data[..., 1] / meta_info['cbrt_y_coef'] # y-coord 69 | elif data_rescale == 'log_center': 70 | # center the data around 0, then take log 71 | min_x, max_x = meta_info['range_x'] 72 | min_y, max_y = meta_info['range_y'] 73 | traj_data[..., 0] = traj_data[..., 0] - (min_x + max_x) / 2 # de-mean x 74 | traj_data[..., 1] = traj_data[..., 1] - (min_y + max_y) / 2 # de-mean y 75 | 76 | traj_data = torch.log(traj_data.abs() + 1) * torch.sign(traj_data) 77 | traj_data[..., 0] /= np.log(max_x - (min_x + max_x) / 2 + 1) 78 | traj_data[..., 1] /= np.log(max_y - (min_y + max_y) / 2 + 1) 79 | 80 | if traj_data.size(-1) == 3: 81 | traj_data[..., 2] = traj_data_z 82 | traj_data[torch.logical_not(traj_mask)] = ori_padded_data 83 | return traj_data 84 | 85 | 86 | def shift_data_to_denormalize(traj_data, traj_mask, data_rescale='sqrt', meta_info=META_INFO): 87 | """ 88 | @param traj_data: [N, T, 2] 89 | @param tarj_mask: [N, T] 90 | @param meta_info: dict 91 | @param data_rescale: str 92 | """ 93 | traj_data = traj_data.clone() 94 | assert traj_data.size(-1) == 2 95 | if traj_mask is None: 96 | traj_mask = torch.ones_like(traj_data[..., 0]).bool() 97 | else: 98 | traj_mask = traj_mask.unsqueeze(1).expand(-1, traj_data.size(1), -1) if len(traj_data.shape) == 4 else traj_mask 99 | 100 | flag_apply_mask = torch.logical_not(traj_mask).sum() 101 | if flag_apply_mask: 102 | ori_pad_val = traj_data[torch.logical_not(traj_mask)].unique() 103 | # if len(ori_pad_val) != 1: 104 | # breakpoint() 105 | assert len(ori_pad_val) == 1 106 | ori_pad_val = ori_pad_val[0] 107 | 108 | if data_rescale == 'linear': 109 | min_x, max_x = meta_info['range_x'] 110 | min_y, max_y = meta_info['range_y'] 111 | traj_data[..., 0] = (traj_data[..., 0] + 1) / 2 * (max_x - min_x) + min_x 112 | traj_data[..., 1] = (traj_data[..., 1] + 1) / 2 * (max_y - min_y) + min_y 113 | elif data_rescale == 'sqrt': 114 | traj_data[..., 0] = (traj_data[..., 0] - meta_info['sqrt_x_offset']) * meta_info['sqrt_x_coef'] 115 | traj_data[..., 1] = (traj_data[..., 1] - meta_info['sqrt_y_offset']) * meta_info['sqrt_y_coef'] 116 | traj_data = traj_data.abs().pow(2) * traj_data.sign() 117 | elif data_rescale == 'cbrt': 118 | traj_data[..., 0] = traj_data[..., 0] * meta_info['cbrt_x_coef'] 119 | traj_data[..., 1] = traj_data[..., 1] * meta_info['cbrt_y_coef'] 120 | traj_data = traj_data.abs().pow(3) * traj_data.sign() 121 | elif data_rescale == 'log_center': 122 | min_x, max_x = meta_info['range_x'] 123 | min_y, max_y = meta_info['range_y'] 124 | traj_data[..., 0] = traj_data[..., 0] * np.log(max_x - (min_x + max_x) / 2 + 1) 125 | traj_data[..., 1] = traj_data[..., 1] * np.log(max_y - (min_y + max_y) / 2 + 1) 126 | 127 | traj_data = (torch.exp(traj_data.abs()) - 1) * traj_data.sign() 128 | traj_data[..., 0] += (min_x + max_x) / 2 129 | traj_data[..., 1] += (min_y + max_y) / 2 130 | 131 | if flag_apply_mask: 132 | traj_data[torch.logical_not(traj_mask)] = ori_pad_val 133 | 134 | return traj_data 135 | 136 | -------------------------------------------------------------------------------- /trajflow/models/layers/transformer/transformer_encoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | """ 14 | Reference: https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/transformer/multi_head_attention.py 15 | """ 16 | 17 | from typing import Optional 18 | 19 | import torch 20 | from torch import nn, Tensor 21 | import torch.nn.functional as F 22 | from .multi_head_attention import MultiheadAttention 23 | try: 24 | from .multi_head_attention_local import MultiheadAttentionLocal 25 | except: 26 | import os 27 | print("{:s} Fail to import MultiheadAttentionLocal module at {:s}. CUDA availability: {} {:s}".format('-' * 20, os.path.basename(__file__), torch.cuda.is_available(), '-' * 20)) 28 | 29 | 30 | def _get_activation_fn(activation): 31 | """Return an activation function given a string""" 32 | if activation == "relu": 33 | return F.relu 34 | if activation == "gelu": 35 | return F.gelu 36 | if activation == "glu": 37 | return F.glu 38 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 39 | 40 | 41 | class TransformerEncoderLayer(nn.Module): 42 | 43 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 44 | activation="relu", normalize_before=False, use_local_attn=False): 45 | super().__init__() 46 | self.use_local_attn = use_local_attn 47 | 48 | if self.use_local_attn: 49 | self.self_attn = MultiheadAttentionLocal(d_model, nhead, dropout=dropout) 50 | else: 51 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 52 | 53 | ### pytorch official implementation of MultiheadAttention ### 54 | # these two implementations should result in the same outcome 55 | # from torch.nn.modules.activation import MultiheadAttention as MultiheadAttentionOffcial 56 | # self.self_attn = MultiheadAttentionOffcial(d_model, nhead, dropout=dropout) 57 | ### pytorch official implementation of MultiheadAttention ### 58 | 59 | # Implementation of Feedforward model 60 | self.linear1 = nn.Linear(d_model, dim_feedforward) 61 | self.dropout = nn.Dropout(dropout) 62 | self.linear2 = nn.Linear(dim_feedforward, d_model) 63 | 64 | self.norm1 = nn.LayerNorm(d_model) 65 | self.norm2 = nn.LayerNorm(d_model) 66 | self.dropout1 = nn.Dropout(dropout) 67 | self.dropout2 = nn.Dropout(dropout) 68 | 69 | self.activation = _get_activation_fn(activation) 70 | self.normalize_before = normalize_before 71 | 72 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 73 | return tensor if pos is None else tensor + pos 74 | 75 | def forward_post(self, 76 | src, 77 | src_mask: Optional[Tensor] = None, 78 | src_key_padding_mask: Optional[Tensor] = None, 79 | pos: Optional[Tensor] = None, 80 | index_pair=None, 81 | query_batch_cnt=None, 82 | key_batch_cnt=None, 83 | index_pair_batch=None): 84 | q = k = self.with_pos_embed(src, pos) 85 | if self.use_local_attn: 86 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 87 | key_padding_mask=src_key_padding_mask, 88 | index_pair=index_pair, query_batch_cnt=query_batch_cnt, 89 | key_batch_cnt=key_batch_cnt, index_pair_batch=index_pair_batch)[0] 90 | else: 91 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 92 | src = src + self.dropout1(src2) 93 | src = self.norm1(src) 94 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 95 | src = src + self.dropout2(src2) 96 | src = self.norm2(src) 97 | return src 98 | 99 | def forward_pre(self, src, 100 | src_mask: Optional[Tensor] = None, 101 | src_key_padding_mask: Optional[Tensor] = None, 102 | pos: Optional[Tensor] = None, 103 | index_pair=None, 104 | query_batch_cnt=None, 105 | key_batch_cnt=None, 106 | index_pair_batch=None): 107 | src2 = self.norm1(src) 108 | q = k = self.with_pos_embed(src2, pos) 109 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 110 | key_padding_mask=src_key_padding_mask, 111 | index_pair=index_pair, query_batch_cnt=query_batch_cnt, 112 | key_batch_cnt=key_batch_cnt, index_pair_batch=index_pair_batch)[0] 113 | src = src + self.dropout1(src2) 114 | src2 = self.norm2(src) 115 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 116 | src = src + self.dropout2(src2) 117 | return src 118 | 119 | def forward(self, src, 120 | src_mask: Optional[Tensor] = None, 121 | src_key_padding_mask: Optional[Tensor] = None, 122 | pos: Optional[Tensor] = None, 123 | # for local-attn 124 | index_pair=None, 125 | query_batch_cnt=None, 126 | key_batch_cnt=None, 127 | index_pair_batch=None): 128 | if self.normalize_before: 129 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos, 130 | index_pair=index_pair, query_batch_cnt=query_batch_cnt, 131 | key_batch_cnt=key_batch_cnt, index_pair_batch=index_pair_batch) 132 | return self.forward_post(src, src_mask, src_key_padding_mask, pos, 133 | index_pair=index_pair, query_batch_cnt=query_batch_cnt, 134 | key_batch_cnt=key_batch_cnt, index_pair_batch=index_pair_batch) -------------------------------------------------------------------------------- /trajflow/mtr_ops/knn/src/knn_gpu.cu: -------------------------------------------------------------------------------- 1 | // Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 2 | // Published at NeurIPS 2022 3 | // Written by Li Jiang, Shaoshuai Shi 4 | // All Rights Reserved 5 | 6 | 7 | #include "knn_gpu.h" 8 | 9 | #include 10 | #include 11 | #define THREADS_PER_BLOCK 256 12 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 13 | 14 | __global__ void knn_batch_cuda_(int n, int m, int k, const float *__restrict__ xyz, const float *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { 15 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 16 | if (pt_idx >= n) return; 17 | 18 | xyz += pt_idx * 3; 19 | idx += pt_idx * k; 20 | 21 | float ox = xyz[0]; 22 | float oy = xyz[1]; 23 | float oz = xyz[2]; 24 | 25 | float best[100]; 26 | int besti[100]; 27 | for(int i = 0; i < k; i++){ 28 | best[i] = 1e20; 29 | besti[i] = -1; 30 | } 31 | 32 | int batch_idx = batch_idxs[pt_idx]; 33 | int start = query_batch_offsets[batch_idx]; 34 | int end = query_batch_offsets[batch_idx + 1]; 35 | 36 | for (int i = start; i < end; ++i) { 37 | float x = query_xyz[i * 3 + 0]; 38 | float y = query_xyz[i * 3 + 1]; 39 | float z = query_xyz[i * 3 + 2]; 40 | float d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); 41 | for(int p = 0; p < k; p++){ 42 | if(d2 < best[p]){ 43 | for(int q = k - 1; q > p; q--){ 44 | best[q] = best[q - 1]; 45 | besti[q] = besti[q - 1]; 46 | } 47 | best[p] = d2; 48 | besti[p] = i - start; 49 | break; 50 | } 51 | } 52 | } 53 | 54 | for(int i = 0; i < k; i++){ 55 | idx[i] = besti[i]; 56 | } 57 | } 58 | 59 | 60 | __global__ void knn_batch_mlogk_cuda_(int n, int m, int k, const float *__restrict__ xyz, const float *__restrict__ query_xyz, const int *__restrict__ batch_idxs, const int *__restrict__ query_batch_offsets, int *__restrict__ idx) { 61 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 62 | if (pt_idx >= n) return; 63 | 64 | xyz += pt_idx * 3; 65 | idx += pt_idx * k; 66 | 67 | float ox = xyz[0]; 68 | float oy = xyz[1]; 69 | float oz = xyz[2]; 70 | 71 | float best[150]; 72 | int besti[150]; 73 | 74 | int heap_len = 0; 75 | 76 | for(int i = 0; i <= k; i++){ 77 | best[i] = 1e20; 78 | besti[i] = -1; 79 | } 80 | 81 | int batch_idx = batch_idxs[pt_idx]; 82 | int start = query_batch_offsets[batch_idx]; 83 | int end = query_batch_offsets[batch_idx + 1]; 84 | int temp_i; 85 | float temp_f; 86 | 87 | for (int i = start; i < end; ++i) { 88 | float x = query_xyz[i * 3 + 0]; 89 | float y = query_xyz[i * 3 + 1]; 90 | float z = query_xyz[i * 3 + 2]; 91 | float d2 = (ox - x) * (ox - x) + (oy - y) * (oy - y) + (oz - z) * (oz - z); 92 | 93 | if (heap_len < k){ 94 | heap_len++; 95 | best[heap_len] = d2; 96 | besti[heap_len] = i - start; 97 | int cur_idx = heap_len, fa_idx = cur_idx >> 1; 98 | 99 | while (fa_idx > 0){ 100 | if (best[cur_idx] < best[fa_idx]) break; 101 | 102 | temp_i = besti[cur_idx]; besti[cur_idx] = besti[fa_idx]; besti[fa_idx] = temp_i; 103 | temp_f = best[cur_idx]; best[cur_idx] = best[fa_idx]; best[fa_idx] = temp_f; 104 | cur_idx = fa_idx; 105 | fa_idx = cur_idx >> 1; 106 | } 107 | } 108 | else{ 109 | if (d2 > best[1]) continue; 110 | best[1] = d2; besti[1] = i - start; 111 | 112 | int cur_idx = 1, son_idx; 113 | while (cur_idx <= k){ 114 | son_idx = cur_idx << 1; 115 | if (son_idx > k) break; 116 | if (son_idx + 1 <= k && best[son_idx] < best[son_idx + 1]){ 117 | son_idx++; 118 | } 119 | 120 | if (son_idx <= k && best[cur_idx] < best[son_idx]){ 121 | temp_i = besti[cur_idx]; besti[cur_idx] = besti[son_idx]; besti[son_idx] = temp_i; 122 | temp_f = best[cur_idx]; best[cur_idx] = best[son_idx]; best[son_idx] = temp_f; 123 | } 124 | else break; 125 | cur_idx = son_idx; 126 | } 127 | } 128 | } 129 | 130 | for(int i = 1; i <= k; i++){ 131 | idx[i - 1] = besti[i]; 132 | } 133 | // delete [] best; 134 | // delete [] besti; 135 | } 136 | 137 | 138 | 139 | 140 | void knn_batch_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { 141 | // param xyz: (n, 3), float 142 | // param query_xyz: (m, 3), float 143 | // param batch_idxs: (n), int 144 | // param query_batch_offsets: (B + 1), int, offsets[-1] = m 145 | // param idx: (n, k), int 146 | 147 | cudaError_t err; 148 | 149 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) 150 | dim3 threads(THREADS_PER_BLOCK); 151 | 152 | knn_batch_cuda_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); 153 | // cudaDeviceSynchronize(); // for using printf in kernel function 154 | 155 | err = cudaGetLastError(); 156 | if (cudaSuccess != err) { 157 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 158 | exit(-1); 159 | } 160 | } 161 | 162 | 163 | void knn_batch_mlogk_cuda(int n, int m, int k, const float *xyz, const float *query_xyz, const int *batch_idxs, const int *query_batch_offsets, int *idx, cudaStream_t stream) { 164 | // param xyz: (n, 3), float 165 | // param query_xyz: (m, 3), float 166 | // param batch_idxs: (n), int 167 | // param query_batch_offsets: (B + 1), int, offsets[-1] = m 168 | // param idx: (n, k), int 169 | 170 | cudaError_t err; 171 | 172 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row) 173 | dim3 threads(THREADS_PER_BLOCK); 174 | 175 | knn_batch_mlogk_cuda_<<>>(n, m, k, xyz, query_xyz, batch_idxs, query_batch_offsets, idx); 176 | // cudaDeviceSynchronize(); // for using printf in kernel function 177 | 178 | err = cudaGetLastError(); 179 | if (cudaSuccess != err) { 180 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 181 | exit(-1); 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /trajflow/models/denoising_decoder/compute_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | 18 | from trajflow.utils.mtr_loss_utils import nll_loss_gmm_direct 19 | 20 | 21 | class LossBuffer: 22 | def __init__(self, t_min, t_max, num_time_steps): 23 | """ 24 | Initialize the LossBuffer with the specified number of denoising levels. 25 | """ 26 | self.t_min = t_min 27 | self.t_max = t_max 28 | self.num_time_steps = num_time_steps 29 | self.t_interval = np.linspace(t_min, t_max, num_time_steps) 30 | self.loss_data = [[] for _ in range(self.num_time_steps)] 31 | self.last_epoch = -1 32 | 33 | def record_loss(self, t, loss, epoch_id): 34 | """ 35 | Record the loss for a specific denoising level. 36 | @param t: [B] the denoising level. 37 | @param loss: [B] the loss value. 38 | """ 39 | 40 | flag_reset = False 41 | if epoch_id != self.last_epoch: 42 | self.last_epoch = epoch_id 43 | self.reset() 44 | flag_reset = epoch_id > 0 45 | 46 | if isinstance(t, torch.Tensor): 47 | t = t.cpu().numpy() 48 | if isinstance(loss, torch.Tensor): 49 | loss = loss.cpu().numpy() 50 | 51 | idx = np.digitize(t, self.t_interval) - 1 52 | for i, l in zip(idx, loss): 53 | self.loss_data[i].append(l) 54 | 55 | return flag_reset 56 | 57 | def reset(self): 58 | """ 59 | Reset the loss data for a new epoch. 60 | """ 61 | self.loss_data = [[] for _ in range(self.num_time_steps)] 62 | 63 | def get_average_loss(self): 64 | """ 65 | To be used for plotting a histogram of denoising level vs. average loss for the last epoch. 66 | """ 67 | avg_loss_per_level = [np.mean(l) if len(l) > 0 else -1.0 for l in self.loss_data] 68 | dict_loss_per_level = {t: l for t, l in zip(self.t_interval, avg_loss_per_level)} 69 | return dict_loss_per_level 70 | 71 | 72 | def first_occurrence_mask_fast(x): 73 | # x: [B, N] 74 | B, N = x.size() 75 | mask = torch.zeros_like(x, dtype=torch.bool) 76 | for i in range(B): 77 | # torch.unique with sorted=False preserves the order of appearance. 78 | _, first_indices = torch.unique(x[i], sorted=False, return_inverse=True) 79 | mask[i, first_indices] = True 80 | return mask 81 | 82 | 83 | def plackett_luce_loss(logits, preference_argsort): 84 | """ 85 | Compute the Plackett-Luce loss for a batch of samples. 86 | @params logits: [B, N], predicted logits (unnormalized scores) for each item 87 | @params preference_argsort: [B, N], the preference order of the items (from the best to the worst) 88 | Note: ranks_idx must be distinct and in the range [0, N-1] 89 | """ 90 | 91 | # Reorder logits according to ranks, from the best to the worst 92 | # z[r_1], z[r_2], ..., z[r_N], level of preference: r_1 > r_2 > ... > r_N 93 | ordered_logits = torch.gather(logits, dim=1, index=preference_argsort) # [B, N] 94 | 95 | # Compute cumulative log-sum-exp 96 | cumulative_log_sum_exp = torch.logcumsumexp(ordered_logits, dim=-1) # [B, N] 97 | 98 | # Compute the loss 99 | log_probs = ordered_logits - cumulative_log_sum_exp # z[r_i] - log(sum(exp(z[r_i:]))) for all i 100 | loss = -log_probs 101 | 102 | # Check the uniqueness of the ranks, ignore repeated ranks 103 | loss_mask = first_occurrence_mask_fast(preference_argsort) # [B, N] 104 | loss = (loss * loss_mask.float()).mean(dim=-1) # [B] 105 | return loss 106 | 107 | 108 | def get_dense_future_prediction_loss(forward_ret_dict, wb_pre_tag='', wb_dict=None, disp_dict=None): 109 | obj_trajs_future_state = forward_ret_dict['obj_trajs_future_state'].cuda() 110 | obj_trajs_future_mask = forward_ret_dict['obj_trajs_future_mask'].cuda() 111 | pred_dense_trajs = forward_ret_dict['pred_dense_trajs'] 112 | assert pred_dense_trajs.shape[-1] == 7 113 | assert obj_trajs_future_state.shape[-1] == 4 114 | 115 | pred_dense_trajs_gmm, pred_dense_trajs_vel = pred_dense_trajs[:, :, :, 0:5], pred_dense_trajs[:, :, :, 5:7] 116 | 117 | loss_reg_vel = F.l1_loss(pred_dense_trajs_vel, obj_trajs_future_state[:, :, :, 2:4], reduction='none') 118 | loss_reg_vel = (loss_reg_vel * obj_trajs_future_mask[:, :, :, None]).sum(dim=-1).sum(dim=-1) 119 | 120 | num_center_objects, num_objects, num_timestamps, _ = pred_dense_trajs.shape 121 | fake_scores = pred_dense_trajs.new_zeros((num_center_objects, num_objects)).view(-1, 1) # (num_center_objects * num_objects, 1) 122 | 123 | temp_pred_trajs = pred_dense_trajs_gmm.contiguous().view(num_center_objects * num_objects, 1, num_timestamps, 5) 124 | temp_gt_idx = torch.zeros(num_center_objects * num_objects).cuda().long() # (num_center_objects * num_objects) 125 | temp_gt_trajs = obj_trajs_future_state[:, :, :, 0:2].contiguous().view(num_center_objects * num_objects, num_timestamps, 2) 126 | temp_gt_trajs_mask = obj_trajs_future_mask.view(num_center_objects * num_objects, num_timestamps) 127 | loss_reg_gmm, _ = nll_loss_gmm_direct( 128 | pred_scores=fake_scores, pred_trajs=temp_pred_trajs, gt_trajs=temp_gt_trajs, gt_valid_mask=temp_gt_trajs_mask, 129 | pre_nearest_mode_idxs=temp_gt_idx, 130 | timestamp_loss_weight=None, use_square_gmm=False, 131 | ) 132 | loss_reg_gmm = loss_reg_gmm.view(num_center_objects, num_objects) 133 | 134 | loss_reg = loss_reg_vel + loss_reg_gmm 135 | 136 | obj_valid_mask = obj_trajs_future_mask.sum(dim=-1) > 0 137 | 138 | loss_reg = (loss_reg * obj_valid_mask.float()).sum(dim=-1) / torch.clamp_min(obj_valid_mask.sum(dim=-1), min=1.0) 139 | loss_reg = loss_reg.mean() 140 | 141 | if wb_dict is None: 142 | wb_dict = {} 143 | if disp_dict is None: 144 | disp_dict = {} 145 | 146 | wb_dict[f'{wb_pre_tag}loss_dense_prediction'] = loss_reg.item() 147 | return loss_reg, wb_dict, disp_dict 148 | -------------------------------------------------------------------------------- /trajflow/models/layers/common_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import math 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class VecWeightNorm(nn.Module): 19 | """ 20 | Weight normalization module. 21 | """ 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward(self, in_tensor): 26 | return nn.functional.normalize(in_tensor, p=2, dim=-1) 27 | 28 | 29 | def build_mlps(c_in, mlp_channels=None, ret_before_act=False, without_norm=False, layer_norm=False, weight_norm=False): 30 | layers = [] 31 | num_layers = len(mlp_channels) 32 | 33 | for k in range(num_layers): 34 | if k + 1 == num_layers and ret_before_act: 35 | layers.append(nn.Linear(c_in, mlp_channels[k], bias=True)) 36 | else: 37 | if without_norm: 38 | layers.extend([nn.Linear(c_in, mlp_channels[k], bias=True), nn.ReLU()]) 39 | else: 40 | if layer_norm: 41 | layers.extend([nn.Linear(c_in, mlp_channels[k], bias=False), nn.LayerNorm(mlp_channels[k]), nn.ReLU()]) 42 | else: 43 | layers.extend([nn.Linear(c_in, mlp_channels[k], bias=False), nn.BatchNorm1d(mlp_channels[k]), nn.ReLU()]) 44 | c_in = mlp_channels[k] 45 | 46 | if weight_norm: 47 | layers = [VecWeightNorm()] + layers 48 | 49 | return nn.Sequential(*layers) 50 | 51 | 52 | def gen_sineembed_for_position(pos_tensor, hidden_dim=256): 53 | """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/ 54 | """ 55 | # n_query, bs, _ = pos_tensor.size() 56 | # sineembed_tensor = torch.zeros(n_query, bs, 256) 57 | half_hidden_dim = hidden_dim // 2 58 | scale = 2 * math.pi 59 | dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device) 60 | dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim) 61 | x_embed = pos_tensor[:, :, 0] * scale 62 | y_embed = pos_tensor[:, :, 1] * scale 63 | pos_x = x_embed[:, :, None] / dim_t 64 | pos_y = y_embed[:, :, None] / dim_t 65 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 66 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 67 | if pos_tensor.size(-1) == 2: 68 | pos = torch.cat((pos_y, pos_x), dim=2) 69 | elif pos_tensor.size(-1) == 4: 70 | w_embed = pos_tensor[:, :, 2] * scale 71 | pos_w = w_embed[:, :, None] / dim_t 72 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) 73 | 74 | h_embed = pos_tensor[:, :, 3] * scale 75 | pos_h = h_embed[:, :, None] / dim_t 76 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) 77 | 78 | pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) 79 | else: 80 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) 81 | return pos 82 | 83 | 84 | class TimestepEmbedder(nn.Module): 85 | """ 86 | Embeds scalar timesteps into vector representations. 87 | borrowed from DiT 88 | """ 89 | def __init__(self, hidden_size, frequency_embedding_size=256, bias=True, max_period=None): 90 | super().__init__() 91 | self.mlp = nn.Sequential( 92 | nn.Linear(frequency_embedding_size, hidden_size, bias=bias), 93 | nn.SiLU(), 94 | nn.Linear(hidden_size, hidden_size, bias=bias), 95 | ) 96 | self.frequency_embedding_size = frequency_embedding_size 97 | self.max_period = max_period 98 | 99 | @staticmethod 100 | def timestep_embedding(t, dim, max_period=10000): 101 | """ 102 | Create sinusoidal timestep embeddings. 103 | :param t: a 1-D Tensor of N indices, one per batch element. 104 | These may be fractional. 105 | :param dim: the dimension of the output. 106 | :param max_period: controls the minimum frequency of the embeddings. 107 | :return: an (N, D) Tensor of positional embeddings. 108 | """ 109 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 110 | half = dim // 2 111 | freqs = torch.exp( 112 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 113 | ).to(device=t.device) 114 | args = t[:, None].float() * freqs[None] 115 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 116 | if dim % 2: 117 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 118 | return embedding 119 | 120 | def forward(self, t): 121 | _max_period = self.max_period if self.max_period is not None else 10000 122 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size, max_period=_max_period) 123 | t_emb = self.mlp(t_freq) 124 | return t_emb 125 | 126 | 127 | class FlexIdentity(nn.Module): 128 | """ 129 | Flexibly applies an identity function to the input. 130 | borrowed from DiT 131 | """ 132 | def __init__(self, constant_output = None): 133 | super().__init__() 134 | self.constant_output = constant_output 135 | 136 | def forward(self, x, *args, **kwargs): 137 | if self.constant_output is None: 138 | return x 139 | else: 140 | return torch.empty_like(x).fill_(self.constant_output) 141 | 142 | 143 | class LlamaRMSNorm(nn.Module): 144 | """ 145 | Simplified RMSNorm layer used in Llama 3.1 backbone. 146 | """ 147 | def __init__(self, hidden_size, eps=1e-5): 148 | """ 149 | LlamaRMSNorm is equivalent to T5LayerNorm 150 | """ 151 | super().__init__() 152 | self.weight = nn.Parameter(torch.ones(hidden_size)) 153 | self.variance_epsilon = eps 154 | 155 | def forward(self, hidden_states): 156 | input_dtype = hidden_states.dtype 157 | hidden_states = hidden_states.to(torch.float32) 158 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 159 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 160 | return self.weight * hidden_states.to(input_dtype) 161 | 162 | def extra_repr(self): 163 | return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" 164 | 165 | -------------------------------------------------------------------------------- /runner/utils/starter/network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import copy 10 | import torch 11 | import torch.optim.lr_scheduler as lr_sched 12 | from glob import glob 13 | from ema_pytorch import EMA 14 | from trajflow.models.dmt_model import DenoisingMotionTransformer 15 | from trajflow.denoising.flow_matching import FlowMatcher 16 | 17 | 18 | def init_network(cfg, logger): 19 | """ 20 | Initialize the networks. 21 | """ 22 | 23 | # build model 24 | model = DenoisingMotionTransformer(config=cfg, logger=logger) 25 | if cfg.OPT.DIST_TRAIN: 26 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 27 | model.to(cfg.DEVICE) 28 | 29 | # Build diffusion objective 30 | denoising_cfg = cfg.MODEL_DMT.DENOISING 31 | 32 | denoiser = FlowMatcher( 33 | model=model, 34 | sampling_timesteps=denoising_cfg.FM.SAMPLING_STEPS, 35 | objective=denoising_cfg.FM.OBJECTIVE, 36 | t_schedule=denoising_cfg.FM.T_SCHEDULE, 37 | logger=logger, 38 | data_rescale=cfg.DATA_CONFIG.DATA_RESCALE, 39 | ckpt_dir=cfg.SAVE_DIR.CKPT_DIR, 40 | model_cfg=copy.deepcopy(cfg.MODEL_DMT) 41 | ) 42 | denoiser.to(cfg.DEVICE) 43 | 44 | return model, denoiser 45 | 46 | 47 | def init_optimizer(model, opt_cfg): 48 | """Initialize optimizer.""" 49 | optimizers = { 50 | 'Adam': torch.optim.Adam, 51 | 'AdamW': torch.optim.AdamW 52 | } 53 | 54 | if opt_cfg.OPTIMIZER not in optimizers: 55 | raise NotImplementedError(f"Optimizer {opt_cfg.OPTIMIZER} not implemented.") 56 | 57 | return optimizers[opt_cfg.OPTIMIZER]( 58 | model.parameters(), 59 | lr=opt_cfg.LR, 60 | weight_decay=opt_cfg.get('WEIGHT_DECAY', 0) 61 | ) 62 | 63 | 64 | def init_ema_helper(model, opt_cfg, logger): 65 | """Setup exponential moving average training helper.""" 66 | ema_coef = opt_cfg.EMA_COEF 67 | 68 | # Determine if EMA should be used 69 | if isinstance(ema_coef, float): 70 | flag_ema = ema_coef < 1 71 | ema_coef = [ema_coef] if flag_ema else None 72 | elif isinstance(ema_coef, list): 73 | flag_ema = True 74 | else: 75 | flag_ema = False 76 | ema_coef = None 77 | 78 | if not flag_ema: 79 | logger.info("Exponential moving average is OFF.") 80 | return None 81 | 82 | # Create EMA helpers 83 | ema_helper = [ 84 | EMA(model=model, beta=coef, update_every=1, update_after_step=0, inv_gamma=1, power=1) 85 | for coef in sorted(ema_coef) 86 | ] 87 | logger.info(f"Exponential moving average is ON. Coefficient: {ema_coef}") 88 | return ema_helper 89 | 90 | 91 | def load_checkpoint(model, optimizer, ema_helper, logger, ckpt_path, ckpt_dir): 92 | """ 93 | Load checkpoint if it is possible. 94 | """ 95 | start_epoch = it = 0 96 | last_epoch = -1 97 | 98 | if ckpt_path is not None: 99 | # Load checkpoint from specified path 100 | it, start_epoch = model.load_params(ckpt_path, optimizer=optimizer, ema_helper=ema_helper) 101 | last_epoch = start_epoch + 1 102 | return it, start_epoch, last_epoch 103 | 104 | # Load latest checkpoint from directory 105 | ckpt_list = glob(os.path.join(ckpt_dir, '*.pth')) 106 | if not ckpt_list: 107 | logger.info("No checkpoint found. Training from scratch.") 108 | return it, start_epoch, last_epoch 109 | 110 | # Find and load the latest valid checkpoint 111 | ckpt_list.sort(key=os.path.getmtime) 112 | for ckpt_file in reversed(ckpt_list): 113 | if os.path.basename(ckpt_file) == 'best_model.pth': 114 | continue 115 | try: 116 | ckpt_state = torch.load(ckpt_file, map_location=torch.device('cpu')) 117 | it, start_epoch = model.load_params(ckpt_file, ckpt_state=ckpt_state, optimizer=optimizer, ema_helper=ema_helper) 118 | last_epoch = start_epoch + 1 119 | break 120 | except: 121 | continue 122 | 123 | return it, start_epoch, last_epoch 124 | 125 | 126 | def init_scheduler(optimizer, opt_cfg, total_epochs, total_iters_each_epoch, last_epoch): 127 | """Initialize learning rate scheduler.""" 128 | scheduler_type = opt_cfg.get('SCHEDULER', None) 129 | total_iterations = total_epochs * total_iters_each_epoch 130 | 131 | if scheduler_type == 'cosine': 132 | # Cosine annealing with linear warmup 133 | warmup_iterations = max(1, int(total_iterations * 0.05)) 134 | warmup_scheduler = lr_sched.LambdaLR( 135 | optimizer, 136 | lambda step: max(opt_cfg.LR_CLIP / opt_cfg.LR, step / warmup_iterations) 137 | ) 138 | cosine_scheduler = lr_sched.CosineAnnealingLR( 139 | optimizer, 140 | T_max=total_iterations - warmup_iterations, 141 | eta_min=opt_cfg.LR_CLIP 142 | ) 143 | scheduler = lr_sched.SequentialLR( 144 | optimizer, 145 | schedulers=[warmup_scheduler, cosine_scheduler], 146 | milestones=[warmup_iterations] 147 | ) 148 | 149 | elif scheduler_type == 'lambdaLR': 150 | # LambdaLR with decay steps 151 | decay_step_list = opt_cfg.get('DECAY_STEP_LIST', [22, 24, 26, 28]) 152 | if len(decay_step_list) == 1 and decay_step_list[0] == -1: 153 | decay_step_list = [22, 24, 26, 28] 154 | 155 | decay_steps = [x * total_iters_each_epoch for x in decay_step_list] 156 | 157 | def lr_lambda(cur_step): 158 | cur_decay = 1 159 | for decay_step in decay_steps: 160 | if cur_step >= decay_step: 161 | cur_decay *= opt_cfg.LR_DECAY 162 | return max(cur_decay, opt_cfg.LR_CLIP / opt_cfg.LR) 163 | 164 | scheduler = lr_sched.LambdaLR(optimizer, lr_lambda) 165 | 166 | elif scheduler_type == 'linearLR': 167 | # LinearLR 168 | scheduler = lr_sched.LinearLR( 169 | optimizer, 170 | start_factor=1.0, 171 | end_factor=opt_cfg.LR_CLIP / opt_cfg.LR, 172 | total_iters=total_iterations, 173 | last_epoch=last_epoch 174 | ) 175 | 176 | elif scheduler_type == 'constant': 177 | # Constant learning rate 178 | scheduler = lr_sched.LambdaLR(optimizer, lambda x: 1.0, last_epoch=last_epoch) 179 | 180 | else: 181 | raise NotImplementedError(f"Unsupported scheduler: {scheduler_type}") 182 | 183 | # Handle last_epoch for schedulers that don't support it properly 184 | if last_epoch > 0 and scheduler_type in ['cosine', 'lambdaLR']: 185 | for _ in range(last_epoch * total_iters_each_epoch): 186 | scheduler.step() 187 | 188 | return scheduler 189 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_value_computation_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Transformer function helper function. 3 | Written by tomztyang, 4 | 2021/08/23 5 | */ 6 | 7 | #include 8 | #include 9 | 10 | #include "attention_func.h" 11 | 12 | #define THREADS_PER_BLOCK 256 13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 14 | // #define DEBUG 15 | 16 | 17 | __global__ void attention_value_computation_forward( 18 | int b, int total_query_num, int local_size, 19 | int total_key_num, int nhead, int hdim, 20 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 21 | const int *index_pair, 22 | const float *attn_weight, const float* value_features, 23 | float *output) { 24 | // dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 25 | // params query_batch_cnt: [b] 26 | // params key_batch_cnt: [b] 27 | // params index_pair_batch: [total_query_num] 28 | // params index_pair: [total_query_num, local_size] 29 | // params attn_weight: [total_query_num, local_size, nhead] 30 | // params value_features: [total_key_num, nhead, hdim] 31 | // params output: [total_query_num, nhead, hdim] 32 | 33 | int index = blockIdx.x * blockDim.x + threadIdx.x; 34 | int head_idx = blockIdx.y; 35 | int hdim_idx = blockIdx.z; 36 | if (index >= total_query_num * local_size || 37 | head_idx >= nhead || 38 | hdim_idx >= hdim) return; 39 | 40 | if (index_pair[index] == -1){ 41 | // Ignore index. 42 | return; 43 | } 44 | 45 | int query_idx = index / local_size; 46 | int batch_idx = index_pair_batch[query_idx]; 47 | int key_start_idx = 0; 48 | for (int i = 0; i < batch_idx; i++){ 49 | key_start_idx += key_batch_cnt[i]; 50 | } 51 | 52 | // 1. Obtain value features. 53 | key_start_idx += index_pair[index]; 54 | value_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 55 | // 2. Obtain attention weight. 56 | attn_weight += index * nhead + head_idx; 57 | // 3. Do dot product. 58 | output += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; 59 | atomicAdd( 60 | output, 61 | attn_weight[0] * value_features[0]); 62 | } 63 | 64 | 65 | void attention_value_computation_launcher( 66 | int b, int total_query_num, int local_size, 67 | int total_key_num, int nhead, int hdim, 68 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 69 | const int *index_pair, 70 | const float *attn_weight, const float* value_features, 71 | float *output){ 72 | // params query_batch_cnt: [b] 73 | // params key_batch_cnt: [b] 74 | // params index_pair_batch: [total_query_num] 75 | // params index_pair: [total_query_num, local_size] 76 | // params attn_weight: [total_query_num, local_size, nhead] 77 | // params value_features: [total_key_num, nhead, hdim] 78 | // params output: [total_query_num, nhead, hdim] 79 | 80 | dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 81 | dim3 threads(THREADS_PER_BLOCK); 82 | attention_value_computation_forward<<>>( 83 | b, total_query_num, local_size, total_key_num, nhead, hdim, 84 | query_batch_cnt, key_batch_cnt, index_pair_batch, 85 | index_pair, attn_weight, value_features, 86 | output); 87 | } 88 | 89 | 90 | __global__ void attention_value_computation_backward( 91 | int b, int total_query_num, int local_size, 92 | int total_key_num, int nhead, int hdim, 93 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 94 | const int *index_pair, 95 | const float *attn_weight, const float* value_features, 96 | float *grad_out, float * grad_attn_weight, float * grad_value_features) { 97 | // dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 98 | // params query_batch_cnt: [b] 99 | // params key_batch_cnt: [b] 100 | // params index_pair_batch: [total_query_num] 101 | // params index_pair: [total_query_num, local_size] 102 | // params attn_weight: [total_query_num, local_size, nhead] 103 | // params value_features: [total_key_num, nhead, hdim] 104 | // params grad_out: [total_query_num, nhead, hdim] 105 | // params grad_attn_weight: [total_query_num, local_size, nhead] 106 | // params grad_value_features: [total_key_num, nhead, hdim] 107 | 108 | int index = blockIdx.x * blockDim.x + threadIdx.x; 109 | int head_idx = blockIdx.y; 110 | int hdim_idx = blockIdx.z; 111 | if (index >= total_query_num * local_size || 112 | head_idx >= nhead || 113 | hdim_idx >= hdim) return; 114 | 115 | if (index_pair[index] == -1){ 116 | // Ignore index. 117 | return; 118 | } 119 | 120 | int query_idx = index / local_size; 121 | int batch_idx = index_pair_batch[query_idx]; 122 | int key_start_idx = 0; 123 | for (int i = 0; i < batch_idx; i++){ 124 | key_start_idx += key_batch_cnt[i]; 125 | } 126 | 127 | // 1. Obtain value features. 128 | key_start_idx += index_pair[index]; 129 | value_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 130 | grad_value_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 131 | // 2. Obtain attention weight. 132 | attn_weight += index * nhead + head_idx; 133 | grad_attn_weight += index * nhead + head_idx; 134 | 135 | // 3. Obtain grad out. 136 | grad_out += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; 137 | atomicAdd( 138 | grad_attn_weight, 139 | grad_out[0] * value_features[0]); 140 | atomicAdd( 141 | grad_value_features, 142 | grad_out[0] * attn_weight[0]); 143 | } 144 | 145 | 146 | void attention_value_computation_grad_launcher( 147 | int b, int total_query_num, int local_size, 148 | int total_key_num, int nhead, int hdim, 149 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 150 | const int *index_pair, 151 | const float *attn_weight, const float* value_features, 152 | float *grad_out, float* grad_attn_weight, float* grad_value_features){ 153 | // params query_batch_cnt: [b] 154 | // params key_batch_cnt: [b] 155 | // params index_pair_batch: [total_query_num] 156 | // params index_pair: [total_query_num, local_size] 157 | // params attn_weight: [total_query_num, local_size, nhead] 158 | // params value_features: [total_key_num, nhead, hdim] 159 | // params grad_out: [total_query_num, nhead, hdim] 160 | // params grad_attn_weight: [total_query_num, local_size, nhead] 161 | // params grad_value_features: [total_key_num, nhead, hdim] 162 | 163 | dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 164 | dim3 threads(THREADS_PER_BLOCK); 165 | attention_value_computation_backward<<>>( 166 | b, total_query_num, local_size, total_key_num, nhead, hdim, 167 | query_batch_cnt, key_batch_cnt, index_pair_batch, 168 | index_pair, attn_weight, value_features, 169 | grad_out, grad_attn_weight, grad_value_features); 170 | } -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_weight_computation_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Transformer function helper function. 3 | Written by tomztyang, 4 | 2021/08/23 5 | */ 6 | 7 | #include 8 | #include 9 | 10 | #include "attention_func.h" 11 | 12 | #define THREADS_PER_BLOCK 256 13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 14 | // #define DEBUG 15 | 16 | 17 | __global__ void attention_weight_computation_forward( 18 | int b, int total_query_num, int local_size, 19 | int total_key_num, int nhead, int hdim, 20 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 21 | const int *index_pair, 22 | const float *query_features, const float* key_features, 23 | float *output) { 24 | // dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim);; 25 | // params query_batch_cnt: [b] 26 | // params key_batch_cnt: [b] 27 | // params index_pair_batch: [total_query_num] 28 | // params index_pair: [total_query_num, local_size] 29 | // params query_features: [total_query_num, nhead, hdim] 30 | // params key_features: [total_key_num, nhead, hdim] 31 | // params output: [total_query_num, local_size, nhead] 32 | 33 | int index = blockIdx.x * blockDim.x + threadIdx.x; 34 | int head_idx = blockIdx.y; 35 | int hdim_idx = blockIdx.z; 36 | if (index >= total_query_num * local_size || 37 | head_idx >= nhead || 38 | hdim_idx >= hdim) return; 39 | 40 | if (index_pair[index] == -1){ 41 | // Ignore index. 42 | return; 43 | } 44 | 45 | int query_idx = index / local_size; 46 | int batch_idx = index_pair_batch[query_idx]; 47 | int key_start_idx = 0; 48 | for (int i = 0; i < batch_idx; i++){ 49 | key_start_idx += key_batch_cnt[i]; 50 | } 51 | 52 | // 1. Obtain key features. 53 | key_start_idx += index_pair[index]; 54 | key_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 55 | // 2. Obtain query features. 56 | query_features += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; 57 | // 3. Obtain output position. 58 | output += index * nhead + head_idx; 59 | atomicAdd( 60 | output, 61 | query_features[0] * key_features[0]); 62 | } 63 | 64 | 65 | void attention_weight_computation_launcher( 66 | int b, int total_query_num, int local_size, 67 | int total_key_num, int nhead, int hdim, 68 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 69 | const int *index_pair, 70 | const float *query_features, const float* key_features, 71 | float *output){ 72 | // params query_batch_cnt: [b] 73 | // params key_batch_cnt: [b] 74 | // params index_pair_batch: [total_query_num] 75 | // params index_pair: [total_query_num, local_size] 76 | // params query_features: [total_query_num, nhead, hdim] 77 | // params key_features: [total_key_num, nhead, hdim] 78 | // params output: [total_query_num, local_size, nhead] 79 | 80 | dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 81 | dim3 threads(THREADS_PER_BLOCK); 82 | attention_weight_computation_forward<<>>( 83 | b, total_query_num, local_size, total_key_num, nhead, hdim, 84 | query_batch_cnt, key_batch_cnt, index_pair_batch, 85 | index_pair, query_features, key_features, 86 | output); 87 | } 88 | 89 | 90 | __global__ void attention_weight_computation_backward( 91 | int b, int total_query_num, int local_size, 92 | int total_key_num, int nhead, int hdim, 93 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 94 | const int *index_pair, 95 | const float *query_features, const float* key_features, 96 | float *grad_out, float * grad_query_features, float * grad_key_features) { 97 | // dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 98 | // params query_batch_cnt: [b] 99 | // params key_batch_cnt: [b] 100 | // params index_pair_batch: [total_query_num] 101 | // params index_pair: [total_query_num, local_size] 102 | // params query_features: [total_query_num, nhead, hdim] 103 | // params key_features: [total_key_num, nhead, hdim] 104 | // params grad_out: [total_query_num, local_size, nhead] 105 | // params grad_query_features: [total_query_num, nhead, hdim] 106 | // params grad_key_features: [total_key_num, nhead, hdim] 107 | 108 | int index = blockIdx.x * blockDim.x + threadIdx.x; 109 | int head_idx = blockIdx.y; 110 | int hdim_idx = blockIdx.z; 111 | if (index >= total_query_num * local_size || 112 | head_idx >= nhead || 113 | hdim_idx >= hdim) return; 114 | 115 | if (index_pair[index] == -1){ 116 | // Ignore index. 117 | return; 118 | } 119 | 120 | int query_idx = index / local_size; 121 | int batch_idx = index_pair_batch[query_idx]; 122 | int key_start_idx = 0; 123 | for (int i = 0; i < batch_idx; i++){ 124 | key_start_idx += key_batch_cnt[i]; 125 | } 126 | 127 | // 1. Obtain key features. 128 | key_start_idx += index_pair[index]; 129 | key_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 130 | grad_key_features += key_start_idx * nhead * hdim + head_idx * hdim + hdim_idx; 131 | // 2. Obtain query features. 132 | query_features += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; 133 | grad_query_features += query_idx * nhead * hdim + head_idx * hdim + hdim_idx; 134 | // 3. Obtain output position. 135 | grad_out += index * nhead + head_idx; 136 | atomicAdd( 137 | grad_query_features, 138 | grad_out[0] * key_features[0]); 139 | atomicAdd( 140 | grad_key_features, 141 | grad_out[0] * query_features[0]); 142 | } 143 | 144 | 145 | void attention_weight_computation_grad_launcher( 146 | int b, int total_query_num, int local_size, 147 | int total_key_num, int nhead, int hdim, 148 | const int *query_batch_cnt, const int *key_batch_cnt, const int* index_pair_batch, 149 | const int *index_pair, 150 | const float *query_features, const float* key_features, 151 | float *grad_out, float* grad_query_features, float* grad_key_features){ 152 | // params query_batch_cnt: [b] 153 | // params key_batch_cnt: [b] 154 | // params index_pair_batch: [total_query_num] 155 | // params index_pair: [total_query_num, local_size] 156 | // params query_features: [total_query_num, nhead, hdim] 157 | // params key_features: [total_key_num, nhead, hdim] 158 | // params grad_out: [total_query_num, local_size, nhead] 159 | // params grad_query_features: [total_query_num, nhead, hdim] 160 | // params grad_key_features: [total_key_num, nhead, hdim] 161 | 162 | dim3 blocks(DIVUP(total_query_num * local_size, THREADS_PER_BLOCK), nhead, hdim); 163 | dim3 threads(THREADS_PER_BLOCK); 164 | attention_weight_computation_backward<<>>( 165 | b, total_query_num, local_size, total_key_num, nhead, hdim, 166 | query_batch_cnt, key_batch_cnt, index_pair_batch, 167 | index_pair, query_features, key_features, 168 | grad_out, grad_query_features, grad_key_features); 169 | } -------------------------------------------------------------------------------- /trajflow/models/dmt_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from trajflow.models import MTREncoder, DenoisingDecoder 14 | from trajflow.utils.common_utils import register_module_to_params_dict 15 | 16 | 17 | class DenoisingMotionTransformer(nn.Module): 18 | def __init__(self, config, logger): 19 | super().__init__() 20 | 21 | # init 22 | self.config = config 23 | 24 | self.model_cfg = config.MODEL_DMT 25 | self.model_cfg.CONTEXT_ENCODER.DEVICE = config.DEVICE 26 | self.model_cfg.DMT.DEVICE = config.DEVICE 27 | 28 | self.model_cfg.DMT.CONTEXT_D_MODEL = self.model_cfg.CONTEXT_ENCODER.D_MODEL 29 | 30 | self.ctc_loss = self.model_cfg.DENOISING.CTC_LOSS 31 | 32 | self.logger = logger 33 | 34 | # Encoder network (reusing the MTR encoder) 35 | self.context_encoder = MTREncoder(self.model_cfg.CONTEXT_ENCODER) 36 | 37 | # Denoising decoder 38 | self.denoising_decoder = DenoisingDecoder(model_cfg=self.model_cfg.DMT, denoising_cfg=self.model_cfg.DENOISING, 39 | logger=logger, save_dirs=self.config.SAVE_DIR, data_rescale=config.DATA_CONFIG.DATA_RESCALE) 40 | 41 | self.params_dict = {} 42 | self.register_module_to_params_dict = lambda module, name: register_module_to_params_dict(self.params_dict, module, name) 43 | self.count_model_params() 44 | 45 | def count_model_params(self): 46 | """ 47 | Count the number of trainable parameters in the model. 48 | """ 49 | self.logger.info("===== Overall DMT model parameters breakdown =====") 50 | params_total = sum(p.numel() for p in self.parameters() if p.requires_grad) 51 | self.register_module_to_params_dict(self.context_encoder, 'MTR_encoder') 52 | self.register_module_to_params_dict(self.denoising_decoder, 'denoising_decoder') 53 | 54 | params_other = params_total - sum(self.params_dict.values()) 55 | self.params_dict = {'total': params_total, **self.params_dict} 56 | self.params_dict = {**self.params_dict, 'other': params_other} 57 | for nm, p in self.params_dict.items(): 58 | self.logger.info("#params for {:40}: {:,}".format(nm, p)) 59 | 60 | self.logger.info("===== Overall DMT model parameters breakdown =====") 61 | 62 | def forward(self, batch_dict, disp_dict=None, wb_dict=None): 63 | """ 64 | Forward pass of the model. 65 | """ 66 | 67 | if self.training: 68 | # get ctc loss flags 69 | if self.ctc_loss: 70 | flag_ctc_s1 = 'denoiser_dict_ctc_1' in batch_dict and 'denoiser_dict_ctc_2' not in batch_dict 71 | flag_ctc_s2 = 'denoiser_dict_ctc_1' in batch_dict and 'denoiser_dict_ctc_2' in batch_dict 72 | assert flag_ctc_s1 or flag_ctc_s2, 'CTC loss is not properly set' 73 | 74 | """context encoder""" 75 | if self.ctc_loss and flag_ctc_s2: 76 | # use the cached encoder output for CTC loss 77 | assert 'encoder_output' in batch_dict, 'encoder_output is not found in batch_dict' 78 | else: 79 | batch_dict = self.context_encoder(batch_dict) # batch dict is updated with new outputs 80 | 81 | """denoising decoder""" 82 | batch_dict = self.denoising_decoder(batch_dict) # batch dict is updated with new outputs 83 | 84 | """compute loss""" 85 | loss_denoiser_reg, loss_denoiser_cls = self.denoising_decoder.get_loss(batch_dict, disp_dict, wb_dict) 86 | 87 | return loss_denoiser_reg, loss_denoiser_cls, batch_dict 88 | 89 | else: 90 | flag_run_encoder_net = 'encoder_output' not in batch_dict 91 | 92 | """context encoder""" 93 | if flag_run_encoder_net: 94 | batch_dict = self.context_encoder(batch_dict) # batch dict is updated with new outputs 95 | 96 | """denoising decoder""" 97 | # always run this module 98 | batch_dict = self.denoising_decoder(batch_dict) # batch dict is updated with new outputs 99 | 100 | return batch_dict 101 | 102 | def load_params(self, ckpt_path, to_cpu=True, ckpt_state=None, optimizer=None, ema_model_kw=None, ema_helper=None): 103 | """ 104 | Helper to load model parameters, optimizer states, and EMA model states from a checkpoint. 105 | """ 106 | # init and load checkpoint into memory 107 | if not os.path.isfile(ckpt_path): 108 | raise FileNotFoundError 109 | 110 | if ckpt_state is not None: 111 | self.logger.info('==> Loading parameters from in-memory checkpoint dict...') 112 | checkpoint = ckpt_state 113 | else: 114 | self.logger.info('==> Loading parameters from checkpoint %s to %s' % (ckpt_path, 'CPU' if to_cpu else 'GPU')) 115 | loc_type = torch.device('cpu') if to_cpu else None 116 | checkpoint = torch.load(ckpt_path, map_location=loc_type) 117 | 118 | epoch = checkpoint.get('epoch', -1) 119 | it = checkpoint.get('it', 0.0) 120 | version = checkpoint.get("version", None) 121 | 122 | if version is not None: 123 | self.logger.info('==> Checkpoint trained from version: %s' % version) 124 | 125 | # load EMA model if needed 126 | if ema_model_kw is not None: 127 | assert ema_model_kw in checkpoint, f'key {ema_model_kw} not found in checkpoint' 128 | loaded_model_state = checkpoint[ema_model_kw] 129 | self.logger.info(f'==> Loading EMA model with key {ema_model_kw} from checkpoint') 130 | else: 131 | loaded_model_state = checkpoint['model_state'] 132 | 133 | # check the keys in the checkpoint 134 | self.logger.info(f'The number of in-memory ckpt keys: {len(loaded_model_state)}') 135 | cur_model_state = self.state_dict() 136 | loaded_model_state_filtered = {} 137 | for key, val in loaded_model_state.items(): 138 | if key in cur_model_state and loaded_model_state[key].shape == cur_model_state[key].shape: 139 | loaded_model_state_filtered[key] = val 140 | else: 141 | if key not in cur_model_state: 142 | self.logger.info(f'Ignore key in disk (not found in model): {key}, shape={val.shape}') 143 | else: 144 | self.logger.info(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={cur_model_state[key].shape}') 145 | 146 | # load the filtered checkpoint 147 | missing_keys, unexpected_keys = self.load_state_dict(loaded_model_state_filtered, strict=True) 148 | 149 | self.logger.info(f'Missing keys: {missing_keys}') 150 | self.logger.info(f'The number of missing keys: {len(missing_keys)}') 151 | self.logger.info(f'The number of unexpected keys: {len(unexpected_keys)}') 152 | self.logger.info('==> Done loading model ckpt (total keys %d)' % (len(cur_model_state))) 153 | 154 | # laod optimizer if needed 155 | if optimizer is not None: 156 | self.logger.info('==> Loading optimizer parameters from checkpoint %s to %s' % (ckpt_path, 'CPU' if to_cpu else 'GPU')) 157 | optimizer.load_state_dict(checkpoint['optimizer_state']) 158 | self.logger.info('==> Done loading optimizer state') 159 | 160 | # load EMA helper weights if needed 161 | if ema_helper is not None: 162 | for ema_wrapper in ema_helper: 163 | beta = ema_wrapper.beta 164 | self.logger.info('==> Loading EMA model with beta = %.4f from checkpoint %s to %s'% (beta, ckpt_path, 'CPU' if to_cpu else 'GPU')) 165 | ema_wrapper.ema_model.load_state_dict(checkpoint['model_ema_beta_{:.4f}'.format(beta)], strict=True) 166 | self.logger.info('==> Done loading EMA model') 167 | 168 | return it, epoch 169 | -------------------------------------------------------------------------------- /runner/utils/tester.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import pickle 14 | import time 15 | import re 16 | import copy 17 | import os 18 | import numpy as np 19 | import torch 20 | import tqdm 21 | from multiprocessing import Pool 22 | from scipy.special import softmax 23 | 24 | from trajflow.utils.init_objective import prepare_denoiser_data 25 | from trajflow.utils import common_utils, motion_utils 26 | from .submission import serialize_single_batch, save_submission_file 27 | 28 | 29 | def deep_copy_dict(batch_dict, scores, trajs): 30 | batch_dict_copy = {} 31 | keys_to_del = ['denoiser_dict', 'encoder_output', 'denoiser_output'] 32 | for key, val in batch_dict.items(): 33 | if key not in keys_to_del: 34 | batch_dict_copy[key] = copy.deepcopy(val) 35 | batch_dict_copy['pred_scores'] = scores 36 | batch_dict_copy['pred_trajs'] = trajs 37 | return batch_dict_copy 38 | 39 | 40 | def eval_one_epoch(denoiser, test_loader, cfg, epoch_id, logger, 41 | inter_pred=False, flag_submission=False, submission_info=None, 42 | logger_iter_interval=50): 43 | # Init 44 | dist_test = cfg.OPT.DIST_TRAIN 45 | eval_output_dir = cfg.SAVE_DIR.EVAL_OUTPUT_DIR 46 | 47 | test_set = test_loader.dataset 48 | 49 | pred_dicts = [] # denoiser trajectory + denoiser classifer score 50 | scenario_predictions = [] # submission format 51 | 52 | # Adjust the model for evaluation 53 | logger.info('*************** EPOCH %s EVALUATION *****************' % epoch_id) 54 | if dist_test: 55 | if not isinstance(denoiser, torch.nn.parallel.DistributedDataParallel): 56 | num_gpus = torch.cuda.device_count() 57 | local_rank = cfg.LOCAL_RANK % num_gpus 58 | denoiser = torch.nn.parallel.DistributedDataParallel( 59 | denoiser, 60 | device_ids=[local_rank], 61 | broadcast_buffers=False 62 | ) 63 | denoiser.eval() 64 | 65 | if cfg.LOCAL_RANK == 0: 66 | progress_bar = tqdm.tqdm(total=len(test_loader), leave=True, desc='eval', dynamic_ncols=True) 67 | 68 | # Evaluation loop 69 | start_time = time.time() 70 | 71 | for i, batch_dict in enumerate(test_loader): 72 | disp_dict = {} 73 | 74 | with torch.no_grad(): 75 | """prepare data for denoiser model""" 76 | if 'center_gt_trajs' in batch_dict['input_dict']: 77 | batch_dict = prepare_denoiser_data(batch_dict, cfg.DATA_CONFIG.DATA_RESCALE, cfg.DEVICE) 78 | 79 | batch_dict['denoiser_dict'] = {} 80 | 81 | """create more samples in a for loop""" 82 | pred_trajs, pred_cls_logits, batch_dicts = denoiser(batch_dict, disp_dict=disp_dict, flag_sample=True) 83 | 84 | """use denoiser cls score for NMS""" 85 | pred_scores_cls_nms = batch_dicts['pred_scores'] 86 | pred_trajs_cls_nms = batch_dicts['pred_trajs'] 87 | 88 | batch_cls_score = deep_copy_dict(batch_dicts, pred_scores_cls_nms, pred_trajs_cls_nms) 89 | 90 | final_pred_dicts = test_set.generate_prediction_dicts(batch_cls_score, 91 | inter_pred=inter_pred, 92 | flag_submission=flag_submission) 93 | pred_dicts += final_pred_dicts 94 | 95 | if flag_submission: 96 | scenario_predictions.extend(serialize_single_batch(final_pred_dicts, inter_pred)) 97 | 98 | B, K, T = pred_trajs.size()[:3] 99 | 100 | ### end of torch.no_grad() ### 101 | 102 | # log the evaluation results 103 | if cfg.LOCAL_RANK == 0 and (i % logger_iter_interval == 0 or i == 0 or i + 1== len(test_loader)): 104 | past_time = progress_bar.format_dict['elapsed'] 105 | second_each_iter = past_time / max(i, 1.0) 106 | remaining_time = second_each_iter * (len(test_loader) - i) 107 | disp_str = ', '.join([f'{key}={val:.3f}' for key, val in disp_dict.items() if key != 'lr']) 108 | batch_size = batch_dict.get('batch_size', None) 109 | logger.info(f'eval: epoch={epoch_id}, batch_iter={i}/{len(test_loader)}, batch_size={batch_size}, iter_cost={second_each_iter:.2f}s, ' 110 | f'time_cost: {progress_bar.format_interval(past_time)}/{progress_bar.format_interval(remaining_time)}, ' 111 | f'{disp_str}') 112 | ### end of evaluation loop ### 113 | 114 | """eval data saving and logging""" 115 | if cfg.LOCAL_RANK == 0: 116 | progress_bar.close() 117 | 118 | if dist_test: 119 | logger.info(f'Total number of samples before merging from multiple GPUs: {len(pred_dicts)}') 120 | pred_dicts = common_utils.merge_results_dist(pred_dicts, len(test_set), tmpdir=os.path.join(eval_output_dir, 'tmpdir')) 121 | if cfg.LOCAL_RANK == 0: 122 | logger.info(f'Total number of samples after merging from multiple GPUs (removing duplicate): {len(pred_dicts)}') 123 | 124 | if flag_submission: 125 | scenario_predictions = common_utils.merge_results_dist(scenario_predictions, len(test_set), tmpdir=os.path.join(eval_output_dir, 'tmpdir')) 126 | 127 | logger.info('*************** Performance of EPOCH %s *****************' % epoch_id) 128 | sec_per_example = (time.time() - start_time) / len(test_loader.dataset) 129 | logger.info('Generate label finished(sec_per_example: %.4f second).' % sec_per_example) 130 | 131 | if cfg.LOCAL_RANK != 0: 132 | return {} 133 | 134 | ret_dict = {} 135 | 136 | logger.info("Number of total trajectories to evaluate: {:d}".format(len(pred_dicts))) 137 | with open(os.path.join(eval_output_dir, 'result_denoiser.pkl'), 'wb') as f: 138 | pickle.dump(pred_dicts, f) 139 | 140 | if flag_submission: 141 | save_submission_file(scenario_predictions, inter_pred, eval_output_dir, cfg.OUTPUT_DIR_PREFIX, submission_info, logger) 142 | 143 | """evaluate trajectory performance""" 144 | def _get_latex_str(in_str): 145 | # extract the last line of the evaluation results and reorganize it into a latex-friendly string 146 | str_latex = in_str.split('\n')[-2].split(',')[:4] 147 | str_latex = [float(re.findall(r"[-+]?\d*\.\d+|\d+", s)[0]) for s in str_latex] 148 | str_latex = str_latex[1:] + [str_latex[0]] 149 | str_latex = ' & '.join(['{:.4f}'.format(float(s)) for s in str_latex]) 150 | return str_latex 151 | 152 | def _eval_and_log(pred_dicts, keyword): 153 | if len(pred_dicts): 154 | result_str, result_dict = test_set.evaluation(pred_dicts, inter_pred=inter_pred) 155 | # logger.info('\n{} Diffusion output results {}'.format('*' * 20, '*' * 20) + result_str) 156 | logger.info('\n{:s} {:s} output results {:s}'.format('*' * 20, keyword, '*' * 20) + '\n'.join(result_str.split('\n')[-7:])) 157 | result_latex = _get_latex_str(result_str) 158 | logger.info('{:s} output results in latex-friendly format: '.format(keyword) + result_latex) 159 | result_dict = {'{:s}_'.format(keyword) + key: val for key, val in result_dict.items()} 160 | ret_dict.update(result_dict) 161 | else: 162 | logger.info("Skip {:s} results evaluation as no relevant results are available.".format(keyword)) 163 | 164 | if test_set.mode in ['eval', 'inter_eval']: 165 | _eval_and_log(pred_dicts, 'denoiser') 166 | 167 | logger.info('Result is save to %s' % eval_output_dir) 168 | logger.info('****************Evaluation done.*****************') 169 | 170 | return ret_dict 171 | 172 | 173 | if __name__ == '__main__': 174 | pass 175 | -------------------------------------------------------------------------------- /runner/utils/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import _init_path 14 | import glob 15 | import os 16 | import re 17 | import time 18 | import copy 19 | import torch 20 | import wandb 21 | 22 | from .tester import eval_one_epoch 23 | 24 | 25 | def get_ema_weight_keywords(ckpt_keys, ema_coef, logger): 26 | """Get EMA weight keywords based on checkpoint data and EMA coefficients.""" 27 | model_keys = [key for key in ckpt_keys if key.startswith('model_')] 28 | online_key = 'model_state' 29 | weight_keywords = [online_key] 30 | 31 | if ema_coef is None: 32 | logger.info('Not using EMA weight.') 33 | elif ema_coef == 'all': 34 | weight_keywords = model_keys 35 | logger.info('Use all possible EMA weights.') 36 | else: 37 | logger.info(f'Using EMA weight with coefficients: {ema_coef}') 38 | if 1.0 not in ema_coef: 39 | weight_keywords.remove(online_key) 40 | else: 41 | ema_coef.remove(1.0) 42 | 43 | for coef in ema_coef: 44 | weight_key = f'model_ema_beta_{coef:.4f}' 45 | assert weight_key in model_keys, f"{weight_key} not found in model data!" 46 | weight_keywords.append(weight_key) 47 | 48 | logger.info(f'Model weights to load: {weight_keywords}') 49 | return weight_keywords 50 | 51 | 52 | def eval_single_ckpt(denoiser, test_loader, cfg, args, logger, args_ema_coef=None, submission_info=None): 53 | """Evaluate a single checkpoint with optional EMA variants.""" 54 | cfg_ = copy.deepcopy(cfg) 55 | dist_test = cfg.OPT.DIST_TRAIN 56 | eval_output_dir = cfg.SAVE_DIR.EVAL_OUTPUT_DIR 57 | 58 | # Load checkpoint into memory for EMA variants 59 | device = 'CPU' if dist_test else 'GPU' 60 | logger.info(f'==> Loading parameters from checkpoint {args.ckpt} to {device}') 61 | ckpt_state = torch.load(args.ckpt, map_location=torch.device('cpu') if dist_test else None) 62 | 63 | weight_keywords = get_ema_weight_keywords(ckpt_state.keys(), args_ema_coef, logger) 64 | 65 | for weight_kw in weight_keywords: 66 | # Load checkpoint 67 | if args.ckpt is not None: 68 | it, epoch = denoiser.model.load_params( 69 | ckpt_path=args.ckpt, to_cpu=dist_test, 70 | ckpt_state=ckpt_state, optimizer=None, ema_model_kw=weight_kw 71 | ) 72 | epoch += 1 # because the epoch is 0-indexed in the checkpoint 73 | else: 74 | it, epoch, ckpt_state = -1, -1, None 75 | 76 | denoiser.cuda() 77 | logger.info(f'*************** Successfully load model (epoch={epoch}, iter={it}, EMA weight_keyword={weight_kw}) for EVALUATION *****************') 78 | 79 | # Setup result directory 80 | base_dir = os.path.basename(eval_output_dir) 81 | if base_dir == 'default': 82 | result_dir = os.path.join(base_dir, f'weight_{weight_kw}_epoch_{epoch}') 83 | else: 84 | result_dir = os.path.join(eval_output_dir, f'weight_{weight_kw}_epoch_{epoch}') 85 | 86 | os.makedirs(result_dir, exist_ok=True) 87 | logger.info(f'*************** Saving results to {result_dir} *****************') 88 | 89 | # Run evaluation 90 | cfg = copy.deepcopy(cfg_) 91 | cfg.SAVE_DIR.EVAL_OUTPUT_DIR = result_dir 92 | 93 | eval_one_epoch( 94 | denoiser, test_loader, cfg, epoch, logger, 95 | inter_pred=args.interactive, flag_submission=args.submit, 96 | submission_info=submission_info, logger_iter_interval=args.logger_iter_interval 97 | ) 98 | 99 | 100 | def get_unevaluated_ckpt(ckpt_dir, record_file, start_epoch): 101 | """Find the next unevaluated checkpoint.""" 102 | ckpt_files = glob.glob(os.path.join(ckpt_dir, '*checkpoint_epoch_*.pth')) 103 | ckpt_files.sort(key=os.path.getmtime) 104 | 105 | with open(record_file, 'r') as f: 106 | evaluated_epochs = {float(x.strip()) for x in f.readlines()} 107 | 108 | for ckpt_path in ckpt_files: 109 | match = re.search(r'checkpoint_epoch_(.*?)\.pth', ckpt_path) 110 | if not match or 'optim' in match.group(1): 111 | continue 112 | 113 | try: 114 | epoch_id = float(match.group(1)) 115 | except: 116 | epoch_id = float(match.group(1).split('_')[0]) 117 | 118 | if epoch_id not in evaluated_epochs and int(epoch_id) >= start_epoch: 119 | return int(epoch_id), ckpt_path 120 | 121 | return -1, None 122 | 123 | 124 | def repeat_eval_ckpt(denoiser, test_loader, cfg, args, logger, args_ema_coef=None, submission_info=None): 125 | """Repeatedly evaluate checkpoints as they become available.""" 126 | if args_ema_coef is not None: 127 | raise NotImplementedError('EMA checkpoint variants not supported for repeated evaluation') 128 | 129 | cfg_ = copy.deepcopy(cfg) 130 | dist_test = cfg.OPT.DIST_TRAIN 131 | eval_output_dir = cfg.SAVE_DIR.EVAL_OUTPUT_DIR 132 | ckpt_dir = cfg.SAVE_DIR.CKPT_DIR 133 | 134 | # Setup checkpoint record file 135 | record_file = os.path.join(eval_output_dir, 'eval_list_val.txt') 136 | open(record_file, 'a').close() # Create file if doesn't exist 137 | 138 | # Setup wandb logging - use the same wandb run as training 139 | wb_log = None 140 | if cfg.LOCAL_RANK == 0: 141 | # Get the current wandb run instead of creating a new one 142 | wb_log = wandb.run 143 | 144 | total_time = 0 145 | wait_seconds = 10 146 | 147 | while True: 148 | # Find next unevaluated checkpoint 149 | epoch_id, ckpt_path = get_unevaluated_ckpt(ckpt_dir, record_file, args.start_epoch) 150 | 151 | if epoch_id == -1: 152 | # No checkpoint found, wait and retry 153 | if cfg.LOCAL_RANK == 0: 154 | progress = total_time / 60 155 | print(f'Wait {wait_seconds}s for next check (progress: {progress:.1f}/{args.max_waiting_mins} min): {ckpt_dir}\r', 156 | end='', flush=True) 157 | 158 | time.sleep(wait_seconds) 159 | total_time += wait_seconds 160 | 161 | if total_time >= args.max_waiting_mins * 60: 162 | break 163 | continue 164 | 165 | # Reset timer and evaluate checkpoint 166 | total_time = 0 167 | 168 | it, epoch = denoiser.model.load_params(ckpt_path=ckpt_path, to_cpu=dist_test) 169 | logger.info(f'*************** LOAD MODEL (epoch={epoch}, iter={it}) for EVALUATION *****************') 170 | denoiser.cuda() 171 | 172 | # Setup result directory and run evaluation 173 | result_dir = os.path.join(eval_output_dir, f'epoch_{epoch_id:02d}') 174 | os.makedirs(result_dir, exist_ok=True) 175 | 176 | cfg = copy.deepcopy(cfg_) 177 | cfg.SAVE_DIR.EVAL_OUTPUT_DIR = result_dir 178 | 179 | wb_dict = eval_one_epoch( 180 | denoiser, test_loader, cfg, epoch_id, logger, 181 | inter_pred=args.interactive, flag_submission=args.submit, 182 | submission_info=submission_info, logger_iter_interval=args.logger_iter_interval 183 | ) 184 | 185 | # Log to wandb 186 | if wb_log: 187 | wb_dict = {k: v for k, v in wb_dict.items() if '-----' not in k} # skip meaningless entries 188 | eval_log_dict = {f'eval/{key}': val for key, val in wb_dict.items()} 189 | eval_log_dict['epoch'] = epoch_id 190 | wb_log.log(eval_log_dict) 191 | 192 | # Record evaluated epoch 193 | with open(record_file, 'a') as f: 194 | f.write(f'{epoch_id}\n') 195 | logger.info(f'Epoch {epoch_id} has been evaluated') 196 | 197 | -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_func.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_func.h" 5 | 6 | #define CHECK_CUDA(x) do { \ 7 | if (!x.type().is_cuda()) { \ 8 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 9 | exit(-1); \ 10 | } \ 11 | } while (0) 12 | #define CHECK_CONTIGUOUS(x) do { \ 13 | if (!x.is_contiguous()) { \ 14 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 15 | exit(-1); \ 16 | } \ 17 | } while (0) 18 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 19 | 20 | 21 | int attention_weight_computation_wrapper( 22 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 23 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 24 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 25 | at::Tensor output){ 26 | // params query_batch_cnt: [b] 27 | // params key_batch_cnt: [b] 28 | // params index_pair_batch: [total_query_num] 29 | // params index_pair: [total_query_num, local_size] 30 | // params query_features: [total_query_num, nhead, hdim] 31 | // params key_features: [total_key_num, nhead, hdim] 32 | // params output: [total_query_num, local_size, nhead] 33 | CHECK_INPUT(query_batch_cnt); 34 | CHECK_INPUT(key_batch_cnt); 35 | CHECK_INPUT(index_pair_batch); 36 | CHECK_INPUT(index_pair); 37 | CHECK_INPUT(query_features); 38 | CHECK_INPUT(key_features); 39 | 40 | CHECK_INPUT(output); 41 | 42 | const int *query_batch_cnt_data = query_batch_cnt.data(); 43 | const int *key_batch_cnt_data = key_batch_cnt.data(); 44 | const int *index_pair_batch_data = index_pair_batch.data(); 45 | const int *index_pair_data = index_pair.data(); 46 | 47 | const float *query_features_data = query_features.data(); 48 | const float *key_features_data = key_features.data(); 49 | 50 | float *output_data = output.data(); 51 | 52 | attention_weight_computation_launcher( 53 | b, total_query_num, local_size, total_key_num, nhead, hdim, 54 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 55 | index_pair_data, query_features_data, key_features_data, 56 | output_data); 57 | 58 | return 1; 59 | } 60 | 61 | 62 | int attention_weight_computation_grad_wrapper( 63 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 64 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 65 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 66 | at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features){ 67 | // params query_batch_cnt: [b] 68 | // params key_batch_cnt: [b] 69 | // params index_pair_batch: [total_query_num] 70 | // params index_pair: [total_query_num, local_size] 71 | // params query_features: [total_query_num, nhead, hdim] 72 | // params key_features: [total_key_num, nhead, hdim] 73 | // params grad_out: [total_query_num, local_size, nhead] 74 | // params grad_query_features: [total_query_num, nhead, hdim] 75 | // params grad_key_features: [total_key_num, nhead, hdim] 76 | CHECK_INPUT(query_batch_cnt); 77 | CHECK_INPUT(key_batch_cnt); 78 | CHECK_INPUT(index_pair_batch); 79 | CHECK_INPUT(index_pair); 80 | CHECK_INPUT(query_features); 81 | CHECK_INPUT(key_features); 82 | 83 | CHECK_INPUT(grad_out); 84 | CHECK_INPUT(grad_query_features); 85 | CHECK_INPUT(grad_key_features); 86 | 87 | const int *query_batch_cnt_data = query_batch_cnt.data(); 88 | const int *key_batch_cnt_data = key_batch_cnt.data(); 89 | const int *index_pair_batch_data = index_pair_batch.data(); 90 | const int *index_pair_data = index_pair.data(); 91 | 92 | const float *query_features_data = query_features.data(); 93 | const float *key_features_data = key_features.data(); 94 | 95 | float *grad_out_data = grad_out.data(); 96 | float *grad_query_features_data = grad_query_features.data(); 97 | float *grad_key_features_data = grad_key_features.data(); 98 | 99 | attention_weight_computation_grad_launcher( 100 | b, total_query_num, local_size, total_key_num, nhead, hdim, 101 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 102 | index_pair_data, query_features_data, key_features_data, 103 | grad_out_data, grad_query_features_data, grad_key_features_data); 104 | 105 | return 1; 106 | } 107 | 108 | 109 | int attention_value_computation_wrapper( 110 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 111 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 112 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 113 | at::Tensor output){ 114 | // params query_batch_cnt: [b] 115 | // params key_batch_cnt: [b] 116 | // params index_pair_batch: [total_query_num] 117 | // params index_pair: [total_query_num, local_size] 118 | // params attn_weight: [total_query_num, local_size, nhead] 119 | // params value_features: [total_key_num, nhead, hdim] 120 | // params output: [total_query_num, nhead, hdim] 121 | CHECK_INPUT(query_batch_cnt); 122 | CHECK_INPUT(key_batch_cnt); 123 | CHECK_INPUT(index_pair_batch); 124 | CHECK_INPUT(index_pair); 125 | CHECK_INPUT(attn_weight); 126 | CHECK_INPUT(value_features); 127 | 128 | CHECK_INPUT(output); 129 | 130 | const int *query_batch_cnt_data = query_batch_cnt.data(); 131 | const int *key_batch_cnt_data = key_batch_cnt.data(); 132 | const int *index_pair_batch_data = index_pair_batch.data(); 133 | const int *index_pair_data = index_pair.data(); 134 | 135 | const float *attn_weight_data = attn_weight.data(); 136 | const float *value_features_data = value_features.data(); 137 | 138 | float *output_data = output.data(); 139 | 140 | attention_value_computation_launcher( 141 | b, total_query_num, local_size, total_key_num, nhead, hdim, 142 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 143 | index_pair_data, attn_weight_data, value_features_data, 144 | output_data); 145 | 146 | return 1; 147 | } 148 | 149 | 150 | int attention_value_computation_grad_wrapper( 151 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 152 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 153 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 154 | at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features){ 155 | // params query_batch_cnt: [b] 156 | // params key_batch_cnt: [b] 157 | // params index_pair_batch: [total_query_num] 158 | // params index_pair: [total_query_num, local_size] 159 | // params attn_weight: [total_query_num, local_size, nhead] 160 | // params value_features: [total_key_num, nhead, hdim] 161 | // params grad_out: [total_query_num, nhead, hdim] 162 | // params grad_attn_weight: [total_query_num, local_size, nhead] 163 | // params grad_value_features: [total_key_num, nhead, hdim] 164 | CHECK_INPUT(query_batch_cnt); 165 | CHECK_INPUT(key_batch_cnt); 166 | CHECK_INPUT(index_pair_batch); 167 | CHECK_INPUT(index_pair); 168 | CHECK_INPUT(attn_weight); 169 | CHECK_INPUT(value_features); 170 | 171 | CHECK_INPUT(grad_out); 172 | CHECK_INPUT(grad_attn_weight); 173 | CHECK_INPUT(grad_value_features); 174 | 175 | const int *query_batch_cnt_data = query_batch_cnt.data(); 176 | const int *key_batch_cnt_data = key_batch_cnt.data(); 177 | const int *index_pair_batch_data = index_pair_batch.data(); 178 | const int *index_pair_data = index_pair.data(); 179 | 180 | const float *attn_weight_data = attn_weight.data(); 181 | const float *value_features_data = value_features.data(); 182 | 183 | float *grad_out_data = grad_out.data(); 184 | float *grad_attn_weight_data = grad_attn_weight.data(); 185 | float *grad_value_features_data = grad_value_features.data(); 186 | 187 | attention_value_computation_grad_launcher( 188 | b, total_query_num, local_size, total_key_num, nhead, hdim, 189 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 190 | index_pair_data, attn_weight_data, value_features_data, 191 | grad_out_data, grad_attn_weight_data, grad_value_features_data); 192 | 193 | return 1; 194 | } -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/src/attention_func_v2.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "attention_func_v2.h" 5 | 6 | #define CHECK_CUDA(x) do { \ 7 | if (!x.type().is_cuda()) { \ 8 | fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 9 | exit(-1); \ 10 | } \ 11 | } while (0) 12 | #define CHECK_CONTIGUOUS(x) do { \ 13 | if (!x.is_contiguous()) { \ 14 | fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \ 15 | exit(-1); \ 16 | } \ 17 | } while (0) 18 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 19 | 20 | 21 | int attention_weight_computation_wrapper_v2( 22 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 23 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 24 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 25 | at::Tensor output){ 26 | // params query_batch_cnt: [b] 27 | // params key_batch_cnt: [b] 28 | // params index_pair_batch: [total_query_num] 29 | // params index_pair: [total_query_num, local_size] 30 | // params query_features: [total_query_num, nhead, hdim] 31 | // params key_features: [total_key_num, nhead, hdim] 32 | // params output: [total_query_num, local_size, nhead] 33 | CHECK_INPUT(query_batch_cnt); 34 | CHECK_INPUT(key_batch_cnt); 35 | CHECK_INPUT(index_pair_batch); 36 | CHECK_INPUT(index_pair); 37 | CHECK_INPUT(query_features); 38 | CHECK_INPUT(key_features); 39 | 40 | CHECK_INPUT(output); 41 | 42 | const int *query_batch_cnt_data = query_batch_cnt.data(); 43 | const int *key_batch_cnt_data = key_batch_cnt.data(); 44 | const int *index_pair_batch_data = index_pair_batch.data(); 45 | const int *index_pair_data = index_pair.data(); 46 | 47 | const float *query_features_data = query_features.data(); 48 | const float *key_features_data = key_features.data(); 49 | 50 | float *output_data = output.data(); 51 | 52 | attention_weight_computation_launcher_v2( 53 | b, total_query_num, local_size, total_key_num, nhead, hdim, 54 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 55 | index_pair_data, query_features_data, key_features_data, 56 | output_data); 57 | 58 | return 1; 59 | } 60 | 61 | 62 | int attention_weight_computation_grad_wrapper_v2( 63 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 64 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 65 | at::Tensor index_pair, at::Tensor query_features, at::Tensor key_features, 66 | at::Tensor grad_out, at::Tensor grad_query_features, at::Tensor grad_key_features){ 67 | // params query_batch_cnt: [b] 68 | // params key_batch_cnt: [b] 69 | // params index_pair_batch: [total_query_num] 70 | // params index_pair: [total_query_num, local_size] 71 | // params query_features: [total_query_num, nhead, hdim] 72 | // params key_features: [total_key_num, nhead, hdim] 73 | // params grad_out: [total_query_num, local_size, nhead] 74 | // params grad_query_features: [total_query_num, nhead, hdim] 75 | // params grad_key_features: [total_key_num, nhead, hdim] 76 | CHECK_INPUT(query_batch_cnt); 77 | CHECK_INPUT(key_batch_cnt); 78 | CHECK_INPUT(index_pair_batch); 79 | CHECK_INPUT(index_pair); 80 | CHECK_INPUT(query_features); 81 | CHECK_INPUT(key_features); 82 | 83 | CHECK_INPUT(grad_out); 84 | CHECK_INPUT(grad_query_features); 85 | CHECK_INPUT(grad_key_features); 86 | 87 | const int *query_batch_cnt_data = query_batch_cnt.data(); 88 | const int *key_batch_cnt_data = key_batch_cnt.data(); 89 | const int *index_pair_batch_data = index_pair_batch.data(); 90 | const int *index_pair_data = index_pair.data(); 91 | 92 | const float *query_features_data = query_features.data(); 93 | const float *key_features_data = key_features.data(); 94 | 95 | float *grad_out_data = grad_out.data(); 96 | float *grad_query_features_data = grad_query_features.data(); 97 | float *grad_key_features_data = grad_key_features.data(); 98 | 99 | attention_weight_computation_grad_launcher_v2( 100 | b, total_query_num, local_size, total_key_num, nhead, hdim, 101 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 102 | index_pair_data, query_features_data, key_features_data, 103 | grad_out_data, grad_query_features_data, grad_key_features_data); 104 | 105 | return 1; 106 | } 107 | 108 | 109 | int attention_value_computation_wrapper_v2( 110 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 111 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 112 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 113 | at::Tensor output){ 114 | // params query_batch_cnt: [b] 115 | // params key_batch_cnt: [b] 116 | // params index_pair_batch: [total_query_num] 117 | // params index_pair: [total_query_num, local_size] 118 | // params attn_weight: [total_query_num, local_size, nhead] 119 | // params value_features: [total_key_num, nhead, hdim] 120 | // params output: [total_query_num, nhead, hdim] 121 | CHECK_INPUT(query_batch_cnt); 122 | CHECK_INPUT(key_batch_cnt); 123 | CHECK_INPUT(index_pair_batch); 124 | CHECK_INPUT(index_pair); 125 | CHECK_INPUT(attn_weight); 126 | CHECK_INPUT(value_features); 127 | 128 | CHECK_INPUT(output); 129 | 130 | const int *query_batch_cnt_data = query_batch_cnt.data(); 131 | const int *key_batch_cnt_data = key_batch_cnt.data(); 132 | const int *index_pair_batch_data = index_pair_batch.data(); 133 | const int *index_pair_data = index_pair.data(); 134 | 135 | const float *attn_weight_data = attn_weight.data(); 136 | const float *value_features_data = value_features.data(); 137 | 138 | float *output_data = output.data(); 139 | 140 | attention_value_computation_launcher_v2( 141 | b, total_query_num, local_size, total_key_num, nhead, hdim, 142 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 143 | index_pair_data, attn_weight_data, value_features_data, 144 | output_data); 145 | 146 | return 1; 147 | } 148 | 149 | 150 | int attention_value_computation_grad_wrapper_v2( 151 | int b, int total_query_num, int local_size, int total_key_num, int nhead, int hdim, 152 | at::Tensor query_batch_cnt, at::Tensor key_batch_cnt, at::Tensor index_pair_batch, 153 | at::Tensor index_pair, at::Tensor attn_weight, at::Tensor value_features, 154 | at::Tensor grad_out, at::Tensor grad_attn_weight, at::Tensor grad_value_features){ 155 | // params query_batch_cnt: [b] 156 | // params key_batch_cnt: [b] 157 | // params index_pair_batch: [total_query_num] 158 | // params index_pair: [total_query_num, local_size] 159 | // params attn_weight: [total_query_num, local_size, nhead] 160 | // params value_features: [total_key_num, nhead, hdim] 161 | // params grad_out: [total_query_num, nhead, hdim] 162 | // params grad_attn_weight: [total_query_num, local_size, nhead] 163 | // params grad_value_features: [total_key_num, nhead, hdim] 164 | CHECK_INPUT(query_batch_cnt); 165 | CHECK_INPUT(key_batch_cnt); 166 | CHECK_INPUT(index_pair_batch); 167 | CHECK_INPUT(index_pair); 168 | CHECK_INPUT(attn_weight); 169 | CHECK_INPUT(value_features); 170 | 171 | CHECK_INPUT(grad_out); 172 | CHECK_INPUT(grad_attn_weight); 173 | CHECK_INPUT(grad_value_features); 174 | 175 | const int *query_batch_cnt_data = query_batch_cnt.data(); 176 | const int *key_batch_cnt_data = key_batch_cnt.data(); 177 | const int *index_pair_batch_data = index_pair_batch.data(); 178 | const int *index_pair_data = index_pair.data(); 179 | 180 | const float *attn_weight_data = attn_weight.data(); 181 | const float *value_features_data = value_features.data(); 182 | 183 | float *grad_out_data = grad_out.data(); 184 | float *grad_attn_weight_data = grad_attn_weight.data(); 185 | float *grad_value_features_data = grad_value_features.data(); 186 | 187 | attention_value_computation_grad_launcher_v2( 188 | b, total_query_num, local_size, total_key_num, nhead, hdim, 189 | query_batch_cnt_data, key_batch_cnt_data, index_pair_batch_data, 190 | index_pair_data, attn_weight_data, value_features_data, 191 | grad_out_data, grad_attn_weight_data, grad_value_features_data); 192 | 193 | return 1; 194 | } -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/attention_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/ops/attention/attention_utils_v2.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Function, Variable 8 | 9 | from . import attention_cuda 10 | 11 | 12 | """ Attention computation code v1.""" 13 | class AttentionWeightComputation(Function): 14 | """ 15 | Generate the attention weight matrix based on: 16 | * the generated attention pair index (total_query_num, local_size); 17 | * query features (total_query_num, nhead, hdim) 18 | * key features (total_key_num, nhead, hdim) 19 | Generate the attention weight matrix. 20 | * (total_query_num, local_size) 21 | """ 22 | 23 | @staticmethod 24 | def forward(ctx, 25 | query_batch_cnt: torch.Tensor, 26 | key_batch_cnt: torch.Tensor, 27 | index_pair_batch: torch.Tensor, 28 | index_pair: torch.Tensor, 29 | query_features: torch.Tensor, 30 | key_features: torch.Tensor): 31 | """ 32 | :param ctx: 33 | :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. 34 | :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. 35 | :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch 36 | index of each query. 37 | :param index_pair: A integer tensor with shape [total_query_num, local_size] 38 | We ignore those index whose value is -1. 39 | :param query_features: A float tensor with shape [total_query_num, nhead, hdim] 40 | :param key_features: A float tensor with shape [total_key_num, nhead, hdim] 41 | :return: 42 | output: A float tensor with shape [total_query_num, local_size, nhead] 43 | """ 44 | assert query_batch_cnt.is_contiguous() 45 | assert key_batch_cnt.is_contiguous() 46 | assert index_pair_batch.is_contiguous() 47 | assert index_pair.is_contiguous() 48 | assert query_features.is_contiguous() 49 | assert key_features.is_contiguous() 50 | 51 | b = query_batch_cnt.shape[0] 52 | total_query_num, local_size = index_pair.size() 53 | total_key_num, nhead, hdim = key_features.size() 54 | 55 | # Need to ensure that every tensor in query features have an output. 56 | assert total_query_num == query_features.shape[0] 57 | 58 | output = torch.cuda.FloatTensor(total_query_num, local_size, nhead).zero_() 59 | 60 | attention_cuda.attention_weight_computation_wrapper( 61 | b, total_query_num, local_size, total_key_num, nhead, hdim, 62 | query_batch_cnt, key_batch_cnt, index_pair_batch, 63 | index_pair, query_features, key_features, 64 | output) 65 | ctx.for_backwards = ( 66 | b, total_query_num, local_size, total_key_num, nhead, hdim, 67 | query_batch_cnt, key_batch_cnt, index_pair_batch, 68 | index_pair, query_features, key_features 69 | ) 70 | return output 71 | 72 | @staticmethod 73 | def backward(ctx, grad_out: torch.Tensor): 74 | """ 75 | Args: 76 | ctx: 77 | grad_out: [total_query_num, local_size, nhead] 78 | Returns: 79 | grad_query_features: [total_query_num, nhead, hdim] 80 | grad_key_features: [total_key_num, nhead, hdim] 81 | """ 82 | (b, total_query_num, local_size, total_key_num, nhead, hdim, 83 | query_batch_cnt, key_batch_cnt, index_pair_batch, 84 | index_pair, query_features, key_features) = ctx.for_backwards 85 | 86 | grad_query_features = Variable(torch.cuda.FloatTensor( 87 | total_query_num, nhead, hdim).zero_()) 88 | grad_key_features = Variable(torch.cuda.FloatTensor( 89 | total_key_num, nhead, hdim).zero_()) 90 | 91 | grad_out_data = grad_out.data.contiguous() 92 | attention_cuda.attention_weight_computation_grad_wrapper( 93 | b, total_query_num, local_size, total_key_num, nhead, hdim, 94 | query_batch_cnt, key_batch_cnt, index_pair_batch, 95 | index_pair, query_features, key_features, 96 | grad_out_data, grad_query_features.data, grad_key_features.data) 97 | return None, None, None, None, grad_query_features, grad_key_features 98 | 99 | 100 | attention_weight_computation = AttentionWeightComputation.apply 101 | 102 | 103 | class AttentionValueComputation(Function): 104 | """ 105 | Generate the attention result based on: 106 | * the generated attention pair index (total_query_num, local_size); 107 | * value features (total_key_num, nhead, hdim) 108 | * attn_weight (total_query_num, local_size, nhead) 109 | Generate the attention result. 110 | * (total_query_num, nhead, hdim) 111 | """ 112 | 113 | @staticmethod 114 | def forward(ctx, 115 | query_batch_cnt: torch.Tensor, 116 | key_batch_cnt: torch.Tensor, 117 | index_pair_batch: torch.Tensor, 118 | index_pair: torch.Tensor, 119 | attn_weight: torch.Tensor, 120 | value_features: torch.Tensor): 121 | """ 122 | :param ctx: 123 | :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. 124 | :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. 125 | :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch 126 | index of each query. 127 | :param index_pair: A integer tensor with shape [total_query_num, local_size] 128 | We ignore those index whose value is -1. 129 | :param attn_weight: A float tensor with shape [total_query_num, local_size, nhead] 130 | :param value_features: A float tensor with shape [total_key_num, nhead, hdim] 131 | :return: 132 | output: A float tensor with shape [total_query_num, nhead, hdim] 133 | """ 134 | assert query_batch_cnt.is_contiguous() 135 | assert key_batch_cnt.is_contiguous() 136 | assert index_pair_batch.is_contiguous() 137 | assert index_pair.is_contiguous() 138 | assert attn_weight.is_contiguous() 139 | assert value_features.is_contiguous() 140 | 141 | b = query_batch_cnt.shape[0] 142 | total_query_num, local_size = index_pair.size() 143 | total_key_num, nhead, hdim = value_features.size() 144 | 145 | # Need to ensure that every tensor in query features have an output. 146 | assert total_query_num == attn_weight.shape[0] 147 | 148 | output = torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_() 149 | 150 | attention_cuda.attention_value_computation_wrapper( 151 | b, total_query_num, local_size, total_key_num, nhead, hdim, 152 | query_batch_cnt, key_batch_cnt, index_pair_batch, 153 | index_pair, attn_weight, value_features, 154 | output) 155 | ctx.for_backwards = ( 156 | b, total_query_num, local_size, total_key_num, nhead, hdim, 157 | query_batch_cnt, key_batch_cnt, index_pair_batch, 158 | index_pair, attn_weight, value_features 159 | ) 160 | return output 161 | 162 | @staticmethod 163 | def backward(ctx, grad_out: torch.Tensor): 164 | """ 165 | Args: 166 | ctx: 167 | grad_out: [total_query_num, nhead, hdim] 168 | Returns: 169 | grad_attn_weight: [total_query_num, local_size, nhead] 170 | grad_value_features: [total_key_num, nhead, hdim] 171 | """ 172 | (b, total_query_num, local_size, total_key_num, nhead, hdim, 173 | query_batch_cnt, key_batch_cnt, index_pair_batch, 174 | index_pair, attn_weight, value_features) = ctx.for_backwards 175 | 176 | grad_attn_weight = Variable(torch.cuda.FloatTensor( 177 | total_query_num, local_size, nhead).zero_()) 178 | grad_value_features = Variable(torch.cuda.FloatTensor( 179 | total_key_num, nhead, hdim).zero_()) 180 | 181 | grad_out_data = grad_out.data.contiguous() 182 | attention_cuda.attention_value_computation_grad_wrapper( 183 | b, total_query_num, local_size, total_key_num, nhead, hdim, 184 | query_batch_cnt, key_batch_cnt, index_pair_batch, 185 | index_pair, attn_weight, value_features, 186 | grad_out_data, grad_attn_weight.data, grad_value_features.data) 187 | return None, None, None, None, grad_attn_weight, grad_value_features 188 | 189 | 190 | attention_value_computation = AttentionValueComputation.apply -------------------------------------------------------------------------------- /trajflow/mtr_ops/attention/attention_utils_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly copy-paste from https://github.com/dvlab-research/DeepVision3D/blob/master/EQNet/eqnet/ops/attention/attention_utils_v2.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Function, Variable 8 | 9 | from . import attention_cuda 10 | 11 | 12 | """ Attention computation code v2.""" 13 | class AttentionWeightComputation(Function): 14 | """ 15 | Generate the attention weight matrix based on: 16 | * the generated attention pair index (total_query_num, local_size); 17 | * query features (total_query_num, nhead, hdim) 18 | * key features (total_key_num, nhead, hdim) 19 | Generate the attention weight matrix. 20 | * (total_query_num, local_size) 21 | """ 22 | 23 | @staticmethod 24 | def forward(ctx, 25 | query_batch_cnt: torch.Tensor, 26 | key_batch_cnt: torch.Tensor, 27 | index_pair_batch: torch.Tensor, 28 | index_pair: torch.Tensor, 29 | query_features: torch.Tensor, 30 | key_features: torch.Tensor): 31 | """ 32 | :param ctx: 33 | :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. 34 | :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. 35 | :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch 36 | index of each query. 37 | :param index_pair: A integer tensor with shape [total_query_num, local_size] 38 | We ignore those index whose value is -1. 39 | :param query_features: A float tensor with shape [total_query_num, nhead, hdim] 40 | :param key_features: A float tensor with shape [total_key_num, nhead, hdim] 41 | :return: 42 | output: A float tensor with shape [total_query_num, local_size, nhead] 43 | """ 44 | assert query_batch_cnt.is_contiguous() 45 | assert key_batch_cnt.is_contiguous() 46 | assert index_pair_batch.is_contiguous() 47 | assert index_pair.is_contiguous() 48 | assert query_features.is_contiguous() 49 | assert key_features.is_contiguous() 50 | 51 | b = query_batch_cnt.shape[0] 52 | total_query_num, local_size = index_pair.size() 53 | total_key_num, nhead, hdim = key_features.size() 54 | 55 | # Need to ensure that every tensor in query features have an output. 56 | assert total_query_num == query_features.shape[0] 57 | 58 | output = torch.cuda.FloatTensor(total_query_num, local_size, nhead).zero_() 59 | 60 | attention_cuda.attention_weight_computation_wrapper_v2( 61 | b, total_query_num, local_size, total_key_num, nhead, hdim, 62 | query_batch_cnt, key_batch_cnt, index_pair_batch, 63 | index_pair, query_features, key_features, 64 | output) 65 | ctx.for_backwards = ( 66 | b, total_query_num, local_size, total_key_num, nhead, hdim, 67 | query_batch_cnt, key_batch_cnt, index_pair_batch, 68 | index_pair, query_features, key_features 69 | ) 70 | return output 71 | 72 | @staticmethod 73 | def backward(ctx, grad_out: torch.Tensor): 74 | """ 75 | Args: 76 | ctx: 77 | grad_out: [total_query_num, local_size, nhead] 78 | Returns: 79 | grad_query_features: [total_query_num, nhead, hdim] 80 | grad_key_features: [total_key_num, nhead, hdim] 81 | """ 82 | (b, total_query_num, local_size, total_key_num, nhead, hdim, 83 | query_batch_cnt, key_batch_cnt, index_pair_batch, 84 | index_pair, query_features, key_features) = ctx.for_backwards 85 | 86 | grad_query_features = Variable(torch.cuda.FloatTensor( 87 | total_query_num, nhead, hdim).zero_()) 88 | grad_key_features = Variable(torch.cuda.FloatTensor( 89 | total_key_num, nhead, hdim).zero_()) 90 | 91 | grad_out_data = grad_out.data.contiguous() 92 | attention_cuda.attention_weight_computation_grad_wrapper_v2( 93 | b, total_query_num, local_size, total_key_num, nhead, hdim, 94 | query_batch_cnt, key_batch_cnt, index_pair_batch, 95 | index_pair, query_features, key_features, 96 | grad_out_data, grad_query_features.data, grad_key_features.data) 97 | return None, None, None, None, grad_query_features, grad_key_features 98 | 99 | 100 | attention_weight_computation = AttentionWeightComputation.apply 101 | 102 | 103 | class AttentionValueComputation(Function): 104 | """ 105 | Generate the attention result based on: 106 | * the generated attention pair index (total_query_num, local_size); 107 | * value features (total_key_num, nhead, hdim) 108 | * attn_weight (total_query_num, local_size, nhead) 109 | Generate the attention result. 110 | * (total_query_num, nhead, hdim) 111 | """ 112 | 113 | @staticmethod 114 | def forward(ctx, 115 | query_batch_cnt: torch.Tensor, 116 | key_batch_cnt: torch.Tensor, 117 | index_pair_batch: torch.Tensor, 118 | index_pair: torch.Tensor, 119 | attn_weight: torch.Tensor, 120 | value_features: torch.Tensor): 121 | """ 122 | :param ctx: 123 | :param query_batch_cnt: A integer tensor with shape [bs], indicating the query amount for each batch. 124 | :param key_batch_cnt: A integer tensor with shape [bs], indicating the key amount of each batch. 125 | :param index_pair_batch: A integer tensor with shape [total_query_num], indicating the batch 126 | index of each query. 127 | :param index_pair: A integer tensor with shape [total_query_num, local_size] 128 | We ignore those index whose value is -1. 129 | :param attn_weight: A float tensor with shape [total_query_num, local_size, nhead] 130 | :param value_features: A float tensor with shape [total_key_num, nhead, hdim] 131 | :return: 132 | output: A float tensor with shape [total_query_num, nhead, hdim] 133 | """ 134 | assert query_batch_cnt.is_contiguous() 135 | assert key_batch_cnt.is_contiguous() 136 | assert index_pair_batch.is_contiguous() 137 | assert index_pair.is_contiguous() 138 | assert attn_weight.is_contiguous() 139 | assert value_features.is_contiguous() 140 | 141 | b = query_batch_cnt.shape[0] 142 | total_query_num, local_size = index_pair.size() 143 | total_key_num, nhead, hdim = value_features.size() 144 | 145 | # Need to ensure that every tensor in query features have an output. 146 | assert total_query_num == attn_weight.shape[0] 147 | 148 | output = torch.cuda.FloatTensor(total_query_num, nhead, hdim).zero_() 149 | 150 | attention_cuda.attention_value_computation_wrapper_v2( 151 | b, total_query_num, local_size, total_key_num, nhead, hdim, 152 | query_batch_cnt, key_batch_cnt, index_pair_batch, 153 | index_pair, attn_weight, value_features, 154 | output) 155 | ctx.for_backwards = ( 156 | b, total_query_num, local_size, total_key_num, nhead, hdim, 157 | query_batch_cnt, key_batch_cnt, index_pair_batch, 158 | index_pair, attn_weight, value_features 159 | ) 160 | return output 161 | 162 | @staticmethod 163 | def backward(ctx, grad_out: torch.Tensor): 164 | """ 165 | Args: 166 | ctx: 167 | grad_out: [total_query_num, nhead, hdim] 168 | Returns: 169 | grad_attn_weight: [total_query_num, local_size, nhead] 170 | grad_value_features: [total_key_num, nhead, hdim] 171 | """ 172 | (b, total_query_num, local_size, total_key_num, nhead, hdim, 173 | query_batch_cnt, key_batch_cnt, index_pair_batch, 174 | index_pair, attn_weight, value_features) = ctx.for_backwards 175 | 176 | grad_attn_weight = Variable(torch.cuda.FloatTensor( 177 | total_query_num, local_size, nhead).zero_()) 178 | grad_value_features = Variable(torch.cuda.FloatTensor( 179 | total_key_num, nhead, hdim).zero_()) 180 | 181 | grad_out_data = grad_out.data.contiguous() 182 | attention_cuda.attention_value_computation_grad_wrapper_v2( 183 | b, total_query_num, local_size, total_key_num, nhead, hdim, 184 | query_batch_cnt, key_batch_cnt, index_pair_batch, 185 | index_pair, attn_weight, value_features, 186 | grad_out_data, grad_attn_weight.data, grad_value_features.data) 187 | return None, None, None, None, grad_attn_weight, grad_value_features 188 | 189 | 190 | attention_value_computation = AttentionValueComputation.apply -------------------------------------------------------------------------------- /trajflow/models/denoising_decoder/decoder_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import torch 14 | 15 | from trajflow.models.layers.common_layers import gen_sineembed_for_position 16 | from trajflow.utils import motion_utils 17 | 18 | 19 | def apply_dense_future_prediction(obj_feature, obj_mask, obj_pos, forward_ret_dict, 20 | obj_pos_encoding_layer, dense_future_head, future_traj_mlps, traj_fusion_mlps, 21 | num_future_frames): 22 | num_center_objects, num_objects, _ = obj_feature.shape 23 | 24 | # dense future prediction 25 | obj_pos_valid = obj_pos[obj_mask][..., 0:2] 26 | obj_feature_valid = obj_feature[obj_mask] 27 | obj_pos_feature_valid = obj_pos_encoding_layer(obj_pos_valid) 28 | obj_fused_feature_valid = torch.cat((obj_pos_feature_valid, obj_feature_valid), dim=-1) 29 | 30 | pred_dense_trajs_valid = dense_future_head(obj_fused_feature_valid) 31 | pred_dense_trajs_valid = pred_dense_trajs_valid.view(pred_dense_trajs_valid.shape[0], 32 | num_future_frames, 7) 33 | 34 | temp_center = pred_dense_trajs_valid[:, :, 0:2] + obj_pos_valid[:, None, 0:2] 35 | pred_dense_trajs_valid = torch.cat((temp_center, pred_dense_trajs_valid[:, :, 2:]), dim=-1) 36 | 37 | # future feature encoding and fuse to past obj_feature 38 | obj_future_input_valid = pred_dense_trajs_valid[:, :, [0, 1, -2, -1]].flatten(start_dim=1, end_dim=2) 39 | obj_future_feature_valid = future_traj_mlps(obj_future_input_valid) 40 | 41 | obj_full_trajs_feature = torch.cat((obj_feature_valid, obj_future_feature_valid), dim=-1) 42 | obj_feature_valid = traj_fusion_mlps(obj_full_trajs_feature) 43 | 44 | ret_obj_feature = torch.zeros_like(obj_feature) 45 | ret_obj_feature[obj_mask] = obj_feature_valid 46 | 47 | ret_pred_dense_future_trajs = obj_feature.new_zeros(num_center_objects, num_objects, num_future_frames, 7) 48 | ret_pred_dense_future_trajs[obj_mask] = pred_dense_trajs_valid 49 | forward_ret_dict['pred_dense_trajs'] = ret_pred_dense_future_trajs 50 | 51 | return ret_obj_feature, ret_pred_dense_future_trajs 52 | 53 | 54 | def get_motion_query(intention_points_dict, intention_query_mlps, center_objects_type): 55 | num_center_objects = len(center_objects_type) 56 | intention_points = torch.stack([intention_points_dict[center_objects_type[obj_idx]] for obj_idx in range(num_center_objects)], dim=0) # [B, K, 2] 57 | d_model = intention_query_mlps[0].in_features 58 | intention_query = gen_sineembed_for_position(intention_points, hidden_dim=d_model) # [B, K, D] 59 | intention_query = intention_query_mlps(intention_query.view(-1, d_model)).view(num_center_objects, -1, d_model) # [B, K, D] 60 | return intention_query, intention_points 61 | 62 | 63 | def get_center_gt_idx(layer_idx, num_inter_layers, num_decoder_layers, flag_training, forward_ret_dict, 64 | pred_scores=None, pred_trajs=None, pred_list=None, prev_trajs=None, prev_dist=None): 65 | if flag_training: 66 | center_gt_trajs = forward_ret_dict['center_gt_trajs'].cuda() 67 | center_gt_trajs_mask = forward_ret_dict['center_gt_trajs_mask'].cuda() 68 | center_gt_final_valid_idx = forward_ret_dict['center_gt_final_valid_idx'].long() 69 | intention_points = forward_ret_dict['intention_points'] 70 | num_center_objects = center_gt_trajs.shape[0] 71 | 72 | center_gt_goals = center_gt_trajs[torch.arange(num_center_objects), center_gt_final_valid_idx, 0:2] 73 | if (layer_idx // num_inter_layers) * num_inter_layers - 1 < 0: 74 | dist = (center_gt_goals[:, None, :] - intention_points).norm(dim=-1) # (num_center_objects, num_query) 75 | anchor_trajs = intention_points.unsqueeze(-2) 76 | select_mask = None 77 | select_idx = None 78 | if pred_list is None: 79 | center_gt_positive_idx = dist.argmin(dim=-1) # (num_center_objects) 80 | return center_gt_positive_idx, anchor_trajs, dist, select_mask, select_idx 81 | 82 | center_gt_positive_idx, select_mask, select_idx = motion_utils.select_distinct_anchors( 83 | dist, pred_scores, pred_trajs, anchor_trajs 84 | ) 85 | return center_gt_positive_idx, anchor_trajs, dist, select_mask, select_idx 86 | 87 | # Evolving & Distinct Anchors 88 | if pred_list is None: 89 | unique_layers = set( 90 | [(i//num_inter_layers)* num_inter_layers 91 | for i in range(num_decoder_layers)] 92 | ) 93 | if layer_idx in unique_layers: 94 | anchor_trajs = pred_trajs 95 | dist = ((center_gt_trajs[:, None, :, 0:2] - anchor_trajs[..., 0:2]).norm(dim=-1) * \ 96 | center_gt_trajs_mask[:, None]).sum(dim=-1) 97 | else: 98 | anchor_trajs, dist = prev_trajs, prev_dist 99 | else: 100 | anchor_trajs, dist = motion_utils.get_evolving_anchors( 101 | layer_idx, num_inter_layers, pred_list, 102 | center_gt_goals, intention_points, 103 | center_gt_trajs, center_gt_trajs_mask, 104 | ) 105 | 106 | center_gt_positive_idx, select_mask, select_idx = motion_utils.select_distinct_anchors( 107 | dist, pred_scores, pred_trajs, anchor_trajs 108 | ) 109 | else: 110 | center_gt_positive_idx = None 111 | anchor_trajs, dist = None, None 112 | select_mask=None 113 | select_idx=None 114 | 115 | return center_gt_positive_idx, anchor_trajs, dist, select_mask, select_idx 116 | 117 | 118 | def apply_cross_attention(query_feat, kv_feat, kv_mask, 119 | query_pos_feat, kv_pos_feat, 120 | pred_query_center, attn_indexing, 121 | attention_layer, 122 | query_feat_pre_mlp=None, query_embed_mlp=None, 123 | query_feat_pos_mlp=None, is_first=False 124 | ): 125 | """ 126 | Args: 127 | query_feat, query_pos_feat, query_searching_feat [M, B, D] 128 | kv_feat, kv_pos_feat [B, N, D] 129 | kv_mask [B, N] 130 | attn_indexing [B, N, M] 131 | attention_layer (func): LocalTransformer Layer (as in EQNet and MTR) 132 | query_feat_pre_mlp, query_embed_mlp, query_feat_pos_mlp (nn.Linear): 133 | projections to align decoder dimension 134 | is_first (bool): whether to concat query pos feature (as in MTR) 135 | Returns: 136 | query_feat: (B, M, D) 137 | """ 138 | 139 | if query_feat_pre_mlp is not None: 140 | query_feat = query_feat_pre_mlp(query_feat) 141 | if query_embed_mlp is not None: 142 | query_pos_feat = query_embed_mlp(query_pos_feat) 143 | 144 | d_model = query_feat.shape[-1] 145 | query_searching_feat = gen_sineembed_for_position(pred_query_center, hidden_dim=d_model) 146 | 147 | # fast attention 148 | if attn_indexing is not None: 149 | B, K = attn_indexing.shape[:2] 150 | M = kv_mask.shape[-1] 151 | context_valid_mask_ = torch.zeros([B, K, M+1], dtype=torch.bool, device=kv_mask.device) 152 | context_valid_mask_.scatter_(2, (attn_indexing + 1).long(), torch.ones_like(attn_indexing).bool()) 153 | context_valid_mask = torch.logical_and(context_valid_mask_[:, :, 1:], kv_mask[:, None, :]) # [B, K, M] 154 | else: 155 | context_valid_mask = kv_mask 156 | 157 | # batch-major tensor shape 158 | query_feat = attention_layer( 159 | query=query_feat, 160 | context=kv_feat, 161 | context_valid_mask=context_valid_mask, 162 | query_sa_pos_embeddings=query_pos_feat, 163 | query_ca_pos_embeddings=query_searching_feat, 164 | context_ca_pos_embeddings=kv_pos_feat, 165 | is_first=is_first, 166 | context_indexing=attn_indexing 167 | ) # [B, M, D] 168 | 169 | if query_feat_pos_mlp is not None: 170 | query_feat = query_feat_pos_mlp(query_feat) 171 | 172 | return query_feat 173 | 174 | 175 | def generate_final_prediction(pred_list, num_motion_modes): 176 | pred_scores, pred_trajs = pred_list[-1][:2] 177 | pred_scores = torch.sigmoid(pred_scores) 178 | 179 | num_query = pred_trajs.shape[1] 180 | 181 | if num_motion_modes != num_query: 182 | assert num_query > num_motion_modes 183 | pred_trajs_final, pred_scores_final, selected_idxs = motion_utils.inference_distance_nms( 184 | pred_scores, pred_trajs, num_motion_modes) 185 | else: 186 | pred_trajs_final = pred_trajs 187 | pred_scores_final = pred_scores 188 | selected_idxs = None 189 | 190 | return pred_scores_final, pred_trajs_final, selected_idxs 191 | -------------------------------------------------------------------------------- /runner/utils/starter/config_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import datetime 10 | import os 11 | from pathlib import Path, PosixPath 12 | import copy 13 | import git 14 | import shutil 15 | import yaml 16 | from easydict import EasyDict 17 | 18 | import wandb 19 | 20 | import torch 21 | 22 | from trajflow.config import init_cfg, cfg_from_yaml_file, log_config_to_file 23 | from trajflow.utils import common_utils 24 | 25 | 26 | def parse_config(): 27 | parser = argparse.ArgumentParser(description='arg parser') 28 | 29 | """basic configs""" 30 | parser.add_argument('--cfg_file', type=str, default=None, help='specify the config for training') 31 | parser.add_argument('--extra_tag', type=str, default='default', help='extra tag for this experiment') 32 | parser.add_argument('--logger_iter_interval', type=int, default=10, help='logger info interval') 33 | 34 | """optimizaion parameters""" 35 | parser.add_argument('--batch_size', type=int, default=None, help='batch size for training') 36 | parser.add_argument('--epochs', type=int, default=None, help='number of epochs to train for') 37 | parser.add_argument('--learning_rate', default=None, type=float, help='Overwrite the learning rate.') 38 | parser.add_argument('--lr_scheduler', default=None, type=str, choices=['cosine', 'lambdaLR', 'linearLR', 'constant'], help='Overwrite the LR scheduler.') 39 | parser.add_argument('--weight_decay', default=None, type=float, help='Overwrite the weight decay.') 40 | parser.add_argument('--ema_coef', default=None, type=float, help='Overwrite the EMA coefficient.') 41 | parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') 42 | 43 | """random seed control""" 44 | parser.add_argument('--fix_random_seed', action='store_true', default=False, help='fix random seed for reproducibility') 45 | 46 | """checkpoint loading, saving and evaluation""" 47 | parser.add_argument('--ckpt', type=str, default=None, help='checkpoint to start from') 48 | parser.add_argument('--ckpt_save_interval', type=int, default=2, help='save checkpoint every few number of training epochs') 49 | parser.add_argument('--ckpt_save_time_interval', type=int, default=600, help='save checkpoint every few seconds') 50 | parser.add_argument('--max_ckpt_save_num', type=int, default=5, help='max number of saved checkpoint') 51 | parser.add_argument('--max_waiting_mins', type=int, default=0, help='max waiting minutes for ckpt evaluation') 52 | 53 | """DDP configs""" 54 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none') 55 | 56 | args = parser.parse_args() 57 | 58 | """load config""" 59 | cfg = init_cfg() 60 | cfg_from_yaml_file(args.cfg_file, cfg) 61 | 62 | cfg.TAG = Path(args.cfg_file).stem 63 | cfg.EXP_GROUP_PATH = '/'.join(args.cfg_file.split('/')[1:-1]) # remove 'cfgs' and 'xxxx.yaml' 64 | 65 | if args.launcher == 'none': 66 | cfg.OPT.DIST_TRAIN, cfg.OPT.TOTAL_GPUS, cfg.OPT.WITHOUT_SYNC_BN = False, 1, True 67 | elif args.launcher == 'pytorch': 68 | local_rank = int(os.environ.get('LOCAL_RANK', '0')) 69 | cfg.OPT.TOTAL_GPUS, cfg.LOCAL_RANK = common_utils.init_dist_pytorch(local_rank, backend='nccl') 70 | cfg.OPT.DIST_TRAIN, cfg.OPT.WITHOUT_SYNC_BN = True, False 71 | else: 72 | raise ValueError('Invalid launcher: %s' % args.launcher) 73 | 74 | if args.batch_size is not None: 75 | assert args.batch_size % cfg.OPT.TOTAL_GPUS == 0, 'Batch size should match the number of gpus' 76 | cfg.OPT.BATCH_SIZE_PER_GPU = args.batch_size // cfg.OPT.TOTAL_GPUS 77 | for param, attr in [('epochs', 'NUM_EPOCHS'), ('learning_rate', 'LR'), ('lr_scheduler', 'SCHEDULER'), 78 | ('weight_decay', 'WEIGHT_DECAY'), ('ema_coef', 'EMA_COEF')]: 79 | if getattr(args, param) is not None: 80 | setattr(cfg.OPT, attr, getattr(args, param)) 81 | cfg.OPT.WORKERS = args.workers 82 | 83 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 84 | cfg.DEVICE = cfg.MODEL_DMT.DEVICE = device 85 | 86 | return args, cfg 87 | 88 | 89 | def init_basics(): 90 | ###### Start of Init ###### 91 | 92 | """Parse arguments and config""" 93 | args, cfg = parse_config() 94 | 95 | """Set random seed""" 96 | if args.fix_random_seed: 97 | common_utils.set_random_seed(42) 98 | 99 | """Set up saving folder""" 100 | # note important configs in the folder name 101 | tag_parts = [] 102 | if cfg.MODEL_DMT.DMT.DROPOUT: 103 | tag_parts.append(f'_DO{cfg.MODEL_DMT.DMT.DROPOUT:.2f}') 104 | if cfg.OPT.WEIGHT_DECAY: 105 | tag_parts.append(f'_WD{cfg.OPT.WEIGHT_DECAY:.0e}') 106 | tag_parts.append(f'_BS{cfg.OPT.BATCH_SIZE_PER_GPU * cfg.OPT.TOTAL_GPUS}_EP{cfg.OPT.NUM_EPOCHS}') 107 | default_tag = ''.join(tag_parts).replace('__', '_') 108 | args.extra_tag = '_'.join([args.extra_tag, default_tag]).replace('__', '_') 109 | 110 | """Initialize place holder saving folders and logger""" 111 | output_dir = os.path.join(cfg.ROOT_DIR, 'output', cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag) 112 | ckpt_dir = os.path.join(output_dir, 'ckpt') 113 | eval_output_dir = os.path.join(output_dir, 'eval', 'eval_with_train') 114 | 115 | os.makedirs(output_dir, exist_ok=True) 116 | os.makedirs(ckpt_dir, exist_ok=True) 117 | os.makedirs(eval_output_dir, exist_ok=True) 118 | 119 | cfg.SAVE_DIR = EasyDict({ 120 | 'OUTPUT_DIR': output_dir, 121 | 'CKPT_DIR': ckpt_dir, 122 | 'EVAL_OUTPUT_DIR': eval_output_dir 123 | }) 124 | 125 | log_file = os.path.join(output_dir, 'log_train_%s.txt' % datetime.datetime.now().strftime('%Y%m%d-%H%M%S')) 126 | logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK) 127 | 128 | # log to file 129 | logger.info('**********************Start logging**********************') 130 | gpu_list = os.environ.get('CUDA_VISIBLE_DEVICES', 'ALL') 131 | logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list) 132 | 133 | if cfg.OPT.DIST_TRAIN: 134 | logger.info('total_batch_size: %d' % (cfg.OPT.TOTAL_GPUS * cfg.OPT.BATCH_SIZE_PER_GPU)) 135 | 136 | logger.info('**********************Argparser**********************') 137 | for key, val in vars(args).items(): 138 | logger.info('{:32} {}'.format(key, val)) 139 | 140 | logger.info('**********************Configurations**********************') 141 | log_config_to_file(copy.deepcopy(cfg), logger=logger) 142 | 143 | if cfg.LOCAL_RANK == 0: 144 | os.system('cp %s %s' % (args.cfg_file, output_dir)) # copy original config 145 | # dump the updated config from easydict [not perfect as there is special items in the original config like &object_type] 146 | 147 | def easydict_to_dict(easydict_obj): 148 | # Function to convert EasyDict to a dictionary recursively 149 | result = {} 150 | for key, value in easydict_obj.items(): 151 | if isinstance(value, EasyDict): 152 | result[key] = easydict_to_dict(value) 153 | else: 154 | if isinstance(value, PosixPath): 155 | result[key] = os.path.abspath(value) # convert PosixPath to string 156 | else: 157 | result[key] = value 158 | return result 159 | 160 | nested_dict = easydict_to_dict(cfg) 161 | with open(os.path.join(output_dir, '{:s}_updated.yaml'.format(os.path.basename(args.cfg_file)[:-5])), 'w') as f: 162 | yaml.dump(nested_dict, f) 163 | 164 | # wandb log 165 | wb_log = None 166 | if cfg.LOCAL_RANK == 0: 167 | # Initialize wandb run for training 168 | wb_log = wandb.init( 169 | project="trajflow", 170 | name=f"{cfg.TAG}_{args.extra_tag}", 171 | config=cfg, 172 | dir=output_dir 173 | ) 174 | 175 | # save version control information 176 | repo = git.Repo(search_parent_directories=True) 177 | sha = repo.head.object.hexsha 178 | logger.info("git hash: {}".format(sha)) 179 | 180 | # backup code 181 | code_backup_dir = os.path.join(output_dir, 'code_backup') 182 | shutil.rmtree(code_backup_dir, ignore_errors=True) 183 | os.makedirs(code_backup_dir, exist_ok=True) 184 | if cfg.LOCAL_RANK == 0: 185 | dirs_to_save = ['trajflow', 'runner'] 186 | for this_dir in dirs_to_save: 187 | src_dir = os.path.join(cfg.ROOT_DIR, this_dir) 188 | dest_dir = os.path.join(code_backup_dir, this_dir) 189 | 190 | if os.path.exists(src_dir): 191 | try: 192 | shutil.copytree(src_dir, dest_dir) 193 | logger.info(f"Successfully copied {src_dir} to {dest_dir}") 194 | except (shutil.Error, OSError) as e: 195 | logger.error(f"Error copying {src_dir} to {dest_dir}: {e}") 196 | else: 197 | logger.warning(f"Source directory {src_dir} does not exist. Skipping.") 198 | 199 | # [shutil.copytree(os.path.join(cfg_diff.ROOT_DIR, this_dir), os.path.join(code_backup_dir, this_dir)) for this_dir in dirs_to_save] 200 | logger.info("Code is backedup to {}".format(code_backup_dir)) 201 | 202 | ###### End of Init ###### 203 | 204 | return args, cfg, logger, wb_log 205 | -------------------------------------------------------------------------------- /trajflow/models/context_encoder/mtr_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | from trajflow.models.layers.transformer import transformer_encoder_layer 18 | from trajflow.models.layers.common_layers import gen_sineembed_for_position 19 | from trajflow.models.layers import polyline_encoder 20 | from trajflow.utils import common_utils 21 | try: 22 | from trajflow.mtr_ops.knn import knn_utils 23 | except: 24 | import os 25 | print("{:s} Fail to import knn_utils module at {:s}. CUDA availability: {} {:s}".format('-' * 20, os.path.basename(__file__), torch.cuda.is_available(), '-' * 20)) 26 | 27 | 28 | class MTREncoder(nn.Module): 29 | def __init__(self, config): 30 | super().__init__() 31 | self.model_cfg = config 32 | self.device = config.DEVICE 33 | 34 | # build polyline encoders 35 | self.agent_polyline_encoder = self.build_polyline_encoder( 36 | in_channels=self.model_cfg.NUM_INPUT_ATTR_AGENT + 1, 37 | hidden_dim=self.model_cfg.NUM_CHANNEL_IN_MLP_AGENT, 38 | num_layers=self.model_cfg.NUM_LAYER_IN_MLP_AGENT, 39 | out_channels=self.model_cfg.D_MODEL 40 | ) 41 | self.map_polyline_encoder = self.build_polyline_encoder( 42 | in_channels=self.model_cfg.NUM_INPUT_ATTR_MAP, 43 | hidden_dim=self.model_cfg.NUM_CHANNEL_IN_MLP_MAP, 44 | num_layers=self.model_cfg.NUM_LAYER_IN_MLP_MAP, 45 | num_pre_layers=self.model_cfg.NUM_LAYER_IN_PRE_MLP_MAP, 46 | out_channels=self.model_cfg.D_MODEL 47 | ) 48 | 49 | # build transformer encoder layers 50 | self.use_local_attn = self.model_cfg.get('USE_LOCAL_ATTN', False) 51 | self_attn_layers = [] 52 | for _ in range(self.model_cfg.NUM_ATTN_LAYERS): 53 | self_attn_layers.append(self.build_transformer_encoder_layer( 54 | d_model=self.model_cfg.D_MODEL, 55 | nhead=self.model_cfg.NUM_ATTN_HEAD, 56 | dropout=self.model_cfg.get('DROPOUT_OF_ATTN', 0.1), 57 | normalize_before=False, 58 | use_local_attn=self.use_local_attn 59 | )) 60 | 61 | self.self_attn_layers = nn.ModuleList(self_attn_layers) 62 | self.num_out_channels = self.model_cfg.D_MODEL 63 | 64 | ### DEBUG: break down the number of parameters ### 65 | # param_agent_polyline_encoder = sum(p.numel() for p in self.agent_polyline_encoder.parameters()) 66 | # param_map_polyline_encoder = sum(p.numel() for p in self.map_polyline_encoder.parameters()) 67 | # param_self_attn_layers = sum(p.numel() for p in self.self_attn_layers.parameters()) 68 | # print("Params of agent_polyline_encoder: {:,d}, map_polyline_encoder: {:,d}, self_attn_layers: {:,d}".format(param_agent_polyline_encoder, param_map_polyline_encoder, param_self_attn_layers)) 69 | # pdb.set_trace() 70 | ### DEBUG: break down the number of parameters ### 71 | 72 | def build_polyline_encoder(self, in_channels, hidden_dim, num_layers, num_pre_layers=1, out_channels=None): 73 | ret_polyline_encoder = polyline_encoder.PointNetPolylineEncoder( 74 | in_channels=in_channels, 75 | hidden_dim=hidden_dim, 76 | num_layers=num_layers, 77 | num_pre_layers=num_pre_layers, 78 | out_channels=out_channels 79 | ) 80 | return ret_polyline_encoder 81 | 82 | def build_transformer_encoder_layer(self, d_model, nhead, dropout=0.1, normalize_before=False, use_local_attn=False): 83 | single_encoder_layer = transformer_encoder_layer.TransformerEncoderLayer( 84 | d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, 85 | normalize_before=normalize_before, use_local_attn=use_local_attn 86 | ) 87 | return single_encoder_layer 88 | 89 | def apply_global_attn(self, x, x_mask, x_pos): 90 | """ 91 | 92 | Args: 93 | x (batch_size, N, d_model): 94 | x_mask (batch_size, N): 95 | x_pos (batch_size, N, 3): 96 | """ 97 | assert torch.all(x_mask.sum(dim=-1) > 0) 98 | 99 | batch_size, N, d_model = x.shape 100 | x_t = x.permute(1, 0, 2) 101 | x_mask_t = x_mask 102 | x_pos_t = x_pos.permute(1, 0, 2) 103 | 104 | pos_embedding = gen_sineembed_for_position(x_pos_t[..., :2], hidden_dim=d_model) 105 | 106 | for k in range(len(self.self_attn_layers)): 107 | x_t = self.self_attn_layers[k]( 108 | src=x_t, 109 | src_key_padding_mask=~x_mask_t, 110 | pos=pos_embedding 111 | ) 112 | x_out = x_t.permute(1, 0, 2) # (batch_size, N, d_model) 113 | return x_out 114 | 115 | def apply_local_attn(self, x, x_mask, x_pos, num_of_neighbors): 116 | """ 117 | 118 | Args: 119 | x (batch_size, N, d_model): 120 | x_mask (batch_size, N): 121 | x_pos (batch_size, N, 3): 122 | """ 123 | assert torch.all(x_mask.sum(dim=-1) > 0) 124 | batch_size, N, d_model = x.shape 125 | 126 | x_stack_full = x.view(-1, d_model) # (batch_size * N, d_model) 127 | x_mask_stack = x_mask.view(-1) # (batch_size * N) 128 | x_pos_stack_full = x_pos.view(-1, 3) # (batch_size * N, 3) 129 | batch_idxs_full = torch.arange(batch_size).type_as(x)[:, None].repeat(1, N).view(-1).int() # (batch_size * N) 130 | 131 | # filter invalid elements 132 | x_stack = x_stack_full[x_mask_stack] 133 | x_pos_stack = x_pos_stack_full[x_mask_stack] 134 | batch_idxs = batch_idxs_full[x_mask_stack] 135 | 136 | # knn 137 | batch_offsets = common_utils.get_batch_offsets(batch_idxs=batch_idxs, bs=batch_size).int() # (batch_size + 1) 138 | batch_cnt = batch_offsets[1:] - batch_offsets[:-1] 139 | 140 | index_pair = knn_utils.knn_batch_mlogk( 141 | x_pos_stack, x_pos_stack, batch_idxs, batch_offsets, num_of_neighbors 142 | ) # (num_valid_elems, K), in range of [0, 1, ..., N-1] 143 | 144 | # positional encoding 145 | pos_embedding = gen_sineembed_for_position(x_pos_stack[None, :, 0:2], hidden_dim=d_model)[0] 146 | 147 | # local attn 148 | output = x_stack 149 | for k in range(len(self.self_attn_layers)): 150 | output = self.self_attn_layers[k]( 151 | src=output, 152 | pos=pos_embedding, 153 | index_pair=index_pair, 154 | query_batch_cnt=batch_cnt, 155 | key_batch_cnt=batch_cnt, 156 | index_pair_batch=batch_idxs 157 | ) 158 | 159 | ret_full_feature = torch.zeros_like(x_stack_full) # (batch_size * N, d_model) 160 | ret_full_feature[x_mask_stack] = output 161 | 162 | ret_full_feature = ret_full_feature.view(batch_size, N, d_model) 163 | return ret_full_feature 164 | 165 | def forward(self, batch_dict): 166 | """ 167 | Args: 168 | batch_dict: 169 | input_dict: 170 | """ 171 | input_dict = batch_dict['input_dict'] 172 | obj_trajs, obj_trajs_mask = input_dict['obj_trajs'].to(self.device), input_dict['obj_trajs_mask'].to(self.device) 173 | map_polylines, map_polylines_mask = input_dict['map_polylines'].to(self.device), input_dict['map_polylines_mask'].to(self.device) 174 | 175 | obj_trajs_last_pos = input_dict['obj_trajs_last_pos'].to(self.device) 176 | map_polylines_center = input_dict['map_polylines_center'].to(self.device) 177 | track_index_to_predict = input_dict['track_index_to_predict'] 178 | 179 | assert obj_trajs_mask.dtype == torch.bool and map_polylines_mask.dtype == torch.bool 180 | 181 | num_center_objects, num_objects, num_timestamps, _ = obj_trajs.shape 182 | num_polylines = map_polylines.shape[1] 183 | 184 | # apply polyline encoder 185 | obj_trajs_in = torch.cat((obj_trajs, obj_trajs_mask[:, :, :, None].type_as(obj_trajs)), dim=-1) 186 | obj_polylines_feature = self.agent_polyline_encoder(obj_trajs_in, obj_trajs_mask) # (num_center_objects, num_objects, C) 187 | map_polylines_feature = self.map_polyline_encoder(map_polylines, map_polylines_mask) # (num_center_objects, num_polylines, C) 188 | 189 | # apply self-attn 190 | obj_valid_mask = (obj_trajs_mask.sum(dim=-1) > 0) # (num_center_objects, num_objects) 191 | map_valid_mask = (map_polylines_mask.sum(dim=-1) > 0) # (num_center_objects, num_polylines) 192 | 193 | global_token_feature = torch.cat((obj_polylines_feature, map_polylines_feature), dim=1) 194 | global_token_mask = torch.cat((obj_valid_mask, map_valid_mask), dim=1) 195 | global_token_pos = torch.cat((obj_trajs_last_pos, map_polylines_center), dim=1) 196 | 197 | if self.use_local_attn: 198 | global_token_feature = self.apply_local_attn( 199 | x=global_token_feature, x_mask=global_token_mask, x_pos=global_token_pos, 200 | num_of_neighbors=self.model_cfg.NUM_OF_ATTN_NEIGHBORS 201 | ) 202 | else: 203 | global_token_feature = self.apply_global_attn( 204 | x=global_token_feature, x_mask=global_token_mask, x_pos=global_token_pos 205 | ) 206 | 207 | obj_polylines_feature = global_token_feature[:, :num_objects] 208 | map_polylines_feature = global_token_feature[:, num_objects:] 209 | assert map_polylines_feature.shape[1] == num_polylines 210 | 211 | # organize return features 212 | center_objects_feature = obj_polylines_feature[torch.arange(num_center_objects), track_index_to_predict] 213 | 214 | encoder_output = { 215 | 'center_objects_feature': center_objects_feature, 216 | 'obj_feature': obj_polylines_feature, 217 | 'map_feature': map_polylines_feature, 218 | 'obj_mask': obj_valid_mask, 219 | 'map_mask': map_valid_mask, 220 | 'obj_pos': obj_trajs_last_pos, 221 | 'map_pos': map_polylines_center 222 | } 223 | 224 | batch_dict['encoder_output'] = encoder_output 225 | 226 | return batch_dict 227 | -------------------------------------------------------------------------------- /trajflow/utils/motion_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025-present, Qi Yan. 2 | # Copyright (c) Shaoshuai Shi. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | ##################################################################################### 8 | # Code is based on the Motion Transformer (https://arxiv.org/abs/2209.13508) implementation 9 | # from https://github.com/sshaoshuai/MTR by Shaoshuai Shi, Li Jiang, Dengxin Dai, Bernt Schiele 10 | #################################################################################### 11 | 12 | 13 | import torch 14 | 15 | 16 | def batch_nms(pred_trajs, pred_scores, dist_thresh, num_ret_modes=6, return_mask=False): 17 | """ 18 | 19 | Args: 20 | pred_trajs (batch_size, num_modes, num_timestamps, 7) 21 | pred_scores (batch_size, num_modes) 22 | dist_thresh (batch_size) 23 | num_ret_modes (int, optional): Defaults to 6. 24 | 25 | Returns: 26 | ret_trajs (batch_size, num_ret_modes, num_timestamps, 7) 27 | ret_scores (batch_size, num_ret_modes) 28 | ret_idxs (batch_size, num_ret_modes) 29 | """ 30 | batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape 31 | 32 | sorted_idxs = pred_scores.argsort(dim=-1, descending=True) 33 | bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) 34 | sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] 35 | sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) 36 | sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) 37 | 38 | dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) # [B, K, K] 39 | if isinstance(dist_thresh, float): 40 | point_cover_mask = (dist < dist_thresh) 41 | else: 42 | point_cover_mask = (dist < dist_thresh[:, None, None]) 43 | 44 | point_val = sorted_pred_scores.clone() # (batch_size, N) 45 | point_val_selected = torch.zeros_like(point_val) # (batch_size, N) 46 | ret_mask_sorted = torch.ones_like(point_val).bool() # (batch_size, N) 47 | 48 | ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() 49 | bs_idxs = torch.arange(batch_size).type_as(ret_idxs) 50 | if not return_mask: 51 | ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) 52 | ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) 53 | 54 | for k in range(num_ret_modes): 55 | cur_idx = point_val.argmax(dim=-1) # (batch_size) 56 | ret_idxs[:, k] = cur_idx 57 | 58 | new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) 59 | filter_mask = new_cover_mask.clone() 60 | filter_mask[bs_idxs, cur_idx] = False 61 | filter_mask *= (point_val.max(dim=-1, keepdim=True).values > 0) 62 | ret_mask_sorted[filter_mask] = False 63 | 64 | point_val = point_val * (~new_cover_mask).float() # (batch_size, N) 65 | point_val_selected[bs_idxs, cur_idx] = -1 66 | point_val += point_val_selected 67 | 68 | if return_mask and (point_val.max(dim=-1, keepdim=True).values <= 0).all(): 69 | # early stop to avoid unnecessary computation 70 | break 71 | 72 | if not return_mask: 73 | ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] 74 | ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] 75 | 76 | bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None] 77 | ret_idxs = sorted_idxs[bs_idxs, ret_idxs] 78 | 79 | if return_mask: 80 | ret_mask = torch.zeros_like(ret_mask_sorted) 81 | ret_mask_sorted[torch.cumsum(ret_mask_sorted, dim=-1) > num_ret_modes] = False 82 | ret_mask[bs_idxs, sorted_idxs] = ret_mask_sorted 83 | return ret_mask, ret_idxs 84 | else: 85 | return ret_trajs, ret_scores, ret_idxs 86 | 87 | 88 | def get_evolving_anchors( 89 | layer_idx, num_inter_layers, pred_list, 90 | center_gt_goals, intention_points, 91 | center_gt_trajs, center_gt_trajs_mask, 92 | ): 93 | """ 94 | Anchor evolving by selected interaction layers 95 | By EDA: Evolving and Distinct Anchors for Multimodal Motion Prediction." Proceedings of the AAAI 96 | Args: 97 | layer_idx (int): current layer idx 98 | num_inter_layers (int): interactive layer for EDA anchors 99 | center_gt_goals, center_gt_trajs (Tensor): GT trajectories. 100 | pred_list (List[Tensor]): full prediction 101 | Returns: 102 | dist (Tensor): end-point distance 103 | anchor_trajs (Tensor): selected trajs for NMS 104 | """ 105 | positive_layer_idx = (layer_idx//num_inter_layers) * num_inter_layers - 1 106 | if positive_layer_idx < 0: 107 | anchor_trajs = intention_points.unsqueeze(-2) 108 | # (num_center_objects, num_query) 109 | dist = (center_gt_goals[:, None, :] - intention_points).norm(dim=-1) 110 | else: 111 | anchor_trajs = pred_list[positive_layer_idx][1] 112 | # (num_center_objects, num_query) 113 | dist = ((center_gt_trajs[:, None, :, 0:2] - anchor_trajs[..., 0:2]).norm(dim=-1) * \ 114 | center_gt_trajs_mask[:, None]).sum(dim=-1) 115 | return anchor_trajs, dist 116 | 117 | 118 | def select_distinct_anchors( 119 | dist, pred_scores, pred_trajs, anchor_trajs, 120 | lower_dist=2.5, upper_dist=3.5, 121 | lower_length=10, upper_length=50, scalar=1.5): 122 | """ 123 | Selects distinct anchors based on trajectory length and configurable distance thresholds. 124 | By EDA: Evolving and Distinct Anchors for Multimodal Motion Prediction." Proceedings of the AAAI 125 | Args: 126 | dist (Tensor): end-point distance 127 | pred_scores (Tensor): Prediction scores for each trajectory. 128 | pred_trajs (Tensor): Predicted trajectories. 129 | anchor_trajs (Tensor): Anchor trajectories (layer 2, 4, 6) for NMS processing. 130 | Returns: 131 | Tensor: center_gt_positive_idx for distinctiveness criteria. 132 | """ 133 | # Initialize the selection mask 134 | select_mask = torch.ones_like(pred_scores).bool() 135 | 136 | # Calculate the length of the top trajectory 137 | num_center_objects = pred_scores.shape[0] 138 | top_traj = pred_trajs[torch.arange(num_center_objects), pred_scores.argsort(dim=-1)[:, -1]][..., :2] 139 | top_traj_length = torch.norm(torch.diff(top_traj, dim=1), dim=-1).sum(dim=-1) 140 | 141 | # Set distance thresholds 142 | dist_thresh = torch.minimum( 143 | torch.tensor(upper_dist, device=pred_trajs.device), 144 | torch.maximum( 145 | torch.tensor(lower_dist, device=pred_trajs.device), 146 | lower_dist + scalar * (top_traj_length - lower_length) / (upper_length - lower_length) 147 | ) 148 | ) 149 | 150 | # Apply non-maximum suppression based on distance threshold 151 | select_mask, select_idx = batch_nms( 152 | anchor_trajs, pred_scores.sigmoid(), 153 | dist_thresh=dist_thresh, 154 | num_ret_modes=anchor_trajs.shape[1], 155 | return_mask=True 156 | ) 157 | 158 | dist = dist.masked_fill(~select_mask, 1e10) 159 | center_gt_positive_idx = dist.argmin(dim=-1) 160 | return center_gt_positive_idx, select_mask, select_idx 161 | 162 | 163 | def inference_distance_nms( 164 | pred_scores, pred_trajs, num_motion_modes=6, 165 | lower_dist=2.5, upper_dist=3.5, 166 | lower_length=10, upper_length=50, scalar=1.5 167 | ): 168 | """ 169 | Perform NMS post-processing during inference 170 | Followed by MTRA 171 | """ 172 | num_center_objects, num_query, num_future_timestamps, num_feat = pred_trajs.shape 173 | top_traj = pred_trajs[torch.arange(num_center_objects), pred_scores.argsort(dim=-1)[:, -1]][..., :2] 174 | top_traj_length = torch.norm(torch.diff(top_traj, dim=1), dim=-1).sum(dim=-1) 175 | 176 | dist_thresh = torch.minimum( 177 | torch.tensor(upper_dist, device=pred_trajs.device), 178 | torch.maximum( 179 | torch.tensor(lower_dist, device=pred_trajs.device), 180 | lower_dist+scalar*(top_traj_length-lower_length)/(upper_length-lower_length) 181 | ) 182 | ) 183 | 184 | pred_trajs_final, pred_scores_final, selected_idxs = batch_nms( 185 | pred_trajs=pred_trajs, pred_scores=pred_scores, 186 | dist_thresh=dist_thresh, 187 | num_ret_modes=num_motion_modes 188 | ) 189 | return pred_trajs_final, pred_scores_final, selected_idxs 190 | 191 | 192 | def get_ade_of_waymo(pred_trajs, gt_trajs, gt_valid_mask, calculate_steps=(5, 9, 15)) -> float: 193 | """Compute Average Displacement Error. 194 | 195 | Args: 196 | pred_trajs: (batch_size, num_modes, pred_len, 2) 197 | gt_trajs: (batch_size, pred_len, 2) 198 | gt_valid_mask: (batch_size, pred_len) 199 | Returns: 200 | ade: Average Displacement Error 201 | 202 | """ 203 | # assert pred_trajs.shape[2] in [1, 16, 80] 204 | if pred_trajs.shape[2] == 80: 205 | pred_trajs = pred_trajs[:, :, 4::5] 206 | gt_trajs = gt_trajs[:, 4::5] 207 | gt_valid_mask = gt_valid_mask[:, 4::5] 208 | 209 | ade = 0 210 | for cur_step in calculate_steps: 211 | dist_error = (pred_trajs[:, :, :cur_step+1, :] - gt_trajs[:, None, :cur_step+1, :]).norm(dim=-1) # (batch_size, num_modes, pred_len) 212 | dist_error = (dist_error * gt_valid_mask[:, None, :cur_step+1].float()).sum(dim=-1) / torch.clamp_min(gt_valid_mask[:, :cur_step+1].sum(dim=-1)[:, None], min=1.0) # (batch_size, num_modes) 213 | cur_ade = dist_error.min(dim=-1)[0].mean().item() 214 | 215 | ade += cur_ade 216 | 217 | ade = ade / len(calculate_steps) 218 | return ade 219 | 220 | 221 | def get_ade_of_each_category(pred_trajs, gt_trajs, gt_trajs_mask, object_types, valid_type_list, post_tag='', pre_tag=''): 222 | """ 223 | Args: 224 | pred_trajs (num_center_objects, num_modes, num_timestamps, 2): 225 | gt_trajs (num_center_objects, num_timestamps, 2): 226 | gt_trajs_mask (num_center_objects, num_timestamps): 227 | object_types (num_center_objects): 228 | 229 | Returns: 230 | 231 | """ 232 | ret_dict = {} 233 | 234 | for cur_type in valid_type_list: 235 | type_mask = (object_types == cur_type) 236 | ret_dict[f'{pre_tag}ade_{cur_type}{post_tag}'] = -0.0 237 | if type_mask.sum() == 0: 238 | continue 239 | 240 | # calculate evaluataion metric 241 | ade = get_ade_of_waymo( 242 | pred_trajs=pred_trajs[type_mask, :, :, 0:2].detach(), 243 | gt_trajs=gt_trajs[type_mask], gt_valid_mask=gt_trajs_mask[type_mask] 244 | ) 245 | ret_dict[f'{pre_tag}ade_{cur_type}{post_tag}'] = ade 246 | return ret_dict --------------------------------------------------------------------------------