├── run_main.sh ├── qualitative_results ├── 0_ground_truth.gif ├── 0_model_output.gif ├── 1_ground_truth.gif └── 1_model_output.gif ├── .gitignore ├── net_works ├── __init__.py ├── traj_decoder.py ├── transformer.py ├── back_bone.py ├── attention.py ├── scene_encoder.py └── diffusion.py ├── common ├── data_preprocess_config.py ├── obs_type.py ├── config_result.py ├── data_train_model_config.py ├── __init__.py ├── data_config.py ├── data.py └── waymo_dataset.py ├── utils ├── __init__.py ├── visualize_utils.py ├── math_utils.py ├── map_utils.py └── data_utils.py ├── tasks ├── __init__.py ├── base_task.py ├── data_split_task.py ├── data_count_task.py ├── data_preprocess_task.py ├── train_model_task.py ├── load_config_task.py └── show_result_task.py ├── config.yaml ├── main.py ├── environment.yml ├── README.md ├── LICENSE └── gene_submission.py /run_main.sh: -------------------------------------------------------------------------------- 1 | nohup python3 -u main.py >> main.log & -------------------------------------------------------------------------------- /qualitative_results/0_ground_truth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/0_ground_truth.gif -------------------------------------------------------------------------------- /qualitative_results/0_model_output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/0_model_output.gif -------------------------------------------------------------------------------- /qualitative_results/1_ground_truth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/1_ground_truth.gif -------------------------------------------------------------------------------- /qualitative_results/1_model_output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/1_model_output.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | result_data/* 2 | */__pycache__ 3 | .idea/* 4 | *.png 5 | *.pkl 6 | waymo_open_dataset_/* 7 | test_set/* 8 | valid_set/* 9 | *.pth 10 | output/* 11 | data_set/* 12 | data_output/* -------------------------------------------------------------------------------- /net_works/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: __init__.py.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | from net_works.back_bone import BackBone 10 | 11 | BackBone = BackBone 12 | -------------------------------------------------------------------------------- /common/data_preprocess_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_preprocess_config.py 6 | @Author: YangChen 7 | @Date: 2023/12/24 8 | """ 9 | from common.data import BaseConfig 10 | 11 | 12 | class DataPreprocessConfig(BaseConfig): 13 | data_size: int = 0 14 | max_data_size: int = 0 15 | num_works: int = 1 16 | 17 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: __init__.py.py 6 | @Author: YangChen 7 | @Date: 2023/12/21 8 | """ 9 | from utils.map_utils import MapUtil 10 | from utils.data_utils import DataUtil 11 | from utils.math_utils import MathUtil 12 | from utils.visualize_utils import VisualizeUtil 13 | 14 | MapUtil = MapUtil 15 | DataUtil = DataUtil 16 | MathUtil = MathUtil 17 | VisualizeUtil = VisualizeUtil 18 | -------------------------------------------------------------------------------- /common/obs_type.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: obs_type.py 6 | @Author: YangChen 7 | @Date: 2023/12/24 8 | """ 9 | 10 | ObjectType = { 11 | 0: "TYPE_UNSET", 12 | 1: "TYPE_VEHICLE", 13 | 2: "TYPE_PEDESTRIAN", 14 | 3: "TYPE_CYCLIST", 15 | 4: "TYPE_OTHER" 16 | } 17 | MapState = { 18 | 0: "LANE_STATE_UNKNOWN", 19 | 1: "LANE_STATE_ARROW_STOP", 20 | 2: "LANE_STATE_ARROW_CAUTION", 21 | 3: "LANE_STATE_ARROW_GO", 22 | 4: "LANE_STATE_STOP", 23 | 5: "LANE_STATE_CAUTION", 24 | 6: "LANE_STATE_GO", 25 | 7: "LANE_STATE_FLASHING_STOP", 26 | 8: "LANE_STATE_FLASHING_CAUTION" 27 | } 28 | -------------------------------------------------------------------------------- /common/config_result.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_load_config_result.py 6 | @Author: YangChen 7 | @Date: 2023/12/21 8 | """ 9 | from dataclasses import dataclass 10 | 11 | from common.data import TaskLogger 12 | from common.data_config import TaskConfig 13 | from common.data_preprocess_config import DataPreprocessConfig 14 | from common.data_train_model_config import TrainModelConfig 15 | 16 | 17 | @dataclass 18 | class LoadConfigResultDate: 19 | task_config: TaskConfig = None 20 | data_preprocess_config: DataPreprocessConfig = None 21 | train_model_config: TrainModelConfig = None 22 | task_id: str = "" 23 | task_logger: TaskLogger = None 24 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: __init__.py.py 6 | @Author: YangChen 7 | @Date: 2023/12/20a 8 | """ 9 | 10 | from tasks.base_task import BaseTask 11 | from tasks.load_config_task import LoadConfigTask 12 | from tasks.data_preprocess_task import DataPreprocessTask 13 | from tasks.data_split_task import DataSplitTask 14 | from tasks.data_count_task import DataCountTask 15 | from tasks.train_model_task import TrainModelTask 16 | from tasks.show_result_task import ShowResultsTask 17 | 18 | BaseTask = BaseTask 19 | LoadConfigTask = LoadConfigTask 20 | DataPreprocessTask = DataPreprocessTask 21 | DataSplitTask = DataSplitTask 22 | TrainModelTask = TrainModelTask 23 | ShowResultsTask = ShowResultsTask 24 | -------------------------------------------------------------------------------- /common/data_train_model_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: train_model_config.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | from typing import List 10 | 11 | from common.data import BaseConfig 12 | 13 | 14 | class TrainModelConfig(BaseConfig): 15 | use_gpu: bool = False 16 | gpu_ids: List[int] = None 17 | batch_size: int = 0 18 | num_works: int = 20 19 | his_step: int = 11 20 | max_pred_num: int = 8 21 | max_other_num: int = 6 22 | max_traffic_light: int = 8 23 | max_lane_num: int = 32 24 | max_point_num: int = 128 25 | num_head: int = 8 26 | attention_dim: int = 128 27 | multimodal: int = 10 28 | time_steps: int = 100 29 | schedule: str = "cosine" 30 | num_epoch: int = 0 31 | init_lr: float = 0.00001 32 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: __init__.py.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | from common.data import TaskType, TaskLogger 10 | from common.data_config import TaskConfig 11 | from common.config_result import LoadConfigResultDate 12 | from common.data_train_model_config import TrainModelConfig 13 | from common.obs_type import ObjectType, MapState 14 | from common.data_preprocess_config import DataPreprocessConfig 15 | from common.waymo_dataset import WaymoDataset 16 | 17 | 18 | TaskType = TaskType 19 | TaskLogger = TaskLogger 20 | TaskConfig = TaskConfig 21 | LoadConfigResultDate = LoadConfigResultDate 22 | ObjectType = ObjectType 23 | MapState = MapState 24 | DataPreprocessConfig = DataPreprocessConfig 25 | TrainModelConfig = TrainModelConfig 26 | WaymoDataset = WaymoDataset -------------------------------------------------------------------------------- /tasks/base_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: base_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | import os 10 | import shutil 11 | from abc import abstractmethod, ABC 12 | 13 | from common import TaskType, LoadConfigResultDate 14 | 15 | 16 | class BaseTask(ABC): 17 | def __init__(self): 18 | super(BaseTask, self).__init__() 19 | self.task_type = TaskType.UNKNOWN 20 | 21 | @abstractmethod 22 | def execute(self, result_info: LoadConfigResultDate): 23 | pass 24 | 25 | @staticmethod 26 | def check_dir_exist(input_dir: str): 27 | if not os.path.exists(input_dir) or \ 28 | not os.path.isdir(input_dir) or \ 29 | len(os.listdir(input_dir)) <= 0: 30 | raise FileNotFoundError(f"data_preprocess_dir error: {input_dir}, dir is None") 31 | 32 | @staticmethod 33 | def rebuild_dir(input_dir: str): 34 | if os.path.exists(input_dir): 35 | shutil.rmtree(input_dir) 36 | os.makedirs(input_dir, exist_ok=True) 37 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | task_config: 2 | task_list: 3 | # - "DATA_PREPROCESS" 4 | # - "DATA_SPLIT" 5 | # - "DATA_COUNT" 6 | - "TRAIN_MODEL" 7 | # - "SHOW_RESULTS" 8 | # - "EVAL_MODEL" 9 | # - "GENE_SUBMISSION" 10 | # 程序执行过程中的输出总目录 11 | output_dir: "output" 12 | log_dir: "log" 13 | image_dir: "result_image" 14 | model_dir: "model" 15 | result_dir: "result" 16 | pre_train_model: "" 17 | waymo_train_dir: "" 18 | waymo_val_dir: "" 19 | waymo_test_dir: "" 20 | # 数据预处理和打包的输出总目录 21 | data_output: "data_output" 22 | data_preprocess_dir: "data_preprocess_dir" 23 | train_dir: "train_dir" 24 | val_dir: "val_dir" 25 | test_dir: "test_dir" 26 | 27 | data_preprocess_config: 28 | data_size: 100 29 | max_data_size: 2000 30 | num_works: 20 31 | 32 | 33 | train_model_config: 34 | use_gpu: False 35 | gpu_ids: 36 | - 6 37 | - 7 38 | batch_size: 4 39 | num_works: 0 40 | his_step: 11 41 | max_pred_num: 8 42 | max_other_num: 6 43 | max_traffic_light: 8 44 | max_lane_num: 32 45 | max_point_num: 128 46 | num_head: 8 47 | attention_dim: 128 48 | multimodal: 10 49 | time_steps: 50 50 | # cosine or linear 51 | schedule: "linear" 52 | num_epoch: 200 53 | init_lr: 0.0001 54 | 55 | 56 | -------------------------------------------------------------------------------- /common/data_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_config.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | from typing import List 10 | 11 | from common.data import TaskType, BaseConfig 12 | 13 | 14 | class TaskConfig(BaseConfig): 15 | __task_list: List[TaskType] = None 16 | # 所有输出的目录 17 | output_dir: str = "" 18 | # log保存路径 19 | log_dir: str = "" 20 | model_dir: str = "" 21 | result_dir: str = "" 22 | pre_train_model: str = "" 23 | # waymo数据的目录 24 | waymo_train_dir: str = "" 25 | waymo_val_dir: str = "" 26 | waymo_test_dir: str = "" 27 | # 训练产生的图片保存路径 28 | image_dir: str = "" 29 | # 数据输出 30 | data_output: str = "" 31 | # 数据预处理保存路径 32 | data_preprocess_dir: str = "" 33 | # 数据集构建 34 | train_dir: str = "" 35 | val_dir: str = "" 36 | test_dir: str = "" 37 | 38 | @property 39 | def task_list(self) -> List[TaskType]: 40 | return self.__task_list 41 | 42 | @task_list.setter 43 | def task_list(self, task_list: List[str]): 44 | self.__task_list = [TaskType(task_name) for task_name in task_list] 45 | 46 | def check_config(self): 47 | """ 48 | 检查配置文件的输入 49 | @return: 50 | """ 51 | if len(self.task_list) < 0: 52 | raise Warning("task_list is None") 53 | -------------------------------------------------------------------------------- /net_works/traj_decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: traj_decoder.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class TrajDecoder(nn.Module): 14 | def __init__(self, multimodal: int = 10, dim: int = 256, future_step: int = 80): 15 | super(TrajDecoder, self).__init__() 16 | self.multimodal = multimodal 17 | self.future_step = future_step 18 | self.one_modal = (future_step * 3 + 1) 19 | output_dim = multimodal * self.one_modal 20 | self.decoder = nn.Sequential( 21 | nn.Linear(dim, dim * 2), 22 | nn.LayerNorm(dim * 2), 23 | nn.ReLU(inplace=True), 24 | nn.Linear(dim * 2, dim), 25 | nn.LayerNorm(dim), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(dim, dim), 28 | nn.LayerNorm(dim), 29 | nn.ReLU(inplace=True), 30 | nn.Linear(dim, output_dim) 31 | ) 32 | 33 | def forward(self, input_x): 34 | batch_size, obs_num = input_x.shape[0], input_x.shape[1] 35 | decoder_output = self.decoder(input_x) 36 | decoder_output = decoder_output.view(batch_size, obs_num, self.multimodal, self.one_modal) 37 | traj = decoder_output[:, :, :, :-1].view(batch_size, obs_num, self.multimodal, self.future_step, 3) 38 | confidence = decoder_output[:, :, :, -1] 39 | return traj, confidence 40 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: main.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | from typing import List 10 | 11 | from common import TaskType, LoadConfigResultDate, TaskLogger 12 | from tasks import BaseTask 13 | from tasks import (LoadConfigTask, DataPreprocessTask, DataSplitTask, 14 | DataCountTask, TrainModelTask, ShowResultsTask) 15 | 16 | 17 | class TaskFactory: 18 | 19 | @staticmethod 20 | def init_config() -> LoadConfigResultDate: 21 | result_info = LoadConfigResultDate() 22 | load_config_task = LoadConfigTask() 23 | load_config_task.execute(result_info) 24 | return result_info 25 | 26 | @staticmethod 27 | def init_tasks(task_type_list: List[TaskType]) -> List[BaseTask]: 28 | task_list = list() 29 | for task_type in task_type_list: 30 | if task_type == TaskType.DATA_PREPROCESS: 31 | task_list.append(DataPreprocessTask()) 32 | elif task_type == TaskType.DATA_SPLIT: 33 | task_list.append(DataSplitTask()) 34 | elif task_type == TaskType.DATA_COUNT: 35 | task_list.append(DataCountTask()) 36 | elif task_type == TaskType.TRAIN_MODEL: 37 | task_list.append(TrainModelTask()) 38 | elif task_type == TaskType.SHOW_RESULTS: 39 | task_list.append(ShowResultsTask()) 40 | return task_list 41 | 42 | 43 | def execute_tasks(): 44 | load_config_result = TaskFactory.init_config() 45 | task_list = TaskFactory.init_tasks(load_config_result.task_config.task_list) 46 | task_logger: TaskLogger = load_config_result.task_logger 47 | for task in task_list: 48 | task_logger.logger.info(f"task type {task.task_type.value} start") 49 | task.execute(load_config_result) 50 | task_logger.logger.info(f"task type {task.task_type.value} success") 51 | 52 | 53 | if __name__ == "__main__": 54 | execute_tasks() 55 | -------------------------------------------------------------------------------- /utils/visualize_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: visualize_utils.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | from typing import Any, Dict 10 | 11 | import torch 12 | from matplotlib import pyplot as plt 13 | 14 | 15 | class VisualizeUtil: 16 | @staticmethod 17 | def show_result(image_path: str, min_loss_traj: torch.Tensor, data: Dict): 18 | min_loss_traj = min_loss_traj[0] 19 | predicted_traj_mask = data['predicted_traj_mask'][0] 20 | predicted_future_traj = data['predicted_future_traj'][0] 21 | predicted_his_traj = data['predicted_his_traj'][0] 22 | predicted_num = 0 23 | for i in range(predicted_traj_mask.shape[0]): 24 | if int(predicted_traj_mask[i]) == 1: 25 | predicted_num += 1 26 | generate_traj = min_loss_traj[:predicted_num] 27 | predicted_future_traj = predicted_future_traj[:predicted_num] 28 | predicted_his_traj = predicted_his_traj[:predicted_num] 29 | map_feature_list = eval(data['map_json'][0]) 30 | real_traj = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, :2].detach().numpy() 31 | model_output = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, :2].detach().numpy() 32 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 33 | # 画地图 34 | for map_feature in map_feature_list: 35 | x_list = [float(point[0]) for point in map_feature] 36 | y_list = [float(point[1]) for point in map_feature] 37 | ax[0].plot(x_list, y_list, color="grey") 38 | ax[1].plot(x_list, y_list, color="grey") 39 | # 画原图,画模型输出 40 | for i in range(predicted_num): 41 | ax[0].plot(real_traj[i, :, 0], real_traj[i, :, 1]) 42 | ax[1].plot(model_output[i, :, 0], model_output[i, :, 1]) 43 | 44 | # label = 'Epoch {0}'.format(num_epoch) 45 | # plt.show() 46 | # fig.text(0.5, 0.04, label, ha='center') 47 | plt.savefig(image_path) 48 | plt.close('all') # 避免内存泄漏 49 | print("save_image_success") 50 | -------------------------------------------------------------------------------- /utils/math_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: math_utils.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | import numpy as np 10 | import torch 11 | 12 | 13 | class MathUtil: 14 | 15 | @staticmethod 16 | def generate_linear_schedule(time_steps: int, low: float = 1e-5, high: float = 2e-5): 17 | return np.linspace(low, high, time_steps) 18 | 19 | @staticmethod 20 | def step_cos(time_step: int, time_steps: int, s: float): 21 | return (np.cos((time_step / time_steps + s) / (1 + s) * np.pi / 2)) ** 2 22 | 23 | # @classmethod 24 | # def generate_cosine_schedule(cls, time_steps: int, s: float = 0.008): 25 | # alphas = [] 26 | # f0 = cls.step_cos(0, time_steps, s) 27 | # for t in range(time_steps + 1): 28 | # alphas.append(cls.step_cos(t, time_steps, s) / f0) 29 | # betas = [] 30 | # for t in range(1, time_steps + 1): 31 | # betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999)) 32 | # return np.array(betas) 33 | 34 | @classmethod 35 | def generate_cosine_schedule(cls, time_steps: int, s: float = 0.008): 36 | steps = time_steps + 1 37 | x = torch.linspace(0, time_steps, steps) 38 | alphas_cum_prod = torch.cos(((x / time_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 39 | alphas_cum_prod = alphas_cum_prod / alphas_cum_prod[0] 40 | betas = 1 - (alphas_cum_prod[1:] / alphas_cum_prod[:-1]) 41 | return torch.clip(betas, 0.0001, 0.9999).detach().numpy() 42 | 43 | @staticmethod 44 | def post_process_output(generate_traj: torch.Tensor, predicted_his_traj: torch.Tensor) -> torch.Tensor: 45 | delt_t = 0.1 46 | batch_size = generate_traj.shape[0] 47 | num_obs = generate_traj.shape[1] 48 | vx = generate_traj[:, :, :, :, 0] / delt_t 49 | vy = generate_traj[:, :, :, :, 1] / delt_t 50 | start_x = predicted_his_traj[:, :, -1, 0].view(batch_size, num_obs, 1, 1) 51 | start_y = predicted_his_traj[:, :, -1, 1].view(batch_size, num_obs, 1, 1) 52 | start_heading = predicted_his_traj[:, :, -1, 2].view(batch_size, num_obs, 1, 1) 53 | x = torch.cumsum(generate_traj[:, :, :, :, 0], dim=-1) + start_x 54 | y = torch.cumsum(generate_traj[:, :, :, :, 1], dim=-1) + start_y 55 | heading = torch.cumsum(generate_traj[:, :, :, :, 2], dim=-1) + start_heading 56 | output = torch.stack((x, y, heading, vx, vy), dim=-1) 57 | return output 58 | -------------------------------------------------------------------------------- /common/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | import logging 10 | from dataclasses import dataclass 11 | from enum import Enum, unique 12 | from typing import Any 13 | 14 | 15 | def object_dict_print(obj: Any) -> str: 16 | """ 17 | 打印基础类里面所有的属性 18 | @param obj: 19 | @return: 20 | """ 21 | result_str = "" 22 | for key, value in obj.__dict__.items(): 23 | if isinstance(value, list) and len(value) > 0: 24 | list_str = f"{key}:\n\t[" 25 | for row in value: 26 | list_str += f"{str(row)}, " 27 | list_str = list_str[:-2] 28 | list_str += "]\n" 29 | result_str += list_str 30 | else: 31 | result_str += f"{key}: {value} \n" 32 | return result_str 33 | 34 | 35 | class TaskType(Enum): 36 | LOAD_CONFIG = "LOAD_CONFIG" 37 | DATA_PREPROCESS = "DATA_PREPROCESS" 38 | DATA_SPLIT = "DATA_SPLIT" 39 | DATA_COUNT = "DATA_COUNT" 40 | TRAIN_MODEL = "TRAIN_MODEL" 41 | SHOW_RESULTS = "SHOW_RESULTS" 42 | EVAL_MODEL = "EVAL_MODEL" 43 | GENE_SUBMISSION = "GENE_SUBMISSION" 44 | UNKNOWN = "UNKNOWN" 45 | 46 | def __str__(self): 47 | return self.value 48 | 49 | 50 | @dataclass 51 | class BaseConfig(object): 52 | 53 | def __str__(self) -> str: 54 | return object_dict_print(self) 55 | 56 | 57 | class TaskLogger(object): 58 | """ 59 | 输出日志 60 | Args: 61 | log_path: 日志的路径 62 | """ 63 | 64 | def __init__(self, log_path: str): 65 | super(TaskLogger, self).__init__() 66 | # 创建一个日志器 67 | self.logger = logging.getLogger("logger") 68 | 69 | # 设置日志输出的最低等级,低于当前等级则会被忽略 70 | self.logger.setLevel(logging.INFO) 71 | 72 | # 创建处理器:sh为控制台处理器,fh为文件处理器 73 | sh = logging.StreamHandler() 74 | fh = logging.FileHandler(log_path, encoding="UTF-8", mode='w') 75 | 76 | # 创建格式器,并将sh,fh设置对应的格式 77 | format_str = "%(asctime)s -%(name)s -%(levelname)-8s -%(filename)s(line: %(lineno)s): %(message)s" 78 | formatter = logging.Formatter(fmt=format_str, datefmt="%Y/%m/%d %X") 79 | sh.setFormatter(formatter) 80 | fh.setFormatter(formatter) 81 | 82 | # 将处理器,添加至日志器中 83 | self.logger.addHandler(sh) 84 | self.logger.addHandler(fh) 85 | 86 | def get_logger(self) -> logging.Logger: 87 | return self.logger 88 | 89 | 90 | if __name__ == "__main__": 91 | pass 92 | -------------------------------------------------------------------------------- /tasks/data_split_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_split_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/24 8 | """ 9 | import os 10 | import pickle 11 | from multiprocessing import Pool 12 | from typing import Any, Dict 13 | 14 | from tqdm import tqdm 15 | 16 | from common import TaskType, LoadConfigResultDate 17 | from tasks import BaseTask 18 | from utils import DataUtil 19 | 20 | 21 | class DataSplitTask(BaseTask): 22 | def __init__(self): 23 | super(DataSplitTask, self).__init__() 24 | self.task_type = TaskType.DATA_SPLIT 25 | 26 | def execute(self, result_info: LoadConfigResultDate): 27 | logger = result_info.task_logger.get_logger() 28 | train_dir = result_info.task_config.train_dir 29 | self.rebuild_dir(train_dir) 30 | data_preprocess_dir = result_info.task_config.data_preprocess_dir 31 | self.check_dir_exist(data_preprocess_dir) 32 | pkl_list = sorted(os.listdir(data_preprocess_dir), 33 | key=lambda x: int(x[:-4].split('_')[-1])) 34 | data_set_num = 0 35 | his_step = result_info.train_model_config.his_step 36 | for pkl_path in pkl_list: 37 | pool = Pool(result_info.data_preprocess_config.num_works) 38 | with open(os.path.join(data_preprocess_dir, pkl_path), "rb") as f: 39 | pkl_obj = pickle.load(f) 40 | process_list = list() 41 | for one_pkl_dict in pkl_obj: 42 | data_set_num += 1 43 | one_pkl_path = os.path.join(train_dir, f"dataset_{data_set_num}.pkl") 44 | process_list.append( 45 | pool.apply_async( 46 | self.save_split_data, 47 | kwds=dict( 48 | one_pkl_dict=one_pkl_dict, 49 | his_step=his_step, 50 | one_pkl_path=one_pkl_path 51 | ) 52 | ) 53 | ) 54 | for process in tqdm(process_list): 55 | try: 56 | process.get() 57 | except Exception as e: 58 | logger.error(e) 59 | finally: 60 | continue 61 | pool.close() 62 | 63 | @staticmethod 64 | def save_split_data(one_pkl_dict: Dict[str, Any], his_step: int, one_pkl_path: str): 65 | pkl_data = DataUtil.split_pkl_data(one_pkl_dict, his_step) 66 | if pkl_data and len(pkl_data) > 0: 67 | with open(one_pkl_path, "wb") as f: 68 | pickle.dump(pkl_data, f) 69 | -------------------------------------------------------------------------------- /utils/map_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: map_utils.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | import math 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | 14 | 15 | class MapUtil: 16 | @staticmethod 17 | def local_to_global(ego_heading: float, position_x: np.ndarray, 18 | position_y: np.ndarray, ego_local_x: float, 19 | ego_local_y: float) -> Tuple[np.ndarray, np.ndarray]: 20 | """ 21 | 将坐标系从自车系转为世界系 22 | @param ego_heading: 自车在世界系的转角 23 | @param position_x: 要转换的x坐标 24 | @param position_y: 要转换的y坐标 25 | @param ego_local_x: 自车在世界系的位置 26 | @param ego_local_y: 自车在世界系的位置 27 | @return: 世界坐标系的x坐标, 世界坐标系的y坐标 28 | """ 29 | yaw = ego_heading 30 | global_x = ego_local_x + position_x * math.cos(yaw) - position_y * math.sin(yaw) 31 | global_y = ego_local_y + position_x * math.sin(yaw) + position_y * math.cos(yaw) 32 | return global_x, global_y 33 | 34 | @classmethod 35 | def theta_local_to_global(cls, ego_heading: float, heading: np.ndarray) -> np.ndarray: 36 | """ 37 | 将自车坐标系下的角度转成世界坐标系下的角度 38 | @param ego_heading: 自车在世界系的转角 39 | @param heading: 要转换的heading 40 | @return: 世界标系下的heading 41 | """ 42 | heading = heading.tolist() 43 | heading_list = [] 44 | for one_heading in heading: 45 | heading_list.append(cls.normalize_angle(ego_heading + one_heading)) 46 | return np.array(heading_list) 47 | 48 | @staticmethod 49 | def global_to_local(curr_x: float, curr_y: float, curr_heading: float, 50 | point_x: float, point_y: float) -> Tuple[float, float]: 51 | """ 52 | 将世界系的坐标转成自车系 53 | @param curr_x: 54 | @param curr_y: 55 | @param curr_heading: 56 | @param point_x: 57 | @param point_y: 58 | @return: 59 | """ 60 | delta_x = point_x - curr_x 61 | delta_y = point_y - curr_y 62 | return delta_x * math.cos(curr_heading) + delta_y * math.sin(curr_heading),\ 63 | delta_y * math.cos(curr_heading) - delta_x * math.sin(curr_heading) 64 | 65 | @classmethod 66 | def theta_global_to_local(cls, curr_heading: float, heading: float) -> float: 67 | """ 68 | 将世界坐标系下的角度转成自车坐标系下的角度 69 | @param curr_heading: 自车在世界系的转角 70 | @param heading: 要转换的heading 71 | @return: 自车坐标系下的heading 72 | """ 73 | return cls.normalize_angle(-curr_heading + heading) 74 | 75 | @staticmethod 76 | def normalize_angle(angle: float) -> float: 77 | """ 78 | 归一化弧度值, 使其范围在[-pi, pi]之间 79 | @param angle: 输入的弧度 80 | @return: 归一化之后的弧度 81 | """ 82 | angle = (angle + math.pi) % (2 * math.pi) 83 | if angle < .0: 84 | angle += 2 * math.pi 85 | return angle - math.pi 86 | -------------------------------------------------------------------------------- /tasks/data_count_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_count_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/26 8 | """ 9 | import glob 10 | import os 11 | import pickle 12 | import shutil 13 | 14 | from matplotlib import pyplot as plt 15 | from sklearn.decomposition import PCA 16 | from sklearn.cluster import KMeans 17 | from sklearn.preprocessing import MinMaxScaler 18 | 19 | import numpy as np 20 | 21 | from common import TaskType, LoadConfigResultDate 22 | from tasks import BaseTask 23 | 24 | 25 | class DataCountTask(BaseTask): 26 | 27 | def __init__(self): 28 | super(DataCountTask, self).__init__() 29 | self.task_type = TaskType.DATA_COUNT 30 | 31 | def execute(self, result_info: LoadConfigResultDate): 32 | logger = result_info.task_logger.get_logger() 33 | result_dir = result_info.task_config.result_dir 34 | self.rebuild_dir(result_dir) 35 | pkl_dir = result_info.task_config.train_dir 36 | self.check_dir_exist(pkl_dir) 37 | pkl_path_list = glob.glob(os.path.join(pkl_dir, "*.pkl")) 38 | sum_obs = list() 39 | num_predicted_obs = list() 40 | num_other_obs = list() 41 | num_lanes = list() 42 | num_traffic_light = list() 43 | for pkl_path in pkl_path_list: 44 | with open(pkl_path, "rb") as f: 45 | pkl_obj = pickle.load(f) 46 | sum_obs.append(len(pkl_obj[1])) 47 | num_predicted_obs.append(len(pkl_obj[0])) 48 | num_other_obs.append(len(pkl_obj[1]) - len(pkl_obj[0])) 49 | num_lanes.append(len(pkl_obj[4])) 50 | num_traffic_light.append(len(pkl_obj[5])) 51 | logger.info("data count success") 52 | result = { 53 | "sum_obs": sum_obs, 54 | "num_predicted_obs": num_predicted_obs, 55 | "num_other_obs": num_other_obs, 56 | "num_lanes": num_lanes, 57 | "num_traffic_light": num_traffic_light 58 | } 59 | with open(os.path.join(result_dir, f"{result_info.task_id}.pkl"), "wb") as f: 60 | pickle.dump(result, f) 61 | sum_obs = np.array(sum_obs) 62 | num_predicted_obs = np.array(num_predicted_obs) 63 | num_other_obs = np.array(num_other_obs) 64 | num_lanes = np.array(num_lanes) 65 | num_traffic_light = np.array(num_traffic_light) 66 | # 67 | feature_array = np.stack( 68 | (sum_obs, num_predicted_obs, num_other_obs, num_lanes, num_traffic_light), 69 | axis=-1 70 | ) 71 | scaler = MinMaxScaler() 72 | # 拟合并转换数据 73 | feature_array = scaler.fit_transform(feature_array) 74 | # 创建PCA模型 75 | pca = PCA(n_components=2) 76 | feature_array = pca.fit_transform(feature_array) 77 | # 聚类 78 | kmeans = KMeans(n_clusters=4, random_state=5) 79 | kmeans.fit(feature_array) 80 | y_pred = kmeans.predict(feature_array) 81 | plt.scatter(feature_array[:, 0], feature_array[:, 1], c=y_pred, s=5, cmap='viridis') 82 | plt.show() 83 | 84 | -------------------------------------------------------------------------------- /net_works/transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: transformer.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | from torch import nn 10 | 11 | from net_works.attention import MultiHeadSelfAttention, MultiHeadCrossAttention 12 | 13 | 14 | class FeedForward(nn.Module): 15 | def __init__(self, dim, hidden_dim=256): 16 | super().__init__() 17 | self.net = nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, hidden_dim), 20 | nn.GELU(), 21 | nn.Linear(hidden_dim, dim), 22 | ) 23 | 24 | def forward(self, x): 25 | return self.net(x) 26 | 27 | 28 | class TransformerCrossAttention(nn.Module): 29 | 30 | def __init__( 31 | self, input_dim: int, conditional_dim: int, 32 | head_dim: int = 64, num_heads: int = 8, drop: float = 0.1, 33 | ): 34 | super(TransformerCrossAttention, self).__init__() 35 | self.self_attention = MultiHeadSelfAttention(input_dim, head_dim, num_heads) 36 | self.cross_attention = MultiHeadCrossAttention(input_dim, conditional_dim, head_dim, num_heads) 37 | self.feed_forward = FeedForward(input_dim) 38 | self.drop_self_attention = nn.Dropout(drop) 39 | self.drop_cross_attention = nn.Dropout(drop) 40 | self.norm_self_attention = nn.LayerNorm(input_dim) 41 | self.norm_cross_attention = nn.LayerNorm(input_dim) 42 | self.norm_feed_forward = nn.LayerNorm(input_dim) 43 | 44 | def forward(self, input_x, input_conditional): 45 | # self attention 46 | norm_input_x = self.norm_self_attention(input_x) 47 | self_attention_output = self.self_attention(norm_input_x) 48 | res_output = input_x + self_attention_output 49 | res_output = self.drop_self_attention(res_output) 50 | # cross attention 51 | norm_input_x = self.norm_cross_attention(res_output) 52 | cross_attention_output = self.cross_attention(norm_input_x, input_conditional) 53 | res_output = res_output + cross_attention_output 54 | res_output = self.drop_cross_attention(res_output) 55 | # feed_forward 56 | norm_input_x = self.norm_feed_forward(res_output) 57 | feed_forward_output = self.feed_forward(norm_input_x) 58 | norm_input_x = norm_input_x + feed_forward_output 59 | return norm_input_x 60 | 61 | 62 | class TransformerSelfAttention(nn.Module): 63 | 64 | def __init__(self, input_dim: int, head_dim: int = 64, num_heads: int = 8, drop: float = 0.1): 65 | super(TransformerSelfAttention, self).__init__() 66 | self.self_attention = MultiHeadSelfAttention(input_dim, head_dim, num_heads) 67 | self.feed_forward = FeedForward(input_dim) 68 | self.drop_self_attention = nn.Dropout(drop) 69 | self.norm_self_attention = nn.LayerNorm(input_dim) 70 | self.norm_feed_forward = nn.LayerNorm(input_dim) 71 | 72 | def forward(self, input_x): 73 | # self attention 74 | norm_input_x = self.norm_self_attention(input_x) 75 | self_attention_output = self.self_attention(norm_input_x) 76 | res_output = input_x + self_attention_output 77 | res_output = self.drop_self_attention(res_output) 78 | # feed_forward 79 | norm_input_x = self.norm_feed_forward(res_output) 80 | feed_forward_output = self.feed_forward(norm_input_x) 81 | res_output = res_output + feed_forward_output 82 | return res_output 83 | -------------------------------------------------------------------------------- /tasks/data_preprocess_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_preprocess_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/23 8 | """ 9 | import os 10 | import pickle 11 | import shutil 12 | 13 | import tensorflow as tf 14 | import tqdm 15 | from waymo_open_dataset.protos import scenario_pb2 16 | 17 | from common import LoadConfigResultDate, TaskType 18 | from tasks.base_task import BaseTask 19 | from utils import DataUtil 20 | 21 | 22 | class DataPreprocessTask(BaseTask): 23 | def __init__(self): 24 | super(DataPreprocessTask, self).__init__() 25 | self.task_type = TaskType.DATA_PREPROCESS 26 | 27 | def execute(self, result_info: LoadConfigResultDate): 28 | # 清空原始路径 29 | data_preprocess_dir = result_info.task_config.data_preprocess_dir 30 | if os.path.exists(data_preprocess_dir): 31 | shutil.rmtree(data_preprocess_dir) 32 | os.makedirs(data_preprocess_dir, exist_ok=True) 33 | self.load_waymo_train_data(result_info) 34 | 35 | @staticmethod 36 | def check_waymo_dir(result_info: LoadConfigResultDate): 37 | # 检查waymo数据路径 38 | task_config = result_info.task_config 39 | dataset_names = ["train_set", "val_set", "test_set"] 40 | dataset_dirs = [task_config.waymo_train_dir, 41 | task_config.waymo_val_dir, 42 | task_config.waymo_test_dir] 43 | for dataset_name, dataset_dir in zip(dataset_names, dataset_dirs): 44 | if not os.path.isdir(dataset_dir) or not os.path.isdir(dataset_dir): 45 | error_info = f"waymo {dataset_name} error, {dataset_dir} is not a dir" 46 | result_info.task_logger.logger.error(error_info) 47 | raise ValueError(error_info) 48 | if len(os.listdir(dataset_dir)) == 0 and dataset_names == "train_set": 49 | warn_info = f"waymo {dataset_name} warn, {dataset_dir} size = 0" 50 | result_info.task_logger.logger.warn(warn_info) 51 | 52 | @classmethod 53 | def load_waymo_train_data(cls, result_info: LoadConfigResultDate): 54 | # 读取参数 55 | data_size = result_info.data_preprocess_config.data_size 56 | max_data_size = result_info.data_preprocess_config.max_data_size 57 | waymo_train_dir = result_info.task_config.waymo_train_dir 58 | preprocess_dir = result_info.task_config.data_preprocess_dir 59 | # 加载数据 60 | file_names = os.path.join(waymo_train_dir, "*training.*") 61 | match_filenames = tf.io.matching_files(file_names) 62 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data") 63 | dataset_iterator = dataset.as_numpy_iterator() 64 | bar = tqdm.tqdm(dataset_iterator, desc="load waymo train data: ") 65 | all_data = list() 66 | result_info.task_logger.logger.info("load_waymo_train_data start") 67 | number = 0 68 | for index, data in enumerate(bar): 69 | scenario = scenario_pb2.Scenario.FromString(data) 70 | data_dict = DataUtil.load_scenario_data(scenario) 71 | if len(data_dict) == 0: 72 | result_info.task_logger.logger.warn(f"scenario: {index} obs track is none") 73 | continue 74 | all_data.append(data_dict) 75 | number += 1 76 | if number % data_size == 0: 77 | file_name = f"result_{number}.pkl" 78 | with open(os.path.join(preprocess_dir, file_name), 'wb') as file: 79 | pickle.dump(all_data, file) 80 | all_data.clear() 81 | result_info.task_logger.logger.warn(f"file: {file_name} save success") 82 | if number > max_data_size: 83 | break 84 | result_info.task_logger.logger.info("load_waymo_train_data success") 85 | -------------------------------------------------------------------------------- /net_works/back_bone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: back_bone.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | from typing import Dict 10 | 11 | import numpy as np 12 | import torch 13 | from torch import nn 14 | 15 | from net_works.diffusion import GaussianDiffusion 16 | from net_works.scene_encoder import SceneEncoder 17 | from net_works.traj_decoder import TrajDecoder 18 | from utils import MathUtil 19 | 20 | 21 | class MultiModalLoss(nn.Module): 22 | def __init__(self): 23 | super(MultiModalLoss, self).__init__() 24 | self.huber_loss = nn.HuberLoss(reduction="none") 25 | self.confidence_loss = nn.CrossEntropyLoss(reduction="none") 26 | 27 | def forward(self, traj, confidence, predicted_future_traj, predicted_traj_mask): 28 | batch = traj.shape[0] 29 | obs_num = traj.shape[1] 30 | multimodal = traj.shape[2] 31 | future_step = traj.shape[3] 32 | output_dim = traj.shape[4] 33 | predicted_future_extend = predicted_future_traj.unsqueeze(dim=-3) 34 | loss = self.huber_loss(traj, predicted_future_extend) 35 | loss = loss.view(batch, obs_num, multimodal, -1) 36 | loss = torch.mean(loss, dim=-1) 37 | min_loss, _ = torch.min(loss, dim=-1) 38 | min_loss_modal = torch.argmin(loss, dim=-1) 39 | min_loss_index = min_loss_modal.view(batch, obs_num, 1, 1, 1).repeat( 40 | (1, 1, 1, future_step, output_dim) 41 | ) 42 | min_loss_traj = torch.gather(traj, -3, min_loss_index).squeeze() 43 | confidence = confidence.view(-1, multimodal) 44 | min_loss_modal = min_loss_modal.view(-1) 45 | confidence_loss = self.confidence_loss(confidence, min_loss_modal).view(batch, obs_num) 46 | traj_loss = torch.sum(min_loss * predicted_traj_mask) / (torch.sum(predicted_traj_mask) + 0.00001) 47 | confidence_loss = torch.sum(confidence_loss * predicted_traj_mask) / (torch.sum(predicted_traj_mask) + 0.00001) 48 | return traj_loss, confidence_loss, min_loss_traj 49 | 50 | 51 | class BackBone(nn.Module): 52 | def __init__(self, betas: np.ndarray): 53 | super(BackBone, self).__init__() 54 | self.diffusion = GaussianDiffusion(betas=betas) 55 | self.scene_encoder = SceneEncoder() 56 | self.traj_decoder = TrajDecoder() 57 | self.multi_modal_loss = MultiModalLoss() 58 | 59 | def forward(self, data: Dict): 60 | # batch, other_obs(10), 40, 7 61 | predicted_feature = data['predicted_feature'] 62 | # batch, other_obs(10), 40, 5 63 | other_his_pos = data['other_his_pos'] 64 | other_his_traj_delt = data['other_his_traj_delt'] 65 | other_feature = data['other_feature'] 66 | other_traj_mask = data['other_traj_mask'] 67 | # batch, pred_obs(15), 40, 5 68 | predicted_his_pos = data['predicted_his_pos'] 69 | predicted_his_traj_delt = data['predicted_his_traj_delt'] 70 | predicted_his_traj = data['predicted_his_traj'] 71 | # batch, pred_obs(15), 50, 5 72 | predicted_future_traj = data['predicted_future_traj'] 73 | predicted_traj_mask = data['predicted_traj_mask'] 74 | # batch, tl_num(10), 2 75 | traffic_light = data['traffic_light'] 76 | traffic_light_pos = data['traffic_light_pos'] 77 | # batch, num_lane(32), num_point(128), 2 78 | lane_list = data['lane_list'] 79 | # diffusion训练 80 | noise = torch.randn_like(predicted_his_traj_delt) 81 | diffusion_loss = self.diffusion(data) 82 | noise = self.diffusion.sample(noise, predicted_his_traj) 83 | # scene encoder 84 | scene_feature = self.scene_encoder( 85 | noise, lane_list, 86 | other_his_traj_delt, other_his_pos, other_feature, 87 | predicted_his_traj_delt, predicted_his_pos, predicted_feature, 88 | traffic_light, traffic_light_pos 89 | ) 90 | # traj_decoder 91 | traj, confidence = self.traj_decoder(scene_feature) 92 | traj = MathUtil.post_process_output(traj, predicted_his_traj) 93 | traj_loss, confidence_loss, min_loss_traj = self.multi_modal_loss(traj, confidence, predicted_future_traj, 94 | predicted_traj_mask) 95 | return diffusion_loss, traj_loss, confidence_loss, min_loss_traj 96 | -------------------------------------------------------------------------------- /net_works/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: attention.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class MultiHeadSelfAttention(nn.Module): 16 | def __init__(self, input_dim: int, head_dim: int = 64, num_heads: int = 8): 17 | super(MultiHeadSelfAttention, self).__init__() 18 | self.embed_dim = head_dim * num_heads 19 | self.num_heads = num_heads 20 | # We assume d_v is the same as d_k here aka head_dim == embed_dim // num_heads 21 | self.head_dim = head_dim 22 | # Learnable parameters 23 | self.query_proj = nn.Linear(input_dim, self.embed_dim) 24 | self.key_proj = nn.Linear(input_dim, self.embed_dim) 25 | self.value_proj = nn.Linear(input_dim, self.embed_dim) 26 | self.out_proj = nn.Linear(self.embed_dim, input_dim) 27 | 28 | def forward(self, input_feature, attn_mask=None): 29 | batch_size, obs_num = input_feature.shape[0], input_feature.shape[1] 30 | # Linear transformations 31 | q = self.query_proj(input_feature) 32 | k = self.key_proj(input_feature) 33 | v = self.value_proj(input_feature) 34 | 35 | # Split the embedding dimension into number of heads 36 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim) 37 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim) 38 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim) 39 | 40 | # Transpose to (batch_size, heads, seq_len, head_dim) 41 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 42 | 43 | # Calculate attention scores 44 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) 45 | if attn_mask is not None: 46 | scores = scores.masked_fill(attn_mask, -float('inf')) 47 | 48 | # Apply attention score scaling and softmax 49 | attn = nn.functional.softmax(scores, dim=-1) 50 | 51 | # Multiply attention scores with value and sum to get the final output 52 | out = torch.matmul(attn, v) 53 | 54 | # Transpose back to (batch_size, seq_len, heads, head_dim) 55 | out = out.transpose(1, 2).contiguous().reshape(batch_size, obs_num, self.embed_dim) 56 | 57 | # Final linear transformation 58 | return self.out_proj(out) 59 | 60 | 61 | class MultiHeadCrossAttention(nn.Module): 62 | def __init__(self, input_q_dim: int, input_kv_dim: int, head_dim: int = 64, num_heads: int = 8): 63 | super(MultiHeadCrossAttention, self).__init__() 64 | self.embed_dim = head_dim * num_heads 65 | self.num_heads = num_heads 66 | # We assume d_v is the same as d_k here aka head_dim == embed_dim // num_heads 67 | self.head_dim = head_dim 68 | # Learnable parameters 69 | self.query_proj = nn.Linear(input_q_dim, self.embed_dim) 70 | self.key_proj = nn.Linear(input_kv_dim, self.embed_dim) 71 | self.value_proj = nn.Linear(input_kv_dim, self.embed_dim) 72 | self.out_proj = nn.Linear(self.embed_dim, input_q_dim) 73 | 74 | def forward(self, input_q, input_kv, attn_mask=None): 75 | batch_size, obs_num = input_q.shape[0], input_q.shape[1] 76 | # Linear transformations 77 | q = self.query_proj(input_q) 78 | k = self.key_proj(input_kv) 79 | v = self.value_proj(input_kv) 80 | 81 | # Split the embedding dimension into number of heads 82 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim) 83 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim) 84 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim) 85 | 86 | # Transpose to (batch_size, heads, seq_len, head_dim) 87 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 88 | 89 | # Calculate attention scores 90 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) 91 | if attn_mask is not None: 92 | scores = scores.masked_fill(attn_mask, -float('inf')) 93 | 94 | # Apply attention score scaling and softmax 95 | attn = nn.functional.softmax(scores, dim=-1) 96 | 97 | # Multiply attention scores with value and sum to get the final output 98 | out = torch.matmul(attn, v) 99 | 100 | # Transpose back to (batch_size, seq_len, heads, head_dim) 101 | out = out.transpose(1, 2).contiguous().reshape(batch_size, obs_num, self.embed_dim) 102 | 103 | # Final linear transformation 104 | return self.out_proj(out) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: WcDT 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 10 | - defaults 11 | dependencies: 12 | - _libgcc_mutex=0.1=conda_forge 13 | - _openmp_mutex=4.5=2_gnu 14 | - bzip2=1.0.8=hd590300_5 15 | - ca-certificates=2023.11.17=hbcca054_0 16 | - ld_impl_linux-64=2.40=h41732ed_0 17 | - libffi=3.4.2=h7f98852_5 18 | - libgcc-ng=13.2.0=h807b86a_3 19 | - libgomp=13.2.0=h807b86a_3 20 | - libnsl=2.0.1=hd590300_0 21 | - libsqlite=3.44.2=h2797004_0 22 | - libuuid=2.38.1=h0b41bf4_0 23 | - libzlib=1.2.13=hd590300_5 24 | - ncurses=6.4=h59595ed_2 25 | - openssl=3.2.0=hd590300_1 26 | - pip=23.3.1=pyhd8ed1ab_0 27 | - python=3.10.1=h543edf9_2_cpython 28 | - readline=8.2=h8228510_1 29 | - sqlite=3.44.2=h2c6b66d_0 30 | - tk=8.6.13=noxft_h4845f30_101 31 | - tzdata=2023c=h71feb2d_0 32 | - wheel=0.42.0=pyhd8ed1ab_0 33 | - xz=5.2.6=h166bdaf_0 34 | - pip: 35 | - absl-py==1.4.0 36 | - array-record==0.5.0 37 | - astunparse==1.6.3 38 | - cachetools==5.3.2 39 | - certifi==2023.11.17 40 | - charset-normalizer==3.3.2 41 | - click==8.1.7 42 | - cloudpickle==3.0.0 43 | - cmake==3.27.9 44 | - contourpy==1.2.0 45 | - cycler==0.12.1 46 | - dask==2023.3.1 47 | - decorator==5.1.1 48 | - dm-tree==0.1.8 49 | - einops==0.7.0 50 | - einsum==0.3.0 51 | - etils==1.5.2 52 | - filelock==3.13.1 53 | - flatbuffers==23.5.26 54 | - fonttools==4.46.0 55 | - fsspec==2023.12.1 56 | - gast==0.4.0 57 | - google-auth==2.16.2 58 | - google-auth-oauthlib==0.4.6 59 | - google-pasta==0.2.0 60 | - googleapis-common-protos==1.62.0 61 | - grpcio==1.60.0 62 | - h5py==3.10.0 63 | - idna==3.6 64 | - imageio==2.33.0 65 | - immutabledict==2.2.0 66 | - importlib-resources==6.1.1 67 | - jinja2==3.1.2 68 | - joblib==1.3.2 69 | - keras==2.11.0 70 | - kiwisolver==1.4.5 71 | - lazy-loader==0.3 72 | - libclang==16.0.6 73 | - lit==17.0.6 74 | - locket==1.0.0 75 | - markdown==3.5.1 76 | - markupsafe==2.1.3 77 | - matplotlib==3.6.1 78 | - mpmath==1.3.0 79 | - networkx==3.2.1 80 | - numpy==1.23.0 81 | - nvidia-cublas-cu11==11.10.3.66 82 | - nvidia-cuda-cupti-cu11==11.7.101 83 | - nvidia-cuda-nvrtc-cu11==11.7.99 84 | - nvidia-cuda-runtime-cu11==11.7.99 85 | - nvidia-cudnn-cu11==8.5.0.96 86 | - nvidia-cufft-cu11==10.9.0.58 87 | - nvidia-curand-cu11==10.2.10.91 88 | - nvidia-cusolver-cu11==11.4.0.1 89 | - nvidia-cusparse-cu11==11.7.4.91 90 | - nvidia-nccl-cu11==2.14.3 91 | - nvidia-nvtx-cu11==11.7.91 92 | - oauthlib==3.2.2 93 | - opencv-python==4.8.0.76 94 | - openexr==1.3.9 95 | - opt-einsum==3.3.0 96 | - packaging==23.2 97 | - pandas==1.5.3 98 | - partd==1.4.1 99 | - pillow==9.2.0 100 | - plotly==5.13.1 101 | - promise==2.3 102 | - protobuf==3.19.6 103 | - psutil==5.9.6 104 | - pyarrow==10.0.0 105 | - pyasn1==0.5.1 106 | - pyasn1-modules==0.3.0 107 | - pyparsing==3.1.1 108 | - python-dateutil==2.8.2 109 | - pytz==2023.3.post1 110 | - pywavelets==1.4.1 111 | - pyyaml==6.0.1 112 | - requests==2.31.0 113 | - requests-oauthlib==1.3.1 114 | - rsa==4.9 115 | - scikit-image==0.20.0 116 | - scikit-learn==1.2.2 117 | - scipy==1.10.1 118 | - setuptools==67.6.0 119 | - six==1.16.0 120 | - sympy==1.12 121 | - tenacity==8.2.3 122 | - tensorboard==2.11.2 123 | - tensorboard-data-server==0.6.1 124 | - tensorboard-plugin-wit==1.8.1 125 | - tensorflow==2.11.0 126 | - tensorflow-addons==0.23.0 127 | - tensorflow-datasets==4.9.0 128 | - tensorflow-estimator==2.11.0 129 | - tensorflow-graphics==2021.12.3 130 | - tensorflow-io-gcs-filesystem==0.34.0 131 | - tensorflow-metadata==1.13.0 132 | - tensorflow-probability==0.19.0 133 | - termcolor==2.4.0 134 | - threadpoolctl==3.2.0 135 | - tifffile==2023.12.9 136 | - toml==0.10.2 137 | - toolz==0.12.0 138 | - torch==2.0.0 139 | - torchaudio==2.0.1 140 | - torchvision==0.15.1 141 | - tqdm==4.66.1 142 | - trimesh==4.0.5 143 | - triton==2.0.0 144 | - typeguard==2.13.3 145 | - typing-extensions==4.9.0 146 | - urllib3==2.1.0 147 | - waymo-open-dataset-tf-2-11-0==1.5.2 148 | - werkzeug==3.0.1 149 | - wrapt==1.16.0 150 | - zipp==3.17.0 151 | prefix: /home/haomo/anaconda3/envs/TSDiT 152 | -------------------------------------------------------------------------------- /net_works/scene_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: scene_encoder.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | import torch 10 | from torch import nn 11 | 12 | from net_works.transformer import TransformerCrossAttention, TransformerSelfAttention 13 | 14 | 15 | class OtherFeatureFormer(nn.Module): 16 | def __init__( 17 | self, block_num: int, input_dim: int, conditional_dim: int, 18 | head_dim: int = 64, num_heads: int = 8 19 | ): 20 | super(OtherFeatureFormer, self).__init__() 21 | self.blocks = nn.ModuleList([]) 22 | for _ in range(block_num): 23 | self.blocks.append(TransformerCrossAttention(input_dim, conditional_dim, head_dim, num_heads)) 24 | self.norm = nn.LayerNorm(input_dim) 25 | 26 | def forward(self, input_x, other_feature): 27 | for block in self.blocks: 28 | input_x = block(input_x, other_feature) 29 | return self.norm(input_x) 30 | 31 | 32 | class SelfFeatureFormer(nn.Module): 33 | def __init__(self, block_num: int, input_dim: int, head_dim: int = 64, num_heads: int = 8): 34 | super(SelfFeatureFormer, self).__init__() 35 | self.blocks = nn.ModuleList([]) 36 | for _ in range(block_num): 37 | self.blocks.append(TransformerSelfAttention(input_dim, head_dim, num_heads)) 38 | self.norm = nn.LayerNorm(input_dim) 39 | 40 | def forward(self, input_x): 41 | for block in self.blocks: 42 | input_x = block(input_x) 43 | return self.norm(input_x) 44 | 45 | 46 | class SceneEncoder(nn.Module): 47 | def __init__( 48 | self, dim: int = 256, embedding_dim: int = 32, 49 | his_step: int = 11, other_agent_depth: int = 4, 50 | map_feature_depth: int = 4, traffic_light_depth: int = 2, 51 | self_attention_depth: int = 4 52 | ): 53 | super(SceneEncoder, self).__init__() 54 | self.embedding_dim = embedding_dim 55 | self.pos_embedding = nn.Sequential( 56 | nn.Linear(2, dim), 57 | nn.LayerNorm(dim), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(dim, dim), 60 | nn.LayerNorm(dim), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(dim, embedding_dim) 63 | ) 64 | self.feature_embedding = nn.Sequential( 65 | nn.Linear(7, dim), 66 | nn.LayerNorm(dim), 67 | nn.ReLU(inplace=True), 68 | nn.Linear(dim, dim), 69 | nn.LayerNorm(dim), 70 | nn.ReLU(inplace=True), 71 | nn.Linear(dim, embedding_dim) 72 | ) 73 | linear_input_dim = (his_step - 1) * 5 + embedding_dim + embedding_dim 74 | self.linear_input = nn.Sequential( 75 | nn.Linear(linear_input_dim, dim), 76 | nn.LayerNorm(dim), 77 | nn.ReLU(inplace=True), 78 | nn.Linear(dim, dim) 79 | ) 80 | self.other_agent_former = OtherFeatureFormer(block_num=other_agent_depth, input_dim=dim, conditional_dim=114) 81 | self.map_former = OtherFeatureFormer(block_num=map_feature_depth, input_dim=dim, conditional_dim=embedding_dim) 82 | self.traffic_light_former = OtherFeatureFormer(block_num=traffic_light_depth, input_dim=dim, 83 | conditional_dim=43) 84 | self.fusion_block = SelfFeatureFormer(block_num=self_attention_depth, input_dim=dim, num_heads=16) 85 | 86 | def forward( 87 | self, noise, lane_list, 88 | other_his_traj_delt, other_his_pos, other_feature, 89 | predicted_his_traj_delt, predicted_his_pos, predicted_feature, 90 | traffic_light, traffic_light_pos 91 | ): 92 | batch_size, obs_num = noise.shape[0], noise.shape[1] 93 | # batch, obs_num(8), his_step, 3 94 | x = predicted_his_traj_delt + (noise * 0.001) 95 | x = torch.flatten(x, start_dim=2) 96 | other_his_traj_delt = torch.flatten(other_his_traj_delt, start_dim=2) 97 | # 对各个位置进行位置编码 98 | lane_list = self.pos_embedding(lane_list) 99 | lane_list = lane_list.view(batch_size, -1, self.embedding_dim) 100 | traffic_light_pos = self.pos_embedding(traffic_light_pos) 101 | other_his_pos = self.pos_embedding(other_his_pos) 102 | predicted_his_pos = self.pos_embedding(predicted_his_pos) 103 | # 对属性进行编码 104 | other_feature = self.feature_embedding(other_feature) 105 | predicted_feature = self.feature_embedding(predicted_feature) 106 | # 组合输入信息 107 | x = torch.cat((x, predicted_his_pos, predicted_feature), dim=-1) 108 | # batch, obs_num(15), 256 109 | x = self.linear_input(x) 110 | # other agent former 111 | other_obs_feature = torch.cat((other_his_traj_delt, other_his_pos, other_feature), dim=-1) 112 | x = self.other_agent_former(x, other_obs_feature) 113 | # map_point_transformer 114 | x = self.map_former(x, lane_list) 115 | # traffic_light_transformer 116 | traffic_light = torch.cat((traffic_light, traffic_light_pos), dim=-1) 117 | x = self.traffic_light_former(x, traffic_light) 118 | x = self.fusion_block(x) 119 | return x 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WcDT: World-centric Diffusion Transformer Traffic Scene Generation 2 | 3 | This repository contains the official implementation of WcDT: World-centric Diffusion Transformer for Traffic Scene Generation [[Paper Link](https://arxiv.org/abs/2404.02082)] 4 | 5 | ## Gettting Started 6 | 7 | First of all, we recommend that you read the description of the [Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/) and the [Motion Prediction](https://waymo.com/open/data/motion/) dataset on the waymo website. 8 | 9 | 1. Clone this repository: 10 | 11 | git clone https://github.com/yangchen1997/WcDT.git 12 | 13 | 2. Install the dependencies: 14 | 15 | conda env create -f environment.yml 16 | 17 | 3. Download [Waymo Open Dataset](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_2_0?pli=1)(Please note: If you are downloading waymo datasets for the first time, you need to click "Download" on the waymo website and register your account).After downloading the dataset directory should be organized as follows: 18 | 19 | /path/to/dataset_root/ 20 | ├── train_set/ 21 | ├── training.tfrecord-00000-of-01000 22 | ├── training.tfrecord-00001-of-01000 23 | ├── ... 24 | └── val_set/ 25 | ├── validation.tfrecord-00000-of-00150 26 | ├── validation.tfrecord-00001-of-00150 27 | ├── ... 28 | └── test_set/ 29 | ├── testing.tfrecord-00000-of-00150 30 | ├── testing.tfrecord-00001-of-00150 31 | ├── ... 32 | 33 | 34 | ## Tasks 35 | 36 | This project includes the following tasks: 37 | 38 | 1. Data Preprocess: This task is divided into two subtasks, data compression (which removes redundancy from the waymo dataset) and data splitting. 39 | 40 | 2. Training 41 | 42 | 3. Evaluating Models and Visualising Results 43 | 44 | 45 | Before running the project, you need to configure the tasks to be performed in config.yaml. 46 | 47 | tasks config: 48 | 49 | task_config: 50 | task_list: 51 | - "DATA_PREPROCESS" 52 | - "DATA_SPLIT" 53 | - "DATA_COUNT" 54 | - "TRAIN_MODEL" 55 | - "SHOW_RESULTS" 56 | - "EVAL_MODEL" 57 | - "GENE_SUBMISSION" 58 | output_dir: "output" 59 | log_dir: "log" 60 | image_dir: "result_image" 61 | model_dir: "model" 62 | result_dir: "result" 63 | pre_train_model: "" 64 | waymo_train_dir: "path to waymo train_set" 65 | waymo_val_dir: "path to waymo valid_set" 66 | waymo_test_dir: "path to waymo test_set" 67 | data_output: "data_output" 68 | data_preprocess_dir: "data_preprocess_dir" 69 | train_dir: "train_dir" 70 | val_dir: "val_dir" 71 | test_dir: "test_dir" 72 | 73 | start tasks: 74 | 75 | bash run_main.sh 76 | 77 | ### Data Preprocess 78 | 79 | task config: 80 | 81 | data_preprocess_config: 82 | data_size: 100 83 | max_data_size: 2000 84 | num_works: 20 85 | 86 | ### Training 87 | 88 | task config: 89 | 90 | train_model_config: 91 | use_gpu: False 92 | gpu_ids: 93 | - 6 94 | - 7 95 | batch_size: 4 96 | num_works: 0 97 | his_step: 11 98 | max_pred_num: 8 99 | max_other_num: 6 100 | max_traffic_light: 8 101 | max_lane_num: 32 102 | max_point_num: 128 103 | num_head: 8 104 | attention_dim: 128 105 | multimodal: 10 106 | time_steps: 50 107 | # cosine or linear 108 | schedule: "linear" 109 | num_epoch: 200 110 | init_lr: 0.0001 111 | 112 | ### Evaluate 113 | 114 | | Model | ADE↓ | MinADE↓ | 115 | | --- | --- | --- | 116 | | WcDT-64 | 4.872  | 1.962 | 117 | | WcDT-128 | 4.563 | 1.669 | 118 | 119 | ### Qualitative Results 120 | 121 | Demos for lane-changing scenarios: 122 | 123 | 124 | 131 | 138 | 139 |
125 |

126 | First Image 127 |
128 | Ground truth 129 |

130 |
132 |

133 | Second Image 134 |
135 | WcDT-128 result 136 |

137 |
140 | 141 | 142 | Demos for more complex turning scenarios: 143 | 144 | 145 | 152 | 159 | 160 |
146 |

147 | First Image 148 |
149 | Ground truth 150 |

151 |
153 |

154 | Second Image 155 |
156 | WcDT-128 result 157 |

158 |
161 | 162 | ## Todo List 163 | 164 | - [x] Data Statistics 165 | - [x] Generate Submission 166 | - [ ] Factorized Attention for Temporal Features 167 | - [ ] Graph Attention Mechanisms for Transformer 168 | - [ ] Lane Loss(FDE Loss + Timestep Weighted Loss) 169 | - [ ] Scene Label 170 | - [ ] Upgrade Decoder(Prposed + Refined Trajectory) 171 | 172 | ## Citation 173 | 174 | If you found this repository useful, please consider citing our paper: 175 | 176 | ```bibtex 177 | @article{yang2024wcdt, 178 | title={Wcdt: World-centric diffusion transformer for traffic scene generation}, 179 | author={Yang, Chen and He, Yangfan and Tian, Aaron Xuxiang and Chen, Dong and Wang, Jianhui and Shi, Tianyu and Heydarian, Arsalan and Liu, Pei}, 180 | journal={arXiv preprint arXiv:2404.02082}, 181 | year={2024} 182 | } 183 | ``` 184 | 185 | 186 | -------------------------------------------------------------------------------- /tasks/train_model_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: train_model_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/26 8 | """ 9 | import os 10 | import shutil 11 | from typing import Union 12 | 13 | import torch 14 | from torch import optim, nn 15 | from torch.optim import Optimizer 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | 19 | from common import TaskType, LoadConfigResultDate 20 | from common import WaymoDataset 21 | from net_works import BackBone 22 | from tasks import BaseTask 23 | from utils import MathUtil, VisualizeUtil 24 | 25 | 26 | class TrainModelTask(BaseTask): 27 | 28 | def __init__(self): 29 | super(TrainModelTask, self).__init__() 30 | self.task_type = TaskType.TRAIN_MODEL 31 | self.device = torch.device("cpu") 32 | self.multi_gpus = False 33 | self.gpu_ids = list() 34 | 35 | def execute(self, result_info: LoadConfigResultDate): 36 | train_dir = result_info.task_config.train_dir 37 | train_model_config = result_info.train_model_config 38 | self.init_dirs(result_info) 39 | # 初始化device 40 | if train_model_config.use_gpu: 41 | if train_model_config.gpu_ids: 42 | self.device = torch.device(f"cuda:{train_model_config.gpu_ids[0]}") 43 | self.multi_gpus = True 44 | self.gpu_ids = train_model_config.gpu_ids 45 | else: 46 | self.device = torch.device('cuda') 47 | # 初始化dataloader 48 | waymo_dataset = WaymoDataset( 49 | train_dir, train_model_config.his_step, train_model_config.max_pred_num, 50 | train_model_config.max_other_num, train_model_config.max_traffic_light, 51 | train_model_config.max_lane_num, train_model_config.max_point_num 52 | ) 53 | data_loader = DataLoader( 54 | waymo_dataset, 55 | shuffle=False, 56 | batch_size=train_model_config.batch_size, 57 | num_workers=train_model_config.num_works, 58 | pin_memory=True, 59 | drop_last=False 60 | ) 61 | model = self.init_model(result_info) 62 | model_train = model.train() 63 | optimizer = optim.Adam(model_train.parameters(), lr=train_model_config.init_lr, 64 | betas=(0.9, 0.999), weight_decay=0) 65 | epoch_step = len(waymo_dataset) // train_model_config.batch_size 66 | if epoch_step == 0: 67 | raise ValueError("dataset is too small, epoch_step = 0") 68 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, train_model_config.num_epoch, eta_min=1e-6) 69 | for epoch in range(train_model_config.num_epoch): 70 | self.fit_one_epoch(epoch, epoch_step, optimizer, model, data_loader, result_info) 71 | scheduler.step() 72 | 73 | def fit_one_epoch( 74 | self, epoch_num: int, epoch_step: int, 75 | optimizer: Optimizer, model: Union[BackBone, nn.DataParallel], 76 | data_loader: DataLoader, result_info: LoadConfigResultDate 77 | ): 78 | diffusion_losses = 0 79 | traj_losses = 0 80 | confidence_losses = 0 81 | total_losses = 0 82 | train_model_config = result_info.train_model_config 83 | pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch_num + 1}/{train_model_config.num_epoch}', mininterval=0.3) 84 | torch.autograd.set_detect_anomaly(True) 85 | for iteration, data in enumerate(data_loader): 86 | for key, value in data.items(): 87 | if isinstance(value, torch.Tensor): 88 | data[key] = value.to(self.device).to(torch.float32) 89 | optimizer.zero_grad() 90 | diffusion_loss, traj_loss, confidence_loss, min_loss_traj = model(data) 91 | diffusion_loss = diffusion_loss.mean() 92 | traj_loss = traj_loss.mean() 93 | confidence_loss = confidence_loss.mean() 94 | total_loss = diffusion_loss + traj_loss + confidence_loss 95 | total_loss.backward() 96 | optimizer.step() 97 | total_losses += total_loss.item() 98 | diffusion_losses += diffusion_loss.item() 99 | traj_losses += traj_loss.item() 100 | confidence_losses += confidence_loss.item() 101 | pbar.set_postfix(**{'total_loss': total_losses / (iteration + 1), 102 | 'diffusion_loss': diffusion_losses / (iteration + 1), 103 | 'traj_losses': traj_losses / (iteration + 1), 104 | 'confidence_losses': confidence_losses / (iteration + 1)}) 105 | pbar.update() 106 | if iteration % 10 == 0: 107 | image_path = os.path.join(result_info.task_config.image_dir, 108 | f"epoch_{epoch_num}_batch_num_{iteration}_image.png") 109 | VisualizeUtil.show_result(image_path, min_loss_traj, data) 110 | 111 | @staticmethod 112 | def init_dirs(result_info: LoadConfigResultDate): 113 | task_config = result_info.task_config 114 | if os.path.exists(task_config.image_dir): 115 | shutil.rmtree(task_config.image_dir) 116 | os.makedirs(task_config.image_dir, exist_ok=True) 117 | os.makedirs(task_config.model_dir, exist_ok=True) 118 | 119 | def init_model(self, result_info: LoadConfigResultDate) -> Union[BackBone, nn.DataParallel]: 120 | train_model_config = result_info.train_model_config 121 | task_config = result_info.task_config 122 | # 初始化diffusion的betas 123 | if train_model_config.schedule == "cosine": 124 | betas = MathUtil.generate_cosine_schedule(train_model_config.time_steps) 125 | else: 126 | schedule_low = 1e-4 127 | schedule_high = 0.008 128 | betas = MathUtil.generate_linear_schedule( 129 | train_model_config.time_steps, 130 | schedule_low * 1000 / train_model_config.time_steps, 131 | schedule_high * 1000 / train_model_config.time_steps, 132 | ) 133 | model = BackBone(betas) 134 | if task_config.pre_train_model: 135 | pre_train_model_path = task_config.pre_train_model 136 | model_dict = model.state_dict() 137 | pretrained_dict = torch.load(pre_train_model_path) 138 | # 模型参数赋值 139 | new_model_dict = dict() 140 | for key in model_dict.keys(): 141 | if ("module." + key) in pretrained_dict: 142 | new_model_dict[key] = pretrained_dict["module." + key] 143 | elif key in pretrained_dict: 144 | new_model_dict[key] = pretrained_dict[key] 145 | else: 146 | print("key: ", key, ", not in pretrained") 147 | model.load_state_dict(new_model_dict) 148 | result_info.task_logger.logger.info("load pre_train_model success") 149 | model = model.to(self.device) 150 | if self.multi_gpus: 151 | model = nn.DataParallel(model, device_ids=self.gpu_ids, output_device=self.gpu_ids[0]) 152 | return model 153 | -------------------------------------------------------------------------------- /tasks/load_config_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: load_config_task.py 6 | @Author: YangChen 7 | @Date: 2023/12/20 8 | """ 9 | import os 10 | import random 11 | from datetime import datetime 12 | 13 | import torch 14 | import yaml 15 | 16 | from common import TaskType, LoadConfigResultDate, TaskLogger 17 | from common import TaskConfig, DataPreprocessConfig, TrainModelConfig 18 | from tasks.base_task import BaseTask 19 | 20 | # 可能在不同的地方调用ConfigFactory,因此优先初始化config_file 21 | path_list = os.path.realpath(__file__).split(os.path.sep) 22 | path_list = path_list[:-2] 23 | path_list.append("config.yaml") 24 | CONFIG_PATH = os.path.sep.join(path_list) 25 | 26 | 27 | class LoadConfigTask(BaseTask): 28 | """ 29 | 主要功能: 30 | 加载配置文件 31 | 检查config是否合理 32 | 根据config初始化output 33 | 根据config初始化log 34 | """ 35 | 36 | def __init__(self): 37 | super(LoadConfigTask, self).__init__() 38 | self.task_type = TaskType.LOAD_CONFIG 39 | with open(CONFIG_PATH, "rb") as file: 40 | self.__yaml_loader = yaml.safe_load(file) 41 | 42 | def execute(self, result_info: LoadConfigResultDate): 43 | # 初始化结果信息 44 | task_config = self.get_task_config() 45 | data_preprocess_config = self.get_preprocess_config() 46 | train_model_config = self.get_train_model_config() 47 | self.init_dirs_and_log(task_config, result_info) 48 | self.check_preprocess_config(data_preprocess_config) 49 | self.check_train_model_config(train_model_config) 50 | # 结果信息赋值 51 | result_info.task_config = task_config 52 | result_info.data_preprocess_config = data_preprocess_config 53 | result_info.train_model_config = train_model_config 54 | 55 | def get_task_config(self) -> TaskConfig: 56 | task_config = TaskConfig() 57 | self.__init_config_object_attr(task_config, self.__yaml_loader['task_config']) 58 | return task_config 59 | 60 | def get_preprocess_config(self) -> DataPreprocessConfig: 61 | data_preprocess_config = DataPreprocessConfig() 62 | self.__init_config_object_attr(data_preprocess_config, self.__yaml_loader['data_preprocess_config']) 63 | return data_preprocess_config 64 | 65 | def get_train_model_config(self) -> TrainModelConfig: 66 | train_model_config = TrainModelConfig() 67 | self.__init_config_object_attr(train_model_config, self.__yaml_loader['train_model_config']) 68 | return train_model_config 69 | 70 | @staticmethod 71 | def __init_config_object_attr(instance: object, attrs: dict): 72 | """ 73 | 根据字典给配置类对象赋值 74 | @param instance: 配置类对象 75 | @param attrs: 从配置文件中读取的属性 76 | @return: None 77 | """ 78 | if not instance or not attrs: 79 | return 80 | for name, value in attrs.items(): 81 | if hasattr(instance, name): 82 | setattr(instance, name, value) 83 | else: 84 | raise ValueError(f"unknown config, config name is {name}") 85 | 86 | @staticmethod 87 | def init_dirs_and_log(task_config: TaskConfig, result_data: LoadConfigResultDate): 88 | """ 89 | 检查task_config的参数并初始化路径 90 | @param task_config: 91 | @param result_data 92 | @return: 93 | """ 94 | # output路径检查 95 | os.makedirs(task_config.output_dir, exist_ok=True) 96 | task_config.log_dir = os.path.join(task_config.output_dir, task_config.log_dir) 97 | os.makedirs(task_config.log_dir, exist_ok=True) 98 | # 创建log 99 | now = datetime.now() 100 | formatted_now = now.strftime('%Y%m%d-%H-%M-%S') 101 | random_num = random.randint(10000, 99999) 102 | result_data.task_id = f"{formatted_now}_{random_num}" 103 | log_file_name = os.path.join(task_config.log_dir, f"{result_data.task_id}.log") 104 | result_data.task_logger = TaskLogger(log_file_name) 105 | result_data.task_logger.logger.info(f"task id {result_data.task_id} start") 106 | # task list里的任务必须唯一 107 | if len(task_config.task_list) != len(set(task_config.task_list)): 108 | raise ValueError("task_config must be unique") 109 | # 创建model save dir 110 | task_config.model_dir = os.path.join(task_config.output_dir, task_config.model_dir) 111 | # os.makedirs(task_config.model_dir, exist_ok=True) 112 | # 创建result dir 113 | task_config.result_dir = os.path.join(task_config.output_dir, task_config.result_dir) 114 | # 检查模型路径 115 | if task_config.pre_train_model: 116 | path_type = ".pth" 117 | if not os.path.isfile(task_config.pre_train_model) or \ 118 | not os.path.exists(task_config.pre_train_model) or \ 119 | task_config.pre_train_model[-len(path_type):] != path_type: 120 | raise ValueError("task_config.pre_train_model error") 121 | result_data.task_logger.logger.info(f"{task_config.pre_train_model} check success") 122 | else: 123 | result_data.task_logger.logger.warn("pre_train_model path is None") 124 | 125 | # 初始化图片保存路径 126 | task_config.image_dir = os.path.join( 127 | task_config.output_dir, 128 | task_config.image_dir, 129 | result_data.task_id 130 | ) 131 | # os.makedirs(task_config.image_dir, exist_ok=True) 132 | # 数据预处理输出 133 | os.makedirs(task_config.data_output, exist_ok=True) 134 | task_config.data_preprocess_dir = os.path.join(task_config.data_output, 135 | task_config.data_preprocess_dir) 136 | # os.makedirs(task_config.data_preprocess_dir, exist_ok=True) 137 | # 初始化训练集验证集测试集 138 | task_config.train_dir = os.path.join(task_config.data_output, 139 | task_config.train_dir) 140 | # os.makedirs(task_config.train_dir, exist_ok=True) 141 | task_config.val_dir = os.path.join(task_config.data_output, 142 | task_config.val_dir) 143 | # os.makedirs(task_config.val_dir, exist_ok=True) 144 | task_config.test_dir = os.path.join(task_config.data_output, 145 | task_config.test_dir) 146 | # os.makedirs(task_config.test_dir, exist_ok=True) 147 | result_data.task_logger.logger.info(str(task_config)) 148 | result_data.task_logger.logger.info("task config init success") 149 | 150 | @staticmethod 151 | def check_preprocess_config(data_preprocess_config: DataPreprocessConfig): 152 | if data_preprocess_config.num_works <= 0: 153 | raise ValueError(f"num_works = {data_preprocess_config.num_works}, cannot <= 0") 154 | 155 | @staticmethod 156 | def check_train_model_config(train_model_config: TrainModelConfig): 157 | if train_model_config.his_step <= 0 or train_model_config.his_step >= 91: 158 | raise ValueError(f"his_step {train_model_config.his_step} is out of range") 159 | if train_model_config.use_gpu and not torch.cuda.is_available(): 160 | raise ValueError(f"cuda is unavailable") 161 | if train_model_config.use_gpu and len(train_model_config.gpu_ids) > 0 and torch.cuda.device_count() <= 1: 162 | raise ValueError("only one gpu can used") 163 | if train_model_config.schedule not in ("cosine", "linear"): 164 | raise ValueError(f"schedule: {train_model_config.schedule}, is not in (cosine, linear)") 165 | -------------------------------------------------------------------------------- /common/waymo_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: waymo_dataset.py 6 | @Author: YangChen 7 | @Date: 2023/12/25 8 | """ 9 | 10 | import json 11 | import os 12 | import pickle 13 | from typing import Dict, Tuple, List 14 | 15 | import torch 16 | from torch.utils.data import Dataset 17 | 18 | 19 | class WaymoDataset(Dataset): 20 | 21 | def __init__( 22 | self, dataset_dir: str, his_step: int, max_pred_num: int, 23 | max_other_num: int, max_traffic_light: int, max_lane_num: int, 24 | max_point_num: int, max_data_size: int = -1 25 | ): 26 | super().__init__() 27 | self.__dataset_dir = dataset_dir 28 | self.__pkl_list = sorted(os.listdir(dataset_dir), 29 | key=lambda x: int(x[:-4].split('_')[-1])) 30 | if max_data_size != -1: 31 | self.__pkl_list = self.__pkl_list[:max_data_size] 32 | # 加载参数 33 | self.__max_pred_num = max_pred_num 34 | self.__max_other_num = max_other_num 35 | self.__max_traffic_light = max_traffic_light 36 | self.__max_lane_num = max_lane_num 37 | self.__max_point_num = max_point_num 38 | self.__his_step = his_step 39 | self.__future_step = 91 - his_step 40 | 41 | def __len__(self) -> int: 42 | return len(self.__pkl_list) 43 | 44 | def get_obs_feature( 45 | self, other_obs_index: List[int], 46 | all_obs_his_traj: torch.Tensor, 47 | all_obs_feature: torch.Tensor 48 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 49 | # other obs历史轨迹 50 | if len(other_obs_index) > 0: 51 | other_obs_traj_index = (torch.Tensor(other_obs_index) 52 | .to(torch.long).view(-1, 1, 1) 53 | .repeat((1, self.__his_step, 5))) 54 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index) 55 | other_obs_feature_index = (torch.Tensor(other_obs_index) 56 | .to(torch.long).view(-1, 1) 57 | .repeat((1, 7))) 58 | # other feature 59 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index) 60 | other_his_traj = other_his_traj[:self.__max_other_num] 61 | other_feature = other_feature[:self.__max_other_num] 62 | num_gap = self.__max_other_num - other_his_traj.shape[0] 63 | if num_gap > 0: 64 | other_his_traj = torch.cat((other_his_traj, 65 | torch.zeros(size=(num_gap, self.__his_step, 5), 66 | dtype=torch.float32)), dim=0) 67 | other_feature = torch.cat((other_feature, 68 | torch.zeros(size=(num_gap, 7), 69 | dtype=torch.float32)), dim=0) 70 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] * 71 | num_gap) 72 | else: 73 | other_his_traj = torch.zeros( 74 | size=(self.__max_other_num, self.__his_step, 5), 75 | dtype=torch.float32 76 | ) 77 | other_feature = torch.zeros(size=(self.__max_other_num, 7), dtype=torch.float32) 78 | other_traj_mask = torch.zeros(size=[self.__max_other_num], dtype=torch.float32) 79 | return other_his_traj, other_feature, other_traj_mask 80 | 81 | def get_pred_feature( 82 | self, predicted_index: torch.Tensor, 83 | all_obs_his_traj: torch.Tensor, 84 | all_obs_future_traj: torch.Tensor, 85 | all_obs_feature: torch.Tensor 86 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 87 | 88 | predicted_his_traj = torch.gather( 89 | all_obs_his_traj, 0, 90 | predicted_index.repeat((1, self.__his_step, 5)) 91 | ) 92 | predicted_future_traj = torch.gather( 93 | all_obs_future_traj, 0, 94 | predicted_index.repeat((1, self.__future_step, 5)) 95 | ) 96 | predicted_his_traj = predicted_his_traj[:self.__max_pred_num] 97 | predicted_future_traj = predicted_future_traj[:self.__max_pred_num] 98 | # predicted feature 99 | predicted_feature = torch.gather( 100 | all_obs_feature, 0, 101 | predicted_index.view(-1, 1).repeat((1, 7)) 102 | ) 103 | predicted_feature = predicted_feature[:self.__max_pred_num] 104 | num_gap = self.__max_pred_num - predicted_his_traj.shape[0] 105 | if num_gap > 0: 106 | predicted_his_traj = torch.cat((predicted_his_traj, 107 | torch.zeros(size=(num_gap, self.__his_step, 5), 108 | dtype=torch.float32)), dim=0) 109 | predicted_future_traj = torch.cat((predicted_future_traj, 110 | torch.zeros(size=(num_gap, self.__future_step, 5), 111 | dtype=torch.float32)), dim=0) 112 | predicted_feature = torch.cat((predicted_feature, 113 | torch.zeros(size=(num_gap, 7), 114 | dtype=torch.float32)), dim=0) 115 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj.shape[0] - num_gap) + 116 | [0.0] * num_gap) 117 | return predicted_future_traj, predicted_his_traj, predicted_feature, predicted_traj_mask 118 | 119 | def get_traffic_light( 120 | self, traffic_light: torch.Tensor, 121 | traffic_light_pos: torch.Tensor 122 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 123 | traffic_light = traffic_light[:self.__max_traffic_light] 124 | traffic_light_pos = traffic_light_pos[:self.__max_traffic_light] 125 | num_gap = self.__max_traffic_light - traffic_light.shape[0] 126 | if num_gap > 0: 127 | traffic_light = torch.cat((traffic_light, 128 | torch.zeros(size=(num_gap, self.__his_step), 129 | dtype=torch.float32)), dim=0) 130 | traffic_light_pos = torch.cat((traffic_light_pos, 131 | torch.zeros(size=(num_gap, 2), 132 | dtype=torch.float32)), dim=0) 133 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) + 134 | [0.0] * num_gap) 135 | return traffic_light, traffic_light_pos, traffic_mask 136 | 137 | def get_lane_feature(self, lane_feature: List) -> torch.Tensor: 138 | lane_list = list() 139 | for lane_index, lane in enumerate(lane_feature): 140 | if lane_index >= self.__max_lane_num: 141 | break 142 | point_list = lane[:self.__max_point_num] 143 | point_gap = self.__max_point_num - len(point_list) 144 | # MAX_POINT_NUM, 2 145 | point_list = torch.Tensor(point_list) 146 | if point_gap > 0: 147 | point_list = torch.cat((point_list, 148 | torch.zeros(size=(point_gap, 2), 149 | dtype=torch.float32)), dim=0) 150 | lane_list.append(point_list) 151 | lane_gap = self.__max_lane_num - len(lane_list) 152 | lane_list = torch.stack(lane_list, dim=0) 153 | if lane_gap > 0: 154 | lane_list = torch.cat( 155 | (lane_list, torch.zeros(size=(lane_gap, self.__max_point_num, 2), 156 | dtype=torch.float32)), dim=0 157 | ) 158 | return lane_list 159 | 160 | def __getitem__(self, idx) -> Dict: 161 | with open(os.path.join(self.__dataset_dir, self.__pkl_list[idx]), "rb") as f: 162 | result = pickle.load(f) 163 | all_obs_index = set([i for i in range(result[1].shape[0])]) 164 | predicted_index = torch.Tensor(result[0]).to(torch.long).view(-1, 1, 1) 165 | all_obs_traj = torch.Tensor(result[1]) 166 | all_obs_feature = torch.Tensor(result[2]) 167 | all_obs_his_traj = all_obs_traj[:, :self.__his_step] 168 | all_obs_future_traj = all_obs_traj[:, self.__his_step:] 169 | other_obs_index = list(all_obs_index - set(result[0])) 170 | # 获取障碍物信息 171 | other_his_traj, other_feature, other_traj_mask = self.get_obs_feature( 172 | other_obs_index, all_obs_his_traj, all_obs_feature 173 | ) 174 | # 需要预测的障碍物历史轨迹 175 | (predicted_future_traj, predicted_his_traj, 176 | predicted_feature, predicted_traj_mask) = self.get_pred_feature( 177 | predicted_index, all_obs_his_traj, 178 | all_obs_future_traj, all_obs_feature 179 | ) 180 | # 交通灯信息 181 | traffic_light = torch.Tensor(result[3]) 182 | traffic_light_pos = torch.Tensor(result[4]) 183 | traffic_light, traffic_light_pos, traffic_mask = self.get_traffic_light( 184 | traffic_light, traffic_light_pos 185 | ) 186 | # 车道线信息 187 | lane_list = self.get_lane_feature(result[-1]) 188 | map_json = json.dumps(result[-1]) 189 | other_his_traj_delt = other_his_traj[:, 1:] - other_his_traj[:, :-1] 190 | other_his_pos = other_his_traj[:, -1, :2] 191 | predicted_his_traj_delt = predicted_his_traj[:, 1:] - predicted_his_traj[:, :-1] 192 | predicted_his_pos = predicted_his_traj[:, -1, :2] 193 | result = { 194 | "other_his_traj": other_his_traj, 195 | "other_feature": other_feature, 196 | "other_traj_mask": other_traj_mask, 197 | "other_his_traj_delt": other_his_traj_delt, 198 | "other_his_pos": other_his_pos, 199 | "predicted_future_traj": predicted_future_traj, 200 | "predicted_his_traj": predicted_his_traj, 201 | "predicted_traj_mask": predicted_traj_mask, 202 | "predicted_feature": predicted_feature, 203 | "predicted_his_traj_delt": predicted_his_traj_delt, 204 | "predicted_his_pos": predicted_his_pos, 205 | "traffic_light": traffic_light, 206 | "traffic_light_pos": traffic_light_pos, 207 | "traffic_mask": traffic_mask, 208 | "lane_list": lane_list, 209 | "map_json": map_json 210 | } 211 | return result 212 | -------------------------------------------------------------------------------- /net_works/diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: diffusion.py 6 | @Author: YangChen 7 | @Date: 2023/12/27 8 | """ 9 | from functools import partial 10 | from typing import List 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | 17 | from net_works.transformer import TransformerCrossAttention 18 | 19 | 20 | def extract(a, t, x_shape): 21 | t = t.to(torch.long) 22 | b, *_ = t.shape 23 | out = a.gather(-1, t) 24 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 25 | 26 | 27 | class LinearLayer(nn.Module): 28 | def __init__(self, input_dim: int, output_dim: int): 29 | super().__init__() 30 | self.layer = nn.Sequential( 31 | nn.Linear(input_dim, output_dim, bias=True, dtype=torch.float32), 32 | nn.LeakyReLU(inplace=True) 33 | ) 34 | self.normal = nn.BatchNorm1d(output_dim) 35 | 36 | def forward(self, x): 37 | linear_output = self.layer(x) 38 | linear_output = torch.transpose(linear_output, -1, -2) 39 | normal_output = self.normal(linear_output) 40 | return torch.transpose(normal_output, -1, -2) 41 | 42 | 43 | class Decoder(nn.Module): 44 | def __init__(self, input_dim: int, middle_dim: int, output_dim: int): 45 | super(Decoder, self).__init__() 46 | self.up = LinearLayer(input_dim, output_dim) 47 | self.cat_layer = LinearLayer(middle_dim, output_dim) 48 | 49 | def forward(self, x1, x2): 50 | up_output = self.up(x1) 51 | cat_output = torch.cat((up_output, x2), dim=-1) 52 | output = self.cat_layer(cat_output) 53 | return output 54 | 55 | 56 | class UnetDiffusionModel(nn.Module): 57 | def __init__( 58 | self, dims: List[int] = None, input_dim: int = 3, 59 | conditional_dim: int = 5, his_stp: int = 11 60 | ): 61 | super(UnetDiffusionModel, self).__init__() 62 | if dims is None: 63 | self.__dims = [64, 128, 256, 512] 64 | else: 65 | self.__dims = dims 66 | self.his_delt_step = his_stp - 1 67 | self.input_dim = input_dim 68 | input_tensor_dim = self.his_delt_step * input_dim + conditional_dim * his_stp + 1 69 | self.layer1 = LinearLayer(input_tensor_dim, self.__dims[0]) 70 | self.layer2 = LinearLayer(self.__dims[0], self.__dims[1]) 71 | self.layer3 = LinearLayer(self.__dims[1], self.__dims[2]) 72 | self.layer4 = LinearLayer(self.__dims[2], self.__dims[3]) 73 | self.layer5 = LinearLayer(self.__dims[3], 64) 74 | self.decode4 = Decoder(64, 768, self.__dims[2]) 75 | self.decode3 = Decoder(self.__dims[2], 384, self.__dims[1]) 76 | self.decode2 = Decoder(self.__dims[1], 192, self.__dims[0]) 77 | self.output_layer = nn.Sequential( 78 | nn.Linear( 79 | self.__dims[0], self.his_delt_step * input_dim, 80 | bias=True, dtype=torch.float32 81 | ), 82 | nn.Tanh() 83 | ) 84 | 85 | def forward(self, perturbed_x, t, predicted_his_traj): 86 | # batch, obs_num, 10, 5 87 | batch_size = perturbed_x.shape[0] 88 | obs_num = perturbed_x.shape[1] 89 | input_tensor = torch.flatten(perturbed_x, start_dim=2) 90 | his_traj = torch.flatten(predicted_his_traj, start_dim=2) 91 | t = t.view(-1, 1, 1).repeat((1, obs_num, 1)) 92 | # batch, obs_num, 50 + 50 + 1 93 | input_tensor = torch.cat([input_tensor, his_traj, t], dim=-1).to(torch.float32) 94 | # batch, 64 95 | e1 = self.layer1(input_tensor) 96 | e2 = self.layer2(e1) 97 | e3 = self.layer3(e2) 98 | e4 = self.layer4(e3) 99 | f = self.layer5(e4) 100 | d4 = self.decode4(f, e4) 101 | d3 = self.decode3(d4, e3) 102 | d2 = self.decode2(d3, e2) 103 | out = self.output_layer(d2) 104 | return out.view(batch_size, obs_num, self.his_delt_step, self.input_dim) 105 | 106 | 107 | class DitDiffusionModel(nn.Module): 108 | def __init__( 109 | self, input_dim: int = 3, conditional_dim: int = 5, 110 | his_stp: int = 11, num_dit_blocks: int = 4 111 | ): 112 | super(DitDiffusionModel, self).__init__() 113 | self.his_delt_step = his_stp - 1 114 | self.input_dim = (self.his_delt_step * input_dim) + 1 115 | self.conditional_dim = his_stp * conditional_dim 116 | self.dit_blocks = nn.ModuleList([]) 117 | for _ in range(num_dit_blocks): 118 | self.dit_blocks.append(TransformerCrossAttention(self.input_dim, self.conditional_dim)) 119 | self.linear_output = nn.Sequential( 120 | nn.Linear(self.input_dim, self.input_dim * 2), 121 | nn.LayerNorm(self.input_dim * 2), 122 | nn.ReLU(inplace=True), 123 | nn.Linear(self.input_dim * 2, self.input_dim), 124 | nn.LayerNorm(self.input_dim), 125 | nn.ReLU(inplace=True), 126 | nn.Linear(self.input_dim, self.his_delt_step * input_dim), 127 | nn.Tanh() 128 | ) 129 | 130 | def forward(self, perturbed_x, t, predicted_his_traj): 131 | batch_size = perturbed_x.shape[0] 132 | obs_num = perturbed_x.shape[1] 133 | # batch, pred_obs, his_stp * 3 134 | input_tensor = torch.flatten(perturbed_x, start_dim=2) 135 | # batch, 15, 1 136 | t = t.view(-1, 1, 1).repeat(1, obs_num, 1) 137 | # batch, pred_obs, his_stp * 5 138 | predicted_his_traj = torch.flatten(predicted_his_traj, start_dim=2) 139 | # batch, pred_obs, his_stp * 3 + 1 140 | input_tensor = torch.cat([input_tensor, t], dim=-1) 141 | for dit_block in self.dit_blocks: 142 | input_tensor = dit_block(input_tensor, predicted_his_traj) 143 | noise_output = self.linear_output(input_tensor) 144 | return noise_output.view(batch_size, obs_num, self.his_delt_step, -1) 145 | 146 | 147 | class GaussianDiffusion(nn.Module): 148 | def __init__( 149 | self, input_dim: int = 5, conditional_dim: int = 5, 150 | his_stp: int = 11, betas: np.ndarray = None, 151 | loss_type: str = "l2", num_dit_blocks: int = 4, 152 | diffusion_type: str = "none" 153 | ): 154 | super(GaussianDiffusion, self).__init__() 155 | if betas is None: 156 | betas = [] 157 | # l1或者l2损失 158 | if loss_type not in ["l1", "l2"]: 159 | raise ValueError(f"get unknown loss type: {loss_type}") 160 | if diffusion_type not in ["dit", "unet", "none"]: 161 | raise ValueError(f"get unknown diffusion type: {diffusion_type}") 162 | 163 | self.loss_type = loss_type 164 | self.diffusion_type = diffusion_type 165 | self.num_time_steps = len(betas) 166 | 167 | alphas = 1.0 - betas 168 | alphas_cum_prod = np.cumprod(alphas) 169 | # 转换成torch.tensor来处理 170 | to_torch = partial(torch.tensor, dtype=torch.float32) 171 | 172 | # betas [0.0001, 0.00011992, 0.00013984 ... , 0.02] 173 | self.register_buffer("betas", to_torch(betas)) 174 | # alphas [0.9999, 0.99988008, 0.99986016 ... , 0.98] 175 | self.register_buffer("alphas", to_torch(alphas)) 176 | # alphas_cum_prod [9.99900000e-01, 9.99780092e-01, 9.99640283e-01 ... , 4.03582977e-05] 177 | self.register_buffer("alphas_cum_prod", to_torch(alphas_cum_prod)) 178 | # sqrt(alphas_cum_prod) 179 | self.register_buffer("sqrt_alphas_cum_prod", to_torch(np.sqrt(alphas_cum_prod))) 180 | # sqrt(1 - alphas_cum_prod) 181 | self.register_buffer("sqrt_one_minus_alphas_cum_prod", to_torch(np.sqrt(1 - alphas_cum_prod))) 182 | # sqrt(1 / alphas) 183 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas))) 184 | self.register_buffer("sigma", to_torch(np.sqrt(betas))) 185 | alphas_cum_prod_prev = np.append(1, alphas_cum_prod[:-1]) 186 | # self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cum_prod))) 187 | self.register_buffer("remove_noise_coeff", 188 | to_torch(betas * (1 - alphas_cum_prod_prev / np.sqrt(1 - alphas_cum_prod)))) 189 | # 初始化model 190 | self.his_delt_step = his_stp - 1 191 | if self.diffusion_type == "dit": 192 | self.diffusion_model = DitDiffusionModel( 193 | input_dim=input_dim, 194 | conditional_dim=conditional_dim, 195 | num_dit_blocks=num_dit_blocks 196 | ) 197 | elif self.diffusion_type == "unet": 198 | self.diffusion_model = UnetDiffusionModel( 199 | input_dim=input_dim, 200 | conditional_dim=conditional_dim, 201 | ) 202 | else: 203 | self.diffusion_model = None 204 | self.input_dim = input_dim 205 | self.norm_output = nn.BatchNorm2d(5) 206 | 207 | def remove_noise(self, noise, t_batch, predicted_his_traj): 208 | model_output = self.diffusion_model(noise, t_batch, predicted_his_traj) 209 | return ( 210 | (noise - extract(self.remove_noise_coeff, t_batch, noise.shape) * model_output) * 211 | extract(self.reciprocal_sqrt_alphas, t_batch, noise.shape) 212 | ) 213 | 214 | def sample(self, noise, predicted_his_traj): 215 | if self.diffusion_type != "none": 216 | batch_size = predicted_his_traj.shape[0] 217 | device = predicted_his_traj.device 218 | for t in range(self.num_time_steps - 1, -1, -1): 219 | t_batch = torch.tensor([t], device=device).repeat(batch_size) 220 | noise = self.remove_noise(noise, t_batch, predicted_his_traj) 221 | if t > 0: 222 | noise += extract(self.sigma, t_batch, noise.shape) * torch.randn_like(noise) 223 | noise = torch.transpose(noise, 1, -1) 224 | noise = self.norm_output(noise) 225 | noise = torch.transpose(noise, 1, -1) 226 | return noise 227 | 228 | def perturb_x(self, future_traj, t, noise): 229 | return ( 230 | extract(self.sqrt_alphas_cum_prod, t, future_traj.shape) * future_traj + 231 | extract(self.sqrt_one_minus_alphas_cum_prod, t, future_traj.shape) * noise 232 | ) 233 | 234 | def get_losses(self, predicted_his_traj_delt, predicted_his_traj, predicted_traj_mask, t): 235 | if self.diffusion_type != "none": 236 | noise = torch.randn_like(predicted_his_traj_delt) 237 | perturbed_x = self.perturb_x(predicted_his_traj_delt, t, noise) 238 | estimated_noise = self.diffusion_model(perturbed_x, t, predicted_his_traj) 239 | batch_size, obs_num = perturbed_x.shape[0], perturbed_x.shape[1] 240 | # diffusion_loss 241 | diffusion_loss = F.mse_loss(estimated_noise, noise, reduction="none") 242 | diffusion_loss_mask = (predicted_traj_mask.view(batch_size, obs_num, 1, 1) 243 | .repeat(1, 1, self.his_delt_step, self.input_dim)) 244 | diffusion_loss = torch.sum(diffusion_loss * diffusion_loss_mask) / torch.sum(diffusion_loss_mask) 245 | return diffusion_loss 246 | else: 247 | return torch.tensor(0).to(torch.float32) 248 | 249 | def forward(self, data: dict): 250 | # batch, pred_obs(8), his_step, 3 251 | # batch, pred_obs(8), his_step, 5 252 | predicted_his_traj_delt = data['predicted_his_traj_delt'] 253 | predicted_his_traj = data['predicted_his_traj'] 254 | predicted_traj_mask = data['predicted_traj_mask'] 255 | batch_size = predicted_his_traj_delt.shape[0] 256 | device = predicted_his_traj.device 257 | t = torch.randint(0, self.num_time_steps, (batch_size,), device=device) 258 | return self.get_losses(predicted_his_traj_delt, predicted_his_traj, predicted_traj_mask, t) 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tasks/show_result_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: show_result_task.py 6 | @Author: YangChen 7 | @Date: 2024/1/6 8 | """ 9 | import os.path 10 | import shutil 11 | from typing import Any, Dict 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import tensorflow as tf 16 | import torch 17 | from matplotlib import animation 18 | from matplotlib import patches 19 | from matplotlib.collections import LineCollection 20 | from matplotlib.colors import LinearSegmentedColormap 21 | from waymo_open_dataset.protos import scenario_pb2 22 | from waymo_open_dataset.protos.scenario_pb2 import Scenario 23 | from waymo_open_dataset.utils.sim_agents import visualizations, submission_specs 24 | 25 | from common import TaskType, LoadConfigResultDate 26 | from net_works import BackBone 27 | from tasks import BaseTask 28 | from utils import DataUtil, MathUtil, MapUtil 29 | 30 | RESULT_DIR = r"/home/haomo/yangchen/Scene-Diffusion/output/image" 31 | DATA_SET_PATH = r"/home/haomo/yangchen/Scene-Diffusion/data_set/train_set/training.tfrecord-00000-of-01000" 32 | MODEL_PATH = r"/home/haomo/yangchen/Scene-Diffusion/model_epoch152.pth" 33 | 34 | 35 | class ShowResultsTask(BaseTask): 36 | 37 | def __init__(self): 38 | super(ShowResultsTask, self).__init__() 39 | self.task_type = TaskType.SHOW_RESULTS 40 | self.cmap = LinearSegmentedColormap.from_list( 41 | 'my_cmap', 42 | [np.array([0., 232., 157.]) / 255, np.array([0., 120., 255.]) / 255], 43 | 100 44 | ) 45 | self.color_dict = { 46 | 0: np.array([0., 120., 255.]) / 255, 47 | 1: np.array([0., 232., 157.]) / 255, 48 | 2: np.array([255., 205., 85.]) / 255, 49 | 3: np.array([244., 175., 145.]) / 255, 50 | 4: np.array([145., 80., 200.]) / 255, 51 | 5: np.array([0., 51., 102.]) / 255, 52 | 6: np.array([1, 0, 0]), 53 | 7: np.array([0, 1, 0]), 54 | } 55 | 56 | def execute(self, result_info: LoadConfigResultDate): 57 | if os.path.exists(RESULT_DIR): 58 | shutil.rmtree(RESULT_DIR) 59 | os.makedirs(RESULT_DIR, exist_ok=True) 60 | self.show_result(result_info) 61 | 62 | @staticmethod 63 | def load_pretrain_model(result_info: LoadConfigResultDate) -> BackBone: 64 | betas = MathUtil.generate_linear_schedule(result_info.train_model_config.time_steps) 65 | model = BackBone(betas).eval() 66 | device = torch.device("cpu") 67 | pretrained_dict = torch.load(MODEL_PATH, map_location=device) 68 | model_dict = model.state_dict() 69 | # 模型参数赋值 70 | new_model_dict = dict() 71 | for key in model_dict.keys(): 72 | if ("module." + key) in pretrained_dict: 73 | new_model_dict[key] = pretrained_dict["module." + key] 74 | elif key in pretrained_dict: 75 | new_model_dict[key] = pretrained_dict[key] 76 | else: 77 | print("key: ", key, ", not in pretrained") 78 | model.load_state_dict(new_model_dict) 79 | print("load_pretrain_model success") 80 | return model 81 | 82 | def show_result(self, result_info: LoadConfigResultDate): 83 | model = self.load_pretrain_model(result_info) 84 | match_filenames = tf.io.matching_files([DATA_SET_PATH]) 85 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data").take(100) 86 | dataset_iterator = dataset.as_numpy_iterator() 87 | for index, scenario_bytes in enumerate(dataset_iterator): 88 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes) 89 | data_dict = DataUtil.transform_data_to_input(scenario, result_info) 90 | for key, value in data_dict.items(): 91 | if isinstance(value, torch.Tensor): 92 | data_dict[key] = value.to(torch.float32).unsqueeze(dim=0) 93 | 94 | predict_traj = model(data_dict)[-1] 95 | predicted_traj_mask = data_dict['predicted_traj_mask'][0] 96 | predicted_future_traj = data_dict['predicted_future_traj'][0] 97 | predicted_his_traj = data_dict['predicted_his_traj'][0] 98 | predicted_num = 0 99 | for i in range(predicted_traj_mask.shape[0]): 100 | if int(predicted_traj_mask[i]) == 1: 101 | predicted_num += 1 102 | generate_traj = predict_traj[:predicted_num] 103 | predicted_future_traj = predicted_future_traj[:predicted_num] 104 | predicted_his_traj = predicted_his_traj[:predicted_num] 105 | real_traj = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, :2].detach().numpy() 106 | real_yaw = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, 2].detach().numpy() 107 | model_output = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, :2].detach().numpy() 108 | model_yaw = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, 2].detach().numpy() 109 | # 可视化输入 110 | image_path = os.path.join(RESULT_DIR, f"{index}_input.png") 111 | self.draw_input(scenario, image_path) 112 | # 可视化ground truth 113 | image_path = os.path.join(RESULT_DIR, f"{index}_ground_truth.png") 114 | self.draw_scene(predicted_num, real_traj, data_dict, scenario, image_path) 115 | # 可视化model output 116 | image_path = os.path.join(RESULT_DIR, f"{index}_model_output.png") 117 | self.draw_scene(predicted_num, model_output, data_dict, scenario, image_path) 118 | # 可视化动态图 119 | image_path = os.path.join(RESULT_DIR, f"{index}_ground_truth.gif") 120 | self.draw_gif(predicted_num, real_traj, real_yaw, data_dict, scenario, image_path) 121 | image_path = os.path.join(RESULT_DIR, f"{index}_model_output.gif") 122 | self.draw_gif(predicted_num, model_output, model_yaw, data_dict, scenario, image_path) 123 | 124 | fig, axis = plt.subplots(1, 1, figsize=(10, 11)) 125 | num = np.array([i for i in range(91)]) 126 | segments = np.stack((num, num), axis=-1)[np.newaxis, :] 127 | image_path = os.path.join(RESULT_DIR, f"color_bar.png") 128 | line_segments = LineCollection(segments=segments, linewidths=1, 129 | linestyles='solid', cmap=self.cmap) 130 | cbar = fig.colorbar(line_segments, cmap=self.cmap, orientation='horizontal') 131 | cbar.set_ticks(np.linspace(0, 1, 10)) 132 | cbar.set_ticklabels([str(i) for i in range(0, 91, 10)]) 133 | plt.savefig(image_path) 134 | plt.close('all') # 避免内存泄漏 135 | 136 | @staticmethod 137 | def draw_input(scenario: Scenario, image_path: str): 138 | fig, axis = plt.subplots(1, 1, figsize=(10, 10)) 139 | visualizations.add_map(axis, scenario) 140 | predicted_obs_ids = submission_specs.get_evaluation_sim_agent_ids(scenario) 141 | current_time_index = scenario.current_time_index 142 | for track in scenario.tracks: 143 | if track.id not in predicted_obs_ids: 144 | continue 145 | param_dict = { 146 | "x": track.states[current_time_index].center_x, 147 | "y": track.states[current_time_index].center_y, 148 | "bbox_yaw": track.states[current_time_index].heading, 149 | "length": track.states[current_time_index].length, 150 | "width": track.states[current_time_index].width, 151 | } 152 | rect = visualizations.get_bbox_patch(**param_dict) 153 | axis.add_patch(rect) 154 | plt.savefig(image_path) 155 | plt.close('all') # 避免内存泄漏 156 | 157 | def draw_scene( 158 | self, predicted_num: int, traj: np.ndarray, 159 | data_dict: Dict[str, Any], scenario: Scenario, image_path: str 160 | ): 161 | fig, axis = plt.subplots(1, 1, figsize=(10, 10)) 162 | visualizations.add_map(axis, scenario) 163 | # visualizations.get_bbox_patch() 164 | # axis.axis('equal') # 横纵坐标比例相等 165 | curr_x, curr_y, curr_heading, _ = data_dict['curr_loc'] 166 | for i in range(predicted_num): 167 | real_traj_x, real_traj_y = MapUtil.local_to_global(curr_heading, traj[i, :, 0], 168 | traj[i, :, 1], curr_x, curr_y) 169 | num = np.linspace(0, 1, len(real_traj_x)) 170 | for j in range(2, len(real_traj_x)): 171 | axis.plot( 172 | real_traj_x[j - 2:j], 173 | real_traj_y[j - 2:j], 174 | linewidth=5, 175 | color=self.cmap(num[j]), 176 | ) 177 | axis.set_xticks([]) 178 | axis.set_yticks([]) 179 | # plt.show() 180 | plt.savefig(image_path) 181 | plt.close('all') # 避免内存泄漏 182 | print(f"{image_path} save success") 183 | 184 | def draw_gif( 185 | self, predicted_num: int, traj: np.ndarray, real_yaw: np.ndarray, 186 | data_dict: Dict[str, Any], scenario: Scenario, image_path: str 187 | ): 188 | fig, axis = plt.subplots(1, 1, figsize=(10, 10)) 189 | visualizations.add_map(axis, scenario) 190 | # visualizations.get_bbox_patch() 191 | # axis.axis('equal') # 横纵坐标比例相等 192 | curr_x, curr_y, curr_heading, _ = data_dict['curr_loc'] 193 | x_list = list() 194 | y_list = list() 195 | yaw_list = list() 196 | for i in range(predicted_num): 197 | real_traj_x, real_traj_y = MapUtil.local_to_global(curr_heading, traj[i, :, 0], 198 | traj[i, :, 1], curr_x, curr_y) 199 | real_traj_yaw = MapUtil.theta_local_to_global(curr_heading, real_yaw[i]) 200 | x_list.append(real_traj_x) 201 | y_list.append(real_traj_y) 202 | yaw_list.append(real_traj_yaw) 203 | # [num, step] 204 | x_list = np.stack(x_list, axis=0) 205 | y_list = np.stack(y_list, axis=0) 206 | yaw_list = np.stack(yaw_list, axis=0) 207 | predicted_feature = data_dict['predicted_feature'].squeeze()[:, :2] 208 | 209 | def animate(t: int) -> list[patches.Rectangle]: 210 | # At each animation step, we need to remove the existing patches. This can 211 | # only be done using the `pop()` operation. 212 | for _ in range(len(axis.patches)): 213 | axis.patches.pop() 214 | bboxes = list() 215 | for j in range(x_list.shape[0]): 216 | bboxes.append(axis.add_patch( 217 | self.get_bbox_patch( 218 | x_list[:, t][j], y_list[:, t][j], yaw_list[:, t][j], 219 | predicted_feature[j, 1], predicted_feature[j, 0], self.color_dict[j] 220 | ) 221 | )) 222 | return bboxes 223 | 224 | animations = animation.FuncAnimation( 225 | fig, animate, frames=x_list.shape[1], interval=100, 226 | blit=True) 227 | axis.set_xticks([]) 228 | axis.set_yticks([]) 229 | # plt.show() 230 | animations.save(image_path, writer='ffmpeg', fps=30) 231 | plt.close('all') # 避免内存泄漏 232 | print(f"{image_path} save success") 233 | 234 | @staticmethod 235 | def get_bbox_patch( 236 | x: float, y: float, bbox_yaw: float, length: float, width: float, 237 | color: np.ndarray 238 | ) -> patches.Rectangle: 239 | left_rear_object = np.array([-length / 2, -width / 2]) 240 | 241 | rotation_matrix = np.array([[np.cos(bbox_yaw), -np.sin(bbox_yaw)], 242 | [np.sin(bbox_yaw), np.cos(bbox_yaw)]]) 243 | left_rear_rotated = rotation_matrix.dot(left_rear_object) 244 | left_rear_global = np.array([x, y]) + left_rear_rotated 245 | color = list(color) + [0.5] 246 | rect = patches.Rectangle( 247 | left_rear_global, length, width, angle=np.rad2deg(bbox_yaw), color=color) 248 | return rect 249 | 250 | 251 | if __name__ == "__main__": 252 | # show_result() 253 | pass 254 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @Project: WcDT 5 | @Name: data_utils.py 6 | @Author: YangChen 7 | @Date: 2023/12/24 8 | """ 9 | import json 10 | import math 11 | from typing import List, Dict, Any, Tuple 12 | 13 | import numpy as np 14 | import torch 15 | from waymo_open_dataset.protos import scenario_pb2 16 | from waymo_open_dataset.utils.sim_agents import submission_specs 17 | 18 | from common import ObjectType, LoadConfigResultDate 19 | from utils.map_utils import MapUtil 20 | 21 | 22 | class DataUtil: 23 | 24 | @classmethod 25 | def load_scenario_data(cls, scenario: scenario_pb2.Scenario) -> Dict[str, Any]: 26 | result_dict = dict() 27 | # 获取自车当前位置把所有轨迹和车道线转化成自车坐标系 28 | sdc_track_index = scenario.sdc_track_index 29 | current_time_index = scenario.current_time_index 30 | curr_state = scenario.tracks[scenario.sdc_track_index].states[current_time_index] 31 | ego_curr_x, ego_curr_y = curr_state.center_x, curr_state.center_y 32 | ego_curr_heading = curr_state.heading 33 | # 需要预测的障碍物id 34 | predicted_obs_ids = submission_specs.get_evaluation_sim_agent_ids(scenario) 35 | obs_tracks = cls.load_obs_tracks(scenario, ego_curr_x, ego_curr_y, ego_curr_heading) 36 | map_features = cls.load_map_features(scenario, ego_curr_x, ego_curr_y, ego_curr_heading) 37 | traffic_lights = cls.load_traffic_light(scenario, ego_curr_x, ego_curr_y, ego_curr_heading) 38 | if len(obs_tracks) <= 1: 39 | return result_dict 40 | result_dict['predicted_obs_ids'] = predicted_obs_ids 41 | result_dict['obs_tracks'] = obs_tracks 42 | result_dict['map_features'] = map_features 43 | result_dict['dynamic_states'] = traffic_lights 44 | result_dict['curr_loc'] = (ego_curr_x, ego_curr_y, ego_curr_heading, sdc_track_index) 45 | return result_dict 46 | 47 | @classmethod 48 | def load_obs_tracks(cls, scenario: scenario_pb2.Scenario, 49 | ego_curr_x: float, ego_curr_y: float, 50 | ego_curr_heading: float) -> List[Dict[str, Any]]: 51 | obs_tracks = list() 52 | # 一个障碍物的状态(id, 类型, 轨迹) 53 | for track in scenario.tracks: 54 | one_obs_track = dict() 55 | obs_id = track.id 56 | if hasattr(track, "object_type"): 57 | object_type = track.object_type 58 | else: 59 | object_type = 4 60 | obs_traj = list() 61 | for state in track.states: 62 | if not state.valid: 63 | continue 64 | if 'height' not in one_obs_track and 'length' not in one_obs_track \ 65 | and 'width' not in one_obs_track: 66 | one_obs_track['height'] = state.height 67 | one_obs_track['length'] = state.length 68 | one_obs_track['width'] = state.width 69 | center_x, center_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 70 | state.center_x, state.center_y) 71 | center_heading = MapUtil.theta_global_to_local(ego_curr_heading, state.heading) 72 | # 速度变化有两种方法: 73 | # 1. 在新坐标系做投影 74 | # 2. 计算新坐标系的航向角,用速度乘sin和cos 75 | curr_v = math.sqrt(math.pow(state.velocity_x, 2) + 76 | math.pow(state.velocity_y, 2)) 77 | curr_v_heading = math.atan2(state.velocity_y, state.velocity_x) 78 | curr_v_heading = MapUtil.theta_global_to_local(ego_curr_heading, curr_v_heading) 79 | obs_traj.append( 80 | (center_x, center_y, center_heading, 81 | curr_v * math.cos(curr_v_heading), curr_v * math.sin(curr_v_heading)) 82 | ) 83 | one_obs_track['obs_id'] = obs_id 84 | one_obs_track['object_type'] = object_type 85 | one_obs_track['obs_traj'] = obs_traj 86 | # 轨迹丢失的障碍物不需要 87 | if len(obs_traj) == 91: 88 | obs_tracks.append(one_obs_track) 89 | return obs_tracks 90 | 91 | @staticmethod 92 | def load_map_features(scenario: scenario_pb2.Scenario, 93 | ego_curr_x: float, ego_curr_y: float, 94 | ego_curr_heading: float) -> List[Dict[str, Any]]: 95 | map_features = list() 96 | for map_feature in scenario.map_features: 97 | one_map_dict = dict() 98 | map_id = map_feature.id 99 | polygon_points = list() 100 | if hasattr(map_feature, "road_edge") and map_feature.road_edge.polyline: 101 | map_type = "road_edge" 102 | polygon_list = map_feature.road_edge.polyline 103 | elif hasattr(map_feature, "road_line") and map_feature.road_line.polyline: 104 | map_type = "road_line" 105 | polygon_list = map_feature.road_line.polyline 106 | else: 107 | continue 108 | for polygon_point in polygon_list: 109 | polygon_point_x, polygon_point_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 110 | polygon_point.x, polygon_point.y) 111 | polygon_points.append((polygon_point_x, polygon_point_y)) 112 | one_map_dict['map_id'] = map_id 113 | one_map_dict['map_type'] = map_type 114 | one_map_dict['polygon_points'] = polygon_points 115 | map_features.append(one_map_dict) 116 | return map_features 117 | 118 | @staticmethod 119 | def load_traffic_light(scenario: scenario_pb2.Scenario, 120 | ego_curr_x: float, ego_curr_y: float, 121 | ego_curr_heading: float) -> Dict[str, Any]: 122 | dynamic_states = dict() 123 | for dynamic_state in scenario.dynamic_map_states: 124 | for lane_state in dynamic_state.lane_states: 125 | lane_id = lane_state.lane 126 | lane_x, lane_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 127 | lane_state.stop_point.x, 128 | lane_state.stop_point.y) 129 | if lane_id not in dynamic_states: 130 | dynamic_states[lane_id] = list() 131 | dynamic_states[lane_id].append((lane_x, lane_y)) 132 | state = lane_state.state 133 | dynamic_states[lane_id].append(state) 134 | return dynamic_states 135 | 136 | @staticmethod 137 | def split_pkl_data(one_pkl_dict: Dict[str, Any], his_step: int) -> Tuple: 138 | map_points = [feature['polygon_points'] for feature in one_pkl_dict['map_features']] 139 | predicted_obs_ids = one_pkl_dict['predicted_obs_ids'] 140 | # 初始化需要保存的信息 141 | index = 0 142 | traj_list = list() 143 | obs_feature_list = list() 144 | predicted_obs_index = list() 145 | # 障碍物信息 146 | for one_obs_info in one_pkl_dict['obs_tracks']: 147 | if len(one_obs_info['obs_traj']) < 91 or None in one_obs_info['obs_traj']: 148 | continue 149 | # 障碍物size和type 150 | obs_feature = list() 151 | obs_feature.append(one_obs_info['width']) 152 | obs_feature.append(one_obs_info['length']) 153 | type_onehot = [0] * len(ObjectType) 154 | type_onehot[one_obs_info['object_type']] = 1 155 | obs_feature += type_onehot 156 | obs_feature_list.append(obs_feature) 157 | # 记录predicted_obs的索引 158 | if one_obs_info['obs_id'] in predicted_obs_ids: 159 | predicted_obs_index.append(index) 160 | traj = np.array(one_obs_info['obs_traj']) 161 | traj_list.append(traj) 162 | index += 1 163 | if len(predicted_obs_index) < 1: 164 | return tuple() 165 | # 动态地图信息 166 | dynamic_states = list() 167 | dynamic_pos = list() 168 | for key, value in one_pkl_dict['dynamic_states'].items(): 169 | dynamic_pos.append(value[0]) 170 | dynamic_state = value[1:] 171 | dynamic_state = dynamic_state[:his_step] 172 | if len(dynamic_state) < his_step: 173 | dynamic_state = dynamic_state + ([0] * (his_step - len(dynamic_state))) 174 | dynamic_states.append(dynamic_state) 175 | traj_arr = np.stack(traj_list, axis=0) 176 | obs_feature_list = np.stack(obs_feature_list, axis=0) 177 | dynamic_states = np.array(dynamic_states) 178 | dynamic_pos = np.array(dynamic_pos) 179 | one_pkl_data = (predicted_obs_index, traj_arr, 180 | obs_feature_list, dynamic_states, 181 | dynamic_pos, map_points) 182 | return one_pkl_data 183 | 184 | @staticmethod 185 | def get_obs_feature( 186 | config_data: LoadConfigResultDate, 187 | other_obs_index: List[int], 188 | all_obs_his_traj: torch.Tensor, 189 | all_obs_feature: torch.Tensor 190 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 191 | his_step = config_data.train_model_config.his_step 192 | max_other_num = config_data.train_model_config.max_other_num 193 | # other obs历史轨迹 194 | if len(other_obs_index) > 0: 195 | other_obs_traj_index = (torch.Tensor(other_obs_index) 196 | .to(torch.long).view(-1, 1, 1) 197 | .repeat((1, his_step, 5))) 198 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index) 199 | other_obs_feature_index = (torch.Tensor(other_obs_index) 200 | .to(torch.long).view(-1, 1) 201 | .repeat((1, 7))) 202 | # other feature 203 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index) 204 | other_his_traj = other_his_traj[:max_other_num] 205 | other_feature = other_feature[:max_other_num] 206 | num_gap = max_other_num - other_his_traj.shape[0] 207 | if num_gap > 0: 208 | other_his_traj = torch.cat((other_his_traj, 209 | torch.zeros(size=(num_gap, his_step, 5), 210 | dtype=torch.float32)), dim=0) 211 | other_feature = torch.cat((other_feature, 212 | torch.zeros(size=(num_gap, 7), 213 | dtype=torch.float32)), dim=0) 214 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] * 215 | num_gap) 216 | else: 217 | other_his_traj = torch.zeros( 218 | size=(max_other_num, his_step, 5), 219 | dtype=torch.float32 220 | ) 221 | other_feature = torch.zeros(size=(max_other_num, 7), dtype=torch.float32) 222 | other_traj_mask = torch.zeros(size=[max_other_num], dtype=torch.float32) 223 | return other_his_traj, other_feature, other_traj_mask 224 | 225 | @staticmethod 226 | def get_pred_feature( 227 | config_data: LoadConfigResultDate, 228 | predicted_index: torch.Tensor, 229 | all_obs_his_traj: torch.Tensor, 230 | all_obs_future_traj: torch.Tensor, 231 | all_obs_feature: torch.Tensor 232 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 233 | his_step = config_data.train_model_config.his_step 234 | future_step = 91 - his_step 235 | max_pred_num = config_data.train_model_config.max_pred_num 236 | predicted_his_traj = torch.gather( 237 | all_obs_his_traj, 0, 238 | predicted_index.repeat((1, his_step, 5)) 239 | ) 240 | predicted_future_traj = torch.gather( 241 | all_obs_future_traj, 0, 242 | predicted_index.repeat((1, future_step, 5)) 243 | ) 244 | predicted_his_traj = predicted_his_traj[:max_pred_num] 245 | predicted_future_traj = predicted_future_traj[:max_pred_num] 246 | # predicted feature 247 | predicted_feature = torch.gather( 248 | all_obs_feature, 0, 249 | predicted_index.view(-1, 1).repeat((1, 7)) 250 | ) 251 | predicted_feature = predicted_feature[:max_pred_num] 252 | num_gap = max_pred_num - predicted_his_traj.shape[0] 253 | if num_gap > 0: 254 | predicted_his_traj = torch.cat((predicted_his_traj, 255 | torch.zeros(size=(num_gap, his_step, 5), 256 | dtype=torch.float32)), dim=0) 257 | predicted_future_traj = torch.cat((predicted_future_traj, 258 | torch.zeros(size=(num_gap, future_step, 5), 259 | dtype=torch.float32)), dim=0) 260 | predicted_feature = torch.cat((predicted_feature, 261 | torch.zeros(size=(num_gap, 7), 262 | dtype=torch.float32)), dim=0) 263 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj.shape[0] - num_gap) + 264 | [0.0] * num_gap) 265 | return predicted_future_traj, predicted_his_traj, predicted_feature, predicted_traj_mask 266 | 267 | @staticmethod 268 | def get_traffic_light( 269 | config_data: LoadConfigResultDate, 270 | traffic_light: torch.Tensor, 271 | traffic_light_pos: torch.Tensor 272 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 273 | his_step = config_data.train_model_config.his_step 274 | max_traffic_light = config_data.train_model_config.max_traffic_light 275 | traffic_light = traffic_light[:max_traffic_light] 276 | traffic_light_pos = traffic_light_pos[:max_traffic_light] 277 | num_gap = max_traffic_light - traffic_light.shape[0] 278 | if num_gap > 0: 279 | traffic_light = torch.cat((traffic_light, 280 | torch.zeros(size=(num_gap, his_step), 281 | dtype=torch.float32)), dim=0) 282 | traffic_light_pos = torch.cat((traffic_light_pos, 283 | torch.zeros(size=(num_gap, 2), 284 | dtype=torch.float32)), dim=0) 285 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) + 286 | [0.0] * num_gap) 287 | return traffic_light, traffic_light_pos, traffic_mask 288 | 289 | @staticmethod 290 | def get_lane_feature(config_data: LoadConfigResultDate, lane_feature: List) -> torch.Tensor: 291 | max_lane_num = config_data.train_model_config.max_lane_num 292 | max_point_num = config_data.train_model_config.max_point_num 293 | lane_list = list() 294 | for lane_index, lane in enumerate(lane_feature): 295 | if lane_index >= max_lane_num: 296 | break 297 | point_list = lane[:max_point_num] 298 | point_gap = max_point_num - len(point_list) 299 | # MAX_POINT_NUM, 2 300 | point_list = torch.Tensor(point_list) 301 | if point_gap > 0: 302 | point_list = torch.cat((point_list, 303 | torch.zeros(size=(point_gap, 2), 304 | dtype=torch.float32)), dim=0) 305 | lane_list.append(point_list) 306 | lane_gap = max_lane_num - len(lane_list) 307 | lane_list = torch.stack(lane_list, dim=0) 308 | if lane_gap > 0: 309 | lane_list = torch.cat( 310 | (lane_list, torch.zeros(size=(lane_gap, max_point_num, 2), 311 | dtype=torch.float32)), dim=0 312 | ) 313 | return lane_list 314 | 315 | @classmethod 316 | def transform_data_to_input(cls, scenario: scenario_pb2.Scenario, 317 | config_data: LoadConfigResultDate) -> Dict[str, Any]: 318 | data_dict = cls.load_scenario_data(scenario) 319 | if len(data_dict) == 0: 320 | return dict() 321 | his_step = config_data.train_model_config.his_step 322 | pkl_data = DataUtil.split_pkl_data(data_dict, his_step) 323 | all_obs_index = set([i for i in range(pkl_data[1].shape[0])]) 324 | predicted_index = torch.Tensor(pkl_data[0]).to(torch.long).view(-1, 1, 1) 325 | all_obs_traj = torch.Tensor(pkl_data[1]) 326 | all_obs_feature = torch.Tensor(pkl_data[2]) 327 | all_obs_his_traj = all_obs_traj[:, :his_step] 328 | all_obs_future_traj = all_obs_traj[:, his_step:] 329 | other_obs_index = list(all_obs_index - set(pkl_data[0])) 330 | # 获取障碍物信息 331 | other_his_traj, other_feature, other_traj_mask = cls.get_obs_feature( 332 | config_data, other_obs_index, all_obs_his_traj, all_obs_feature 333 | ) 334 | # 需要预测的障碍物历史轨迹 335 | (predicted_future_traj, predicted_his_traj, 336 | predicted_feature, predicted_traj_mask) = cls.get_pred_feature( 337 | config_data, predicted_index, all_obs_his_traj, 338 | all_obs_future_traj, all_obs_feature 339 | ) 340 | # 交通灯信息 341 | traffic_light = torch.Tensor(pkl_data[3]) 342 | traffic_light_pos = torch.Tensor(pkl_data[4]) 343 | traffic_light, traffic_light_pos, traffic_mask = cls.get_traffic_light( 344 | config_data, traffic_light, traffic_light_pos 345 | ) 346 | # 车道线信息 347 | lane_list = cls.get_lane_feature(config_data, pkl_data[-1]) 348 | map_json = json.dumps(pkl_data[-1]) 349 | other_his_traj_delt = other_his_traj[:, 1:] - other_his_traj[:, :-1] 350 | other_his_pos = other_his_traj[:, -1, :2] 351 | predicted_his_traj_delt = predicted_his_traj[:, 1:] - predicted_his_traj[:, :-1] 352 | predicted_his_pos = predicted_his_traj[:, -1, :2] 353 | result = { 354 | "other_his_traj": other_his_traj, 355 | "other_feature": other_feature, 356 | "other_traj_mask": other_traj_mask, 357 | "other_his_traj_delt": other_his_traj_delt, 358 | "other_his_pos": other_his_pos, 359 | "predicted_future_traj": predicted_future_traj, 360 | "predicted_his_traj": predicted_his_traj, 361 | "predicted_traj_mask": predicted_traj_mask, 362 | "predicted_feature": predicted_feature, 363 | "predicted_his_traj_delt": predicted_his_traj_delt, 364 | "predicted_his_pos": predicted_his_pos, 365 | "traffic_light": traffic_light, 366 | "traffic_light_pos": traffic_light_pos, 367 | "traffic_mask": traffic_mask, 368 | "lane_list": lane_list, 369 | "map_json": map_json, 370 | "curr_loc": data_dict['curr_loc'] 371 | } 372 | return result 373 | -------------------------------------------------------------------------------- /gene_submission.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import tarfile 5 | from typing import Tuple, Dict, Any, List 6 | 7 | import cv2 8 | import numpy as np 9 | import tensorflow as tf 10 | import torch 11 | import tqdm 12 | from waymo_open_dataset.protos import scenario_pb2 13 | from waymo_open_dataset.utils.sim_agents import submission_specs 14 | from waymo_open_dataset.utils.sim_agents import submission_specs 15 | from waymo_open_dataset.wdl_limited.sim_agents_metrics import metrics 16 | from waymo_open_dataset.utils import trajectory_utils 17 | 18 | from nets import (GaussianDiffusion, generate_linear_schedule, SimpleViT, TrajDecorder, MapEncoder) 19 | 20 | SHIFT_H = 125 21 | SHIFT_W = 125 22 | PIX_SIZE = 0.8 23 | MAX_LANE_NUM = 50 24 | MAX_POINT_NUM = 250 25 | HIS_STEP = 11 26 | FUTRUE_STEP = 80 27 | MAX_PRED_NUM = 8 28 | MAX_OTHER_NUM = 6 29 | MAX_TRAFFIC_LIGHT_NUM = 8 30 | MAX_LANE_NUM = 32 31 | MAX_POINT_NUM = 128 32 | 33 | transform_arr = np.array([SHIFT_W, SHIFT_H]) 34 | pix_size = PIX_SIZE 35 | object_type = { 36 | 'TYPE_UNSET': 0, 37 | 'TYPE_VEHICLE': 1, 38 | 'TYPE_PEDESTRIAN': 2, 39 | 'TYPE_CYCLIST': 3, 40 | 'TYPE_OTHER': 4 41 | } 42 | 43 | ObjectType = {0: "TYPE_UNSET", 1: "TYPE_VEHICLE", 2: "TYPE_PEDESTRIAN", 3: "TYPE_CYCLIST", 4: "TYPE_OTHER"} 44 | MapState = {0: "LANE_STATE_UNKNOWN", 1: "LANE_STATE_ARROW_STOP", 2: "LANE_STATE_ARROW_CAUTION", 45 | 3: "LANE_STATE_ARROW_GO", 4: "LANE_STATE_STOP", 5: "LANE_STATE_CAUTION", 46 | 6: "LANE_STATE_GO", 7: "LANE_STATE_FLASHING_STOP", 8: "LANE_STATE_FLASHING_CAUTION"} 47 | 48 | VALID_PATH = r"/mnt/share_disk/yangchen/Scene-Diffusion/valid_set" 49 | TEST_PATH = r"/mnt/share_disk/waymo_dataset/waymodata_1_1/testing" 50 | MODEL_PATH = r"/mnt/share_disk/yangchen/Scene-Diffusion/diffusion_model_last_epoch_weights.pth" 51 | 52 | 53 | def global_to_local(curr_x: float, curr_y: float, curr_heading: float, 54 | point_x: float, point_y: float) -> Tuple[float, float]: 55 | delta_x = point_x - curr_x 56 | delta_y = point_y - curr_y 57 | return delta_x * math.cos(curr_heading) + delta_y * math.sin(curr_heading), \ 58 | delta_y * math.cos(curr_heading) - delta_x * math.sin(curr_heading) 59 | 60 | 61 | def theta_global_to_local(curr_heading: float, heading: float) -> float: 62 | """ 63 | 将世界坐标系下的角度转成自车坐标系下的角度 64 | Args: 65 | curr_heading: 自车在世界系的转角 66 | heading: 要转换的heading 67 | Returns: 自车坐标系下的heading 68 | """ 69 | return normalize_angle(-curr_heading + heading) 70 | 71 | 72 | def normalize_angle(angle: float) -> float: 73 | """ 74 | 归一化弧度值, 使其范围在[-pi, pi]之间 75 | Args: 76 | angle: 输入的弧度 77 | Returns: 归一化之后的弧度 78 | """ 79 | angle = (angle + math.pi) % (2 * math.pi) 80 | if angle < .0: 81 | angle += 2 * math.pi 82 | return angle - math.pi 83 | 84 | 85 | def draw_one_rect(traj_list: List, image: np.ndarray, 86 | box_points: np.ndarray, color: int) -> Tuple[np.ndarray, np.ndarray]: 87 | one_step_color = color / 11 88 | image_list = list() 89 | for i in range(0, HIS_STEP): 90 | yaw = traj_list[i][2] 91 | rotate_matrix = np.array([np.cos(yaw), -np.sin(yaw), 92 | np.sin(yaw), np.cos(yaw)]).reshape(2, 2) 93 | 94 | # 坐标变换 95 | box_points_temp = box_points @ -rotate_matrix 96 | coord = np.array([traj_list[i][0], traj_list[i][1]]) 97 | box_points_temp = box_points_temp + coord 98 | box_points_temp = box_points_temp / pix_size + transform_arr 99 | box_points_temp = np.round(box_points_temp).astype(int).reshape(1, -1, 2) 100 | color_temp = int(one_step_color * (i + 1)) 101 | temp_image = cv2.fillPoly( 102 | np.zeros_like(image), 103 | box_points_temp, 104 | color=[color_temp], 105 | shift=0, 106 | ) 107 | image_list.append(temp_image) 108 | image_list.append(image) 109 | image = np.concatenate(image_list, axis=2) 110 | image = np.expand_dims(np.max(image, axis=2), axis=2) 111 | return image, np.array(traj_list[:HIS_STEP]) 112 | 113 | 114 | def gene_model_input_step_three(one_pkl_data: Tuple, predicted_obs_id: List) -> List: 115 | result = list() 116 | if len(one_pkl_data[1]) != len(predicted_obs_id): 117 | raise ValueError("data preprocess predicted_obs_id size != predicted_obs_id size") 118 | map_image = torch.Tensor(one_pkl_data[0]) 119 | all_obs_index = set([i for i in range(one_pkl_data[2].shape[0])]) 120 | predicted_index = torch.Tensor(one_pkl_data[1]).to(torch.long).view(-1, 1, 1) 121 | all_obs_traj = torch.Tensor(one_pkl_data[2]) 122 | all_obs_feature = torch.Tensor(one_pkl_data[3]) 123 | all_obs_his_traj = all_obs_traj[:, :HIS_STEP] 124 | other_obs_index = list(all_obs_index - set(one_pkl_data[1])) 125 | if len(other_obs_index) > 0: 126 | # other obs历史轨迹 127 | other_obs_traj_index = torch.Tensor(other_obs_index).to(torch.long).view(-1, 1, 1).repeat((1, HIS_STEP, 5)) 128 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index) 129 | other_obs_feature_index = torch.Tensor(other_obs_index).to(torch.long).view(-1, 1).repeat((1, 7)) 130 | # other feature 131 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index) 132 | other_his_traj = other_his_traj[:MAX_OTHER_NUM] 133 | other_feature = other_feature[:MAX_OTHER_NUM] 134 | num_gap = MAX_OTHER_NUM - other_his_traj.shape[0] 135 | if num_gap > 0: 136 | other_his_traj = torch.cat((other_his_traj, 137 | torch.zeros(size=(num_gap, HIS_STEP, 5), 138 | dtype=torch.float32)), dim=0) 139 | other_feature = torch.cat((other_feature, 140 | torch.zeros(size=(num_gap, 7), 141 | dtype=torch.float32)), dim=0) 142 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] * num_gap) 143 | else: 144 | other_his_traj = torch.zeros(size=(MAX_OTHER_NUM, HIS_STEP, 5), dtype=torch.float32) 145 | other_feature = torch.zeros(size=(MAX_OTHER_NUM, 7), dtype=torch.float32) 146 | other_traj_mask = torch.zeros(size=[MAX_OTHER_NUM], dtype=torch.float32) 147 | # 交通灯信息 148 | traffic_light = torch.Tensor(one_pkl_data[4]) 149 | traffic_light_pos = torch.Tensor(one_pkl_data[5]) 150 | traffic_light = traffic_light[:MAX_TRAFFIC_LIGHT_NUM] 151 | traffic_light_pos = traffic_light_pos[:MAX_TRAFFIC_LIGHT_NUM] 152 | num_gap = MAX_TRAFFIC_LIGHT_NUM - traffic_light.shape[0] 153 | if num_gap > 0: 154 | traffic_light = torch.cat((traffic_light, 155 | torch.zeros(size=(num_gap, 91), 156 | dtype=torch.float32)), dim=0) 157 | traffic_light_pos = torch.cat((traffic_light_pos, 158 | torch.zeros(size=(num_gap, 2), 159 | dtype=torch.float32)), dim=0) 160 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) + [0.0] * num_gap) 161 | # 车道线信息 162 | lane_list = list() 163 | for lane_index, lane in enumerate(one_pkl_data[6]): 164 | if lane_index >= MAX_LANE_NUM: 165 | break 166 | point_list = lane[:MAX_POINT_NUM] 167 | point_gap = MAX_POINT_NUM - len(point_list) 168 | # MAX_POINT_NUM, 2 169 | point_list = torch.Tensor(point_list) 170 | if point_gap > 0: 171 | point_list = torch.cat((point_list, 172 | torch.zeros(size=(point_gap, 2), 173 | dtype=torch.float32)), dim=0) 174 | lane_list.append(point_list) 175 | lane_gap = MAX_LANE_NUM - len(lane_list) 176 | if len(lane_list) > 0: 177 | lane_list = torch.stack(lane_list, dim=0) 178 | if lane_gap > 0: 179 | lane_list = torch.cat((lane_list, 180 | torch.zeros(size=(lane_gap, MAX_POINT_NUM, 2), 181 | dtype=torch.float32)), dim=0) 182 | else: 183 | lane_list = torch.zeros(size=(MAX_LANE_NUM, MAX_POINT_NUM, 2), dtype=torch.float32) 184 | map_json = json.dumps(one_pkl_data[6]) 185 | other_his_traj_delt = other_his_traj[:, 1:, :2] - other_his_traj[:, :-1, :2] 186 | other_his_traj_delt = torch.cat((other_his_traj_delt, other_his_traj[:, 1:, 2:]), dim=-1) 187 | other_his_pos = other_his_traj[:, -1, :2] 188 | # 如果需要预测的障碍物大于8个,需要分批次推理 189 | predicted_obs_batch_num = (len(one_pkl_data[1]) // MAX_PRED_NUM) + 1 190 | predicted_his_traj = torch.gather(all_obs_his_traj, 0, predicted_index.repeat((1, HIS_STEP, 5))) 191 | predicted_feature = torch.gather(all_obs_feature, 0, predicted_index.view(-1, 1).repeat((1, 7))) 192 | for i in range(predicted_obs_batch_num): 193 | start_index = i * MAX_PRED_NUM 194 | end_index = start_index + MAX_PRED_NUM - 1 195 | # 需要预测的障碍物历史轨迹 196 | predicted_his_traj_batch = predicted_his_traj[start_index:end_index] 197 | # predicted feature 198 | predicted_feature_batch = predicted_feature[start_index:end_index] 199 | # mask 200 | num_gap = MAX_PRED_NUM - predicted_his_traj_batch.shape[0] 201 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj_batch.shape[0]) + [0.0] * num_gap) 202 | if num_gap > 0: 203 | predicted_his_traj_batch = torch.cat((predicted_his_traj_batch, 204 | torch.zeros(size=(num_gap, HIS_STEP, 5), 205 | dtype=torch.float32)), dim=0) 206 | predicted_feature_batch = torch.cat((predicted_feature_batch, 207 | torch.zeros(size=(num_gap, 7), 208 | dtype=torch.float32)), dim=0) 209 | predicted_his_traj_delt = predicted_his_traj_batch[:, 1:, :2] - predicted_his_traj_batch[:, :-1, :2] 210 | predicted_his_traj_delt = torch.cat((predicted_his_traj_delt, predicted_his_traj_batch[:, 1:, 2:]), dim=-1) 211 | predicted_his_pos = predicted_his_traj_batch[:, -1, :2] 212 | result_dict = { 213 | "map_image": map_image, 214 | "other_his_traj": other_his_traj, 215 | "other_feature": other_feature, 216 | "other_traj_mask": other_traj_mask, 217 | "other_his_traj_delt": other_his_traj_delt, 218 | "other_his_pos": other_his_pos, 219 | "predicted_traj_mask": predicted_traj_mask, 220 | "predicted_his_traj": predicted_his_traj_batch, 221 | "predicted_feature": predicted_feature_batch, 222 | "predicted_his_traj_delt": predicted_his_traj_delt, 223 | "predicted_his_pos": predicted_his_pos, 224 | "traffic_light": traffic_light, 225 | "traffic_light_pos": traffic_light_pos, 226 | "traffic_mask": traffic_mask, 227 | "lane_list": lane_list, 228 | "map_json": map_json 229 | } 230 | for key, value in result_dict.items(): 231 | if not isinstance(value, str): 232 | result_dict[key] = value.unsqueeze(dim=0) 233 | 234 | result.append(result_dict) 235 | return result 236 | 237 | 238 | def gene_model_input_step_two(data_dict: Dict[str, Any]) -> Tuple: 239 | # 把车道线光栅化为图片 240 | image_roadmap = np.zeros((256, 256, 1), dtype=np.uint8) 241 | image_vru_traj = np.zeros((256, 256, 1), dtype=np.uint8) 242 | image_car_traj = np.zeros((256, 256, 1), dtype=np.uint8) 243 | for road_info in data_dict['map_features']: 244 | if len(road_info['polygon_points']) <= 0: 245 | continue 246 | lane_points = np.array(road_info['polygon_points']) 247 | lane_points = lane_points / pix_size + transform_arr 248 | lane_points = np.round(lane_points).astype(int) 249 | image_roadmap = cv2.polylines( 250 | image_roadmap, 251 | [lane_points], 252 | False, 253 | [255], 254 | shift=0, 255 | ) 256 | # 障碍物轨迹,光栅化图片或做成vector 257 | traj_list = list() 258 | obs_feature_list = list() 259 | valid_index = list() 260 | index = 0 261 | predicted_obs_ids = data_dict['predicted_obs_ids'] 262 | predicted_obs_index = list() 263 | predicted_obs_real_index = list() 264 | for obs_index, one_obs_info in enumerate(data_dict['obs_tracks']): 265 | if len(one_obs_info['obs_traj']) < 11 or None in one_obs_info['obs_traj']: 266 | continue 267 | obs_feature = list() 268 | obs_feature.append(one_obs_info['width']) 269 | obs_feature.append(one_obs_info['length']) 270 | type_onehot = [0] * len(object_type) 271 | type_onehot[object_type[one_obs_info['object_type']]] = 1 272 | obs_feature += type_onehot 273 | obs_feature_list.append(obs_feature) 274 | if one_obs_info['obs_id'] in predicted_obs_ids: 275 | predicted_obs_index.append(index) 276 | predicted_obs_real_index.append(one_obs_info['obs_id']) 277 | obs_half_width = 0.5 * one_obs_info['width'] 278 | obs_half_length = 0.5 * one_obs_info['length'] 279 | box_points = np.array([-obs_half_length, -obs_half_width, 280 | obs_half_length, -obs_half_width, 281 | obs_half_length, obs_half_width, 282 | -obs_half_length, obs_half_width]).reshape(4, 2).astype(np.float32) 283 | 284 | if one_obs_info['object_type'] in ('TYPE_PEDESTRIAN', 'TYPE_CYCLIST'): 285 | image_vru_traj, traj = draw_one_rect(one_obs_info['obs_traj'], 286 | image_vru_traj, 287 | box_points, 255) 288 | 289 | elif one_obs_info['object_type'] == 'TYPE_VEHICLE': 290 | valid_index.append(index) 291 | image_car_traj, traj = draw_one_rect(one_obs_info['obs_traj'], 292 | image_car_traj, 293 | box_points, 255) 294 | 295 | else: 296 | image_car_traj, traj = draw_one_rect(one_obs_info['obs_traj'], 297 | image_car_traj, 298 | box_points, 125) 299 | traj_list.append(traj) 300 | index += 1 301 | # 动态地图信息 302 | dynamic_states = list() 303 | dynamic_pos = list() 304 | for key, value in data_dict['dynamic_states'].items(): 305 | dynamic_pos.append(value[0]) 306 | dynamic_state = value[1:] 307 | if len(dynamic_state) < 91: 308 | dynamic_state = dynamic_state + [0] * (91 - len(dynamic_state)) 309 | dynamic_states.append(dynamic_state) 310 | if len(predicted_obs_index) < 1: 311 | raise ValueError("predicted_obs_index < 1") 312 | image = image_roadmap 313 | traj_arr = np.stack(traj_list, axis=0) 314 | obs_feature_list = np.stack(obs_feature_list, axis=0) 315 | 316 | dynamic_states = np.array(dynamic_states) 317 | dynamic_pos = np.array(dynamic_pos) 318 | map_feature_list = [feature['polygon_points'] for feature in data_dict['map_features']] 319 | one_pkl_data = (image.transpose(2, 1, 0), predicted_obs_index, traj_arr, 320 | obs_feature_list, dynamic_states, 321 | dynamic_pos, map_feature_list, predicted_obs_real_index) 322 | return one_pkl_data 323 | 324 | 325 | def gene_model_input_step_one(scenario: scenario_pb2.Scenario) -> Dict[str, Any]: 326 | data_dict = dict() 327 | sdc_track_index = scenario.sdc_track_index 328 | current_time_index = scenario.current_time_index 329 | obs_tracks = list() 330 | # 获取自车当前位置把所有轨迹和车道线转化成自车坐标系 331 | curr_state = scenario.tracks[scenario.sdc_track_index].states[current_time_index] 332 | ego_curr_x, ego_curr_y = curr_state.center_x, curr_state.center_y 333 | ego_curr_heading = curr_state.heading 334 | predicted_obs_ids = list() 335 | for predicted_obs in scenario.tracks_to_predict: 336 | predicted_obs_ids.append(scenario.tracks[predicted_obs.track_index].id) 337 | predicted_obs_ids.append(scenario.tracks[scenario.sdc_track_index].id) 338 | # 一个障碍物的状态(id, 类型, 轨迹) 339 | for track in scenario.tracks: 340 | one_obs_track = dict() 341 | obs_id = track.id 342 | if hasattr(track, "object_type"): 343 | object_type = ObjectType[track.object_type] 344 | else: 345 | object_type = ObjectType[4] 346 | obs_traj = list() 347 | # 遍历障碍物轨迹 348 | for state in track.states: 349 | if not state.valid and obs_id not in predicted_obs_ids: 350 | continue 351 | if not state.valid and obs_id in predicted_obs_ids and len(obs_traj) > 0: 352 | obs_traj.append(obs_traj[-1]) 353 | continue 354 | if 'height' not in one_obs_track and 'length' not in one_obs_track \ 355 | and 'width' not in one_obs_track: 356 | one_obs_track['height'] = state.height 357 | one_obs_track['length'] = state.length 358 | one_obs_track['width'] = state.width 359 | center_x, center_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 360 | state.center_x, state.center_y) 361 | center_heading = theta_global_to_local(ego_curr_heading, state.heading) 362 | # 速度变化有两种方法: 363 | # 1. 在新坐标系做投影 364 | # 2. 计算新坐标系的航向角,用速度乘sin和cos 365 | curr_v = math.sqrt(math.pow(state.velocity_x, 2) + 366 | math.pow(state.velocity_y, 2)) 367 | curr_v_heading = math.atan2(state.velocity_y, state.velocity_x) 368 | curr_v_heading = theta_global_to_local(ego_curr_heading, curr_v_heading) 369 | obs_traj.append( 370 | (center_x, center_y, center_heading, 371 | curr_v * math.cos(curr_v_heading), curr_v * math.sin(curr_v_heading)) 372 | ) 373 | one_obs_track['obs_id'] = obs_id 374 | one_obs_track['object_type'] = object_type 375 | one_obs_track['obs_traj'] = obs_traj 376 | # 轨迹丢失的障碍物不需要 377 | if len(obs_traj) >= 11: 378 | obs_tracks.append(one_obs_track) 379 | # 全局地图信息 380 | map_features = list() 381 | for map_feature in scenario.map_features: 382 | one_map_dict = dict() 383 | map_id = map_feature.id 384 | polygon_points = list() 385 | if hasattr(map_feature, "road_edge") and map_feature.road_edge.polyline: 386 | map_type = "road_edge" 387 | polygon_list = map_feature.road_edge.polyline 388 | elif hasattr(map_feature, "road_line") and map_feature.road_line.polyline: 389 | map_type = "road_line" 390 | polygon_list = map_feature.road_line.polyline 391 | else: 392 | continue 393 | for polygon_point in polygon_list: 394 | polygon_point_x, polygon_point_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 395 | polygon_point.x, polygon_point.y) 396 | polygon_points.append((polygon_point_x, polygon_point_y)) 397 | one_map_dict['map_id'] = map_id 398 | one_map_dict['map_type'] = map_type 399 | one_map_dict['polygon_points'] = polygon_points 400 | map_features.append(one_map_dict) 401 | # 动态地图信息 402 | dynamic_states = dict() 403 | for dynamic_state in scenario.dynamic_map_states: 404 | for lane_state in dynamic_state.lane_states: 405 | lane_id = lane_state.lane 406 | lane_x, lane_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading, 407 | lane_state.stop_point.x, 408 | lane_state.stop_point.y) 409 | if lane_id not in dynamic_states: 410 | dynamic_states[lane_id] = list() 411 | dynamic_states[lane_id].append((lane_x, lane_y)) 412 | state = lane_state.state 413 | dynamic_states[lane_id].append(state) 414 | data_dict['predicted_obs_ids'] = predicted_obs_ids 415 | data_dict['obs_tracks'] = obs_tracks 416 | data_dict['map_features'] = map_features 417 | data_dict['dynamic_states'] = dynamic_states 418 | data_dict['curr_loc'] = (ego_curr_x, ego_curr_y, ego_curr_heading, sdc_track_index) 419 | return data_dict 420 | 421 | 422 | def inference(input_batch: List) -> np.ndarray: 423 | input_shape = (256, 256) 424 | num_timesteps = 100 425 | schedule_low = 1e-4 426 | schedule_high = 0.008 427 | betas = generate_linear_schedule( 428 | num_timesteps, 429 | schedule_low * 1000 / num_timesteps, 430 | schedule_high * 1000 / num_timesteps, 431 | ) 432 | diffusion_model = GaussianDiffusion(SimpleViT(), MapEncoder(), TrajDecorder(), 433 | input_shape, 3, betas=betas) 434 | if torch.cuda.is_available() and torch.cuda.device_count() > 1 and False: 435 | device = torch.device('cuda') 436 | else: 437 | device = torch.device('cpu') 438 | diffusion_model = diffusion_model.to(device) 439 | model_dict = diffusion_model.state_dict() 440 | pretrained_dict = torch.load(MODEL_PATH, map_location=device) 441 | # 模型参数赋值 442 | new_model_dict = dict() 443 | for key in model_dict.keys(): 444 | if ("module." + key) in pretrained_dict: 445 | new_model_dict[key] = pretrained_dict["module." + key] 446 | elif key in pretrained_dict: 447 | new_model_dict[key] = pretrained_dict[key] 448 | else: 449 | print("key: ", key, ", not in pretrained") 450 | diffusion_model.load_state_dict(new_model_dict) 451 | print("load param success") 452 | results = list() 453 | for one_batch_input in input_batch: 454 | result = diffusion_model.sample(one_batch_input)[0] 455 | results.append(result) 456 | results = torch.cat(results, dim=0) 457 | return results.numpy() 458 | 459 | 460 | def local_to_global(ego_heading: float, position_x: np.ndarray, position_y: np.ndarray, ego_local_x: float, 461 | ego_local_y: float) -> Tuple[np.ndarray, np.ndarray]: 462 | """ 463 | 将世界坐标系下的x,y坐标转成自车坐标系下的坐标 464 | Args: 465 | ego_heading: 自车在世界系的转角 466 | position_x: 要转换的x坐标 467 | position_y: 要转换的y坐标 468 | ego_local_x: 自车在世界系的位置 469 | ego_local_y: 自车在世界系的位置 470 | Returns: 世界坐标系的x坐标, 世界坐标系的y坐标 471 | """ 472 | yaw = ego_heading 473 | # global_x = [(ego_local_x + x * math.cos(yaw) - y * math.sin(yaw)) for x, y in zip(position_x.tolist(), position_y.tolist())] 474 | # global_y = [(ego_local_y + x * math.sin(yaw) + y * math.cos(yaw)) for x, y in zip(position_x.tolist(), position_y.tolist())] 475 | # return np.array(global_x), np.array(global_y) 476 | global_x = ego_local_x + position_x * math.cos(yaw) - position_y * math.sin(yaw) 477 | global_y = ego_local_y + position_x * math.sin(yaw) + position_y * math.cos(yaw) 478 | return global_x, global_y 479 | 480 | 481 | def theta_local_to_global(ego_heading: float, heading: np.ndarray) -> np.ndarray: 482 | """ 483 | 将自车坐标系下的角度转成世界坐标系下的角度 484 | Args: 485 | ego_heading: 自车在世界系的转角 486 | heading: 要转换的heading 487 | Returns: 世界标系下的heading 488 | """ 489 | heading = heading.tolist() 490 | heading_list = [] 491 | for one_heading in heading: 492 | heading_list.append(normalize_angle(ego_heading + one_heading)) 493 | return np.array(heading_list) 494 | 495 | 496 | def simulate_with_extrapolation_new( 497 | scenario: scenario_pb2.Scenario, 498 | print_verbose_comments: bool = True) -> tf.Tensor: 499 | vprint = print if print_verbose_comments else lambda arg: None 500 | 501 | # To load the data, we create a simple tensorized version of the object tracks. 502 | logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario) 503 | # Using `ObjectTrajectories` we can select just the objects that we need to 504 | # simulate and remove the "future" part of the Scenario. 505 | vprint(f'Original shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)') 506 | logged_trajectories = logged_trajectories.gather_objects_by_id( 507 | tf.convert_to_tensor(submission_specs.get_sim_agent_ids(scenario))) 508 | logged_trajectories = logged_trajectories.slice_time( 509 | start_index=0, end_index=submission_specs.CURRENT_TIME_INDEX + 1) 510 | vprint(f'Modified shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)') 511 | # We can verify that all of these objects are valid at the last step. 512 | vprint(f'Are all agents valid: {tf.reduce_all(logged_trajectories.valid[:, -1]).numpy()}') 513 | # 数据预处理和模型推理 514 | all_logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario) 515 | all_logged_trajectories = all_logged_trajectories.slice_time( 516 | start_index=0, end_index=submission_specs.N_FULL_SCENARIO_STEPS + 1) 517 | logged_pred = all_logged_trajectories.gather_objects_by_id( 518 | tf.convert_to_tensor(submission_specs.get_evaluation_sim_agent_ids(scenario))) 519 | # if not tf.reduce_all(logged_pred.valid): 520 | # print("logged_pred include invalid state") 521 | predicted_obs_id = submission_specs.get_evaluation_sim_agent_ids(scenario) 522 | data_dict = gene_model_input_step_one(scenario) 523 | one_pkl_data = gene_model_input_step_two(data_dict) 524 | input_batch = gene_model_input_step_three(one_pkl_data, predicted_obs_id) 525 | predicted_obs_traj = inference(input_batch) 526 | predicted_obs_id_in_pkl = one_pkl_data[7] 527 | predicted_obs_traj = predicted_obs_traj[:len(predicted_obs_id_in_pkl)] 528 | # 推理出来的轨迹与id一一对应 529 | predicted_obs_id_traj = {obs_id: predicted_obs_traj[index] for index, obs_id in enumerate(predicted_obs_id_in_pkl)} 530 | # 自车在当前时刻的位置 531 | curr_loc = data_dict['curr_loc'] 532 | 533 | simulated_states = list() 534 | for index, obs_id in enumerate(submission_specs.get_sim_agent_ids(scenario)): 535 | if obs_id not in predicted_obs_id_traj.keys(): 536 | simulated_states.append(np.zeros(shape=(80, 4))) 537 | else: 538 | one_predicted_obs_traj = predicted_obs_id_traj[obs_id] 539 | one_predicted_obs_x = one_predicted_obs_traj[:, 0] 540 | one_predicted_obs_y = one_predicted_obs_traj[:, 1] 541 | one_predicted_obs_z = np.array([float(logged_trajectories.z[:, -1][index])] * 80) 542 | one_predicted_obs_x, one_predicted_obs_y = local_to_global(curr_loc[2], one_predicted_obs_x, 543 | one_predicted_obs_y, curr_loc[0], curr_loc[1]) 544 | one_predicted_obs_heading = theta_local_to_global(curr_loc[2], one_predicted_obs_traj[:, 2]) 545 | one_simulated_state = np.stack((one_predicted_obs_x, one_predicted_obs_y, 546 | one_predicted_obs_z, one_predicted_obs_heading), axis=-1) 547 | simulated_states.append(one_simulated_state) 548 | simulated_states = np.stack(simulated_states, axis=0) 549 | simulated_states = np.stack([simulated_states] * submission_specs.N_ROLLOUTS, axis=0) 550 | simulated_states = tf.convert_to_tensor(simulated_states) 551 | return logged_trajectories, simulated_states 552 | 553 | 554 | def simulate_with_extrapolation( 555 | scenario: scenario_pb2.Scenario, 556 | print_verbose_comments: bool = True) -> tf.Tensor: 557 | vprint = print if print_verbose_comments else lambda arg: None 558 | 559 | # To load the data, we create a simple tensorized version of the object tracks. 560 | logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario) 561 | # Using `ObjectTrajectories` we can select just the objects that we need to 562 | # simulate and remove the "future" part of the Scenario. 563 | vprint(f'Original shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)') 564 | logged_trajectories = logged_trajectories.gather_objects_by_id( 565 | tf.convert_to_tensor(submission_specs.get_sim_agent_ids(scenario))) 566 | logged_trajectories = logged_trajectories.slice_time( 567 | start_index=0, end_index=submission_specs.CURRENT_TIME_INDEX + 1) 568 | vprint(f'Modified shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)') 569 | 570 | # We can verify that all of these objects are valid at the last step. 571 | vprint(f'Are all agents valid: {tf.reduce_all(logged_trajectories.valid[:, -1]).numpy()}') 572 | 573 | # We extract the speed of the sim agents (in the x/y/z components) ready for 574 | # extrapolation (this will be our policy). 575 | states = tf.stack([logged_trajectories.x, logged_trajectories.y, 576 | logged_trajectories.z, logged_trajectories.heading], 577 | axis=-1) 578 | n_objects, n_steps, _ = states.shape 579 | last_velocities = states[:, -1, :3] - states[:, -2, :3] 580 | # We also make the heading constant, so concatenate 0. as angular speed. 581 | last_velocities = tf.concat( 582 | [last_velocities, tf.zeros((n_objects, 1))], axis=-1) 583 | # It can happen that the second to last state of these sim agents might be 584 | # invalid, so we will set a zero speed for them. 585 | vprint(f'Is any 2nd to last state invalid: {tf.reduce_any(tf.logical_not(logged_trajectories.valid[:, -2]))}') 586 | vprint(f'This will result in either min or max speed to be really large: {tf.reduce_max(tf.abs(last_velocities))}') 587 | valid_diff = tf.logical_and(logged_trajectories.valid[:, -1], 588 | logged_trajectories.valid[:, -2]) 589 | # `last_velocities` shape: (n_objects, 4). 590 | last_velocities = tf.where(valid_diff[:, tf.newaxis], 591 | last_velocities, 592 | tf.zeros_like(last_velocities)) 593 | vprint(f'Now this should be back to a normal value: {tf.reduce_max(tf.abs(last_velocities))}') 594 | 595 | # Now we carry over a simulation. As we discussed, we actually want 32 parallel 596 | # simulations, so we make this batched from the very beginning. We add some 597 | # random noise on top of our actions to make sure the behaviours are different. 598 | # To properly scale the noise, we get the max velocities (average over all 599 | # objects, corresponding to axis 0) in each of the dimensions (x/y/z/heading). 600 | NOISE_SCALE = 0.01 601 | # `max_action` shape: (4,). 602 | max_action = tf.reduce_max(last_velocities, axis=0) 603 | # We create `simulated_states` with shape (n_rollouts, n_objects, n_steps, 4). 604 | simulated_states = tf.tile(states[tf.newaxis, :, -1:, :], [submission_specs.N_ROLLOUTS, 1, 1, 1]) 605 | vprint(f'Shape: {simulated_states.shape}') 606 | 607 | for step in range(submission_specs.N_SIMULATION_STEPS): 608 | current_state = simulated_states[:, :, -1, :] 609 | # Random actions, take a normal and normalize by min/max actions 610 | action_noise = tf.random.normal( 611 | current_state.shape, mean=0.0, stddev=NOISE_SCALE) 612 | actions_with_noise = last_velocities[None, :, :] + (action_noise * max_action) 613 | next_state = current_state + actions_with_noise 614 | simulated_states = tf.concat( 615 | [simulated_states, next_state[:, :, None, :]], axis=2) 616 | 617 | # We also need to remove the first time step from `simulated_states` (it was 618 | # still history). 619 | # `simulated_states` shape before: (n_rollouts, n_objects, 81, 4). 620 | # `simulated_states` shape after: (n_rollouts, n_objects, 80, 4). 621 | simulated_states = simulated_states[:, :, 1:, :] 622 | vprint(f'Final simulated states shape: {simulated_states.shape}') 623 | 624 | return logged_trajectories, simulated_states 625 | 626 | 627 | def joint_scene_from_states( 628 | states: tf.Tensor, object_ids: tf.Tensor 629 | ) -> sim_agents_submission_pb2.JointScene: 630 | # States shape: (num_objects, num_steps, 4). 631 | # Objects IDs shape: (num_objects,). 632 | states = states.numpy() 633 | simulated_trajectories = [] 634 | for i_object in range(len(object_ids)): 635 | simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory( 636 | center_x=states[i_object, :, 0], center_y=states[i_object, :, 1], 637 | center_z=states[i_object, :, 2], heading=states[i_object, :, 3], 638 | object_id=object_ids[i_object] 639 | )) 640 | return sim_agents_submission_pb2.JointScene( 641 | simulated_trajectories=simulated_trajectories) 642 | 643 | 644 | # Now we can replicate this strategy to export all the parallel simulations. 645 | def scenario_rollouts_from_states( 646 | scenario: scenario_pb2.Scenario, 647 | states: tf.Tensor, object_ids: tf.Tensor 648 | ) -> sim_agents_submission_pb2.ScenarioRollouts: 649 | # States shape: (num_rollouts, num_objects, num_steps, 4). 650 | # Objects IDs shape: (num_objects,). 651 | joint_scenes = [] 652 | for i_rollout in range(states.shape[0]): 653 | joint_scenes.append(joint_scene_from_states(states[i_rollout], object_ids)) 654 | return sim_agents_submission_pb2.ScenarioRollouts( 655 | # Note: remember to include the Scenario ID in the proto message. 656 | joint_scenes=joint_scenes, scenario_id=scenario.scenario_id) 657 | 658 | 659 | def inference_valid_set(): 660 | file_names = os.path.join(VALID_PATH, "validation*") 661 | match_filenames = tf.io.matching_files(file_names) 662 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data") 663 | dataset_iterator = dataset.as_numpy_iterator() 664 | for data in dataset_iterator: 665 | scenario = scenario_pb2.Scenario.FromString(data) 666 | valid_list = [index for index, track in enumerate(scenario.tracks) if track.states[10].valid] 667 | predicted_list = [obs.track_index for obs in scenario.tracks_to_predict] 668 | print(valid_list) 669 | print(predicted_list) 670 | # logged_trajectories, simulated_states = simulate_with_extrapolation( 671 | # scenario, print_verbose_comments=True) 672 | 673 | logged_trajectories, simulated_states = simulate_with_extrapolation_new( 674 | scenario, print_verbose_comments=True) 675 | # # Package the first simulation into a `JointScene` 676 | joint_scene = joint_scene_from_states(simulated_states[0, :, :, :], 677 | logged_trajectories.object_id) 678 | # Validate the joint scene. Should raise an exception if it's invalid. 679 | submission_specs.validate_joint_scene(joint_scene, scenario) 680 | scenario_rollouts = scenario_rollouts_from_states( 681 | scenario, simulated_states, logged_trajectories.object_id) 682 | # As before, we can validate the message we just generate. 683 | submission_specs.validate_scenario_rollouts(scenario_rollouts, scenario) 684 | # Compute the features for a single JointScene. 685 | # single_scene_features = metric_features.compute_metric_features( 686 | # scenario, joint_scene) 687 | config = metrics.load_metrics_config() 688 | scenario_metrics = metrics.compute_scenario_metrics_for_bundle( 689 | config, scenario, scenario_rollouts) 690 | print(scenario_metrics) 691 | 692 | 693 | def inference_test_set(): 694 | OUTPUT_ROOT_DIRECTORY = r'waymo_output' 695 | os.makedirs(OUTPUT_ROOT_DIRECTORY, exist_ok=True) 696 | output_filenames = [] 697 | file_names = os.path.join(TEST_PATH, "testing*") 698 | match_filenames = tf.io.matching_files(file_names) 699 | for shard_filename in match_filenames: 700 | print(f"{shard_filename} start inference") 701 | # extract the suffix. 702 | shard_suffix = shard_filename.numpy().decode('utf8')[-len('-00000-of-00150'):] 703 | shard_dataset = tf.data.TFRecordDataset([shard_filename]) 704 | shard_iterator = shard_dataset.as_numpy_iterator() 705 | scenario_rollouts = [] 706 | for scenario_bytes in tqdm.tqdm(shard_iterator): 707 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes) 708 | logged_trajectories, simulated_states = simulate_with_extrapolation_new( 709 | scenario, print_verbose_comments=False) 710 | sr = scenario_rollouts_from_states( 711 | scenario, simulated_states, logged_trajectories.object_id) 712 | submission_specs.validate_scenario_rollouts(sr, scenario) 713 | scenario_rollouts.append(sr) 714 | shard_submission = sim_agents_submission_pb2.SimAgentsChallengeSubmission( 715 | scenario_rollouts=scenario_rollouts, 716 | submission_type=sim_agents_submission_pb2.SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION, 717 | account_name='your_account@test.com', 718 | unique_method_name='sim_agents_tutorial', 719 | authors=['test'], 720 | affiliation='waymo', 721 | description='Submission from the Sim Agents tutorial', 722 | method_link='https://waymo.com/open/' 723 | ) 724 | # Now we can export this message to a binproto, saved to local storage. 725 | output_filename = f'submission.binproto{shard_suffix}' 726 | with open(os.path.join(OUTPUT_ROOT_DIRECTORY, output_filename), 'wb') as f: 727 | f.write(shard_submission.SerializeToString()) 728 | output_filenames.append(output_filename) 729 | 730 | # Once we have created all the shards, we can package them directly into a 731 | # tar.gz archive, ready for submission. 732 | with tarfile.open( 733 | os.path.join(OUTPUT_ROOT_DIRECTORY, 'submission.tar.gz'), 'w:gz') as tar: 734 | for output_filename in output_filenames: 735 | tar.add(os.path.join(OUTPUT_ROOT_DIRECTORY, output_filename), 736 | arcname=output_filename) 737 | 738 | 739 | def cal_dynamic_map_states(is_test: bool = True): 740 | if is_test: 741 | file_names = os.path.join(TEST_PATH, "testing*") 742 | match_filenames = tf.io.matching_files(file_names) 743 | else: 744 | file_names = os.path.join(VALID_PATH, "validation*") 745 | match_filenames = tf.io.matching_files(file_names) 746 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data") 747 | dataset_iterator = dataset.as_numpy_iterator() 748 | for scenario_bytes in tqdm.tqdm(dataset_iterator): 749 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes) 750 | print("dynamic_map_states list") 751 | print([len(state.lane_states) for state in scenario.dynamic_map_states]) 752 | 753 | 754 | if __name__ == "__main__": 755 | # inference_valid_set() 756 | # inference_test_set() 757 | cal_dynamic_map_states(is_test=False) 758 | --------------------------------------------------------------------------------