├── run_main.sh
├── qualitative_results
├── 0_ground_truth.gif
├── 0_model_output.gif
├── 1_ground_truth.gif
└── 1_model_output.gif
├── .gitignore
├── net_works
├── __init__.py
├── traj_decoder.py
├── transformer.py
├── back_bone.py
├── attention.py
├── scene_encoder.py
└── diffusion.py
├── common
├── data_preprocess_config.py
├── obs_type.py
├── config_result.py
├── data_train_model_config.py
├── __init__.py
├── data_config.py
├── data.py
└── waymo_dataset.py
├── utils
├── __init__.py
├── visualize_utils.py
├── math_utils.py
├── map_utils.py
└── data_utils.py
├── tasks
├── __init__.py
├── base_task.py
├── data_split_task.py
├── data_count_task.py
├── data_preprocess_task.py
├── train_model_task.py
├── load_config_task.py
└── show_result_task.py
├── config.yaml
├── main.py
├── environment.yml
├── README.md
├── LICENSE
└── gene_submission.py
/run_main.sh:
--------------------------------------------------------------------------------
1 | nohup python3 -u main.py >> main.log &
--------------------------------------------------------------------------------
/qualitative_results/0_ground_truth.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/0_ground_truth.gif
--------------------------------------------------------------------------------
/qualitative_results/0_model_output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/0_model_output.gif
--------------------------------------------------------------------------------
/qualitative_results/1_ground_truth.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/1_ground_truth.gif
--------------------------------------------------------------------------------
/qualitative_results/1_model_output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangchen1997/WcDT/HEAD/qualitative_results/1_model_output.gif
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | result_data/*
2 | */__pycache__
3 | .idea/*
4 | *.png
5 | *.pkl
6 | waymo_open_dataset_/*
7 | test_set/*
8 | valid_set/*
9 | *.pth
10 | output/*
11 | data_set/*
12 | data_output/*
--------------------------------------------------------------------------------
/net_works/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: __init__.py.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | from net_works.back_bone import BackBone
10 |
11 | BackBone = BackBone
12 |
--------------------------------------------------------------------------------
/common/data_preprocess_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_preprocess_config.py
6 | @Author: YangChen
7 | @Date: 2023/12/24
8 | """
9 | from common.data import BaseConfig
10 |
11 |
12 | class DataPreprocessConfig(BaseConfig):
13 | data_size: int = 0
14 | max_data_size: int = 0
15 | num_works: int = 1
16 |
17 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: __init__.py.py
6 | @Author: YangChen
7 | @Date: 2023/12/21
8 | """
9 | from utils.map_utils import MapUtil
10 | from utils.data_utils import DataUtil
11 | from utils.math_utils import MathUtil
12 | from utils.visualize_utils import VisualizeUtil
13 |
14 | MapUtil = MapUtil
15 | DataUtil = DataUtil
16 | MathUtil = MathUtil
17 | VisualizeUtil = VisualizeUtil
18 |
--------------------------------------------------------------------------------
/common/obs_type.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: obs_type.py
6 | @Author: YangChen
7 | @Date: 2023/12/24
8 | """
9 |
10 | ObjectType = {
11 | 0: "TYPE_UNSET",
12 | 1: "TYPE_VEHICLE",
13 | 2: "TYPE_PEDESTRIAN",
14 | 3: "TYPE_CYCLIST",
15 | 4: "TYPE_OTHER"
16 | }
17 | MapState = {
18 | 0: "LANE_STATE_UNKNOWN",
19 | 1: "LANE_STATE_ARROW_STOP",
20 | 2: "LANE_STATE_ARROW_CAUTION",
21 | 3: "LANE_STATE_ARROW_GO",
22 | 4: "LANE_STATE_STOP",
23 | 5: "LANE_STATE_CAUTION",
24 | 6: "LANE_STATE_GO",
25 | 7: "LANE_STATE_FLASHING_STOP",
26 | 8: "LANE_STATE_FLASHING_CAUTION"
27 | }
28 |
--------------------------------------------------------------------------------
/common/config_result.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_load_config_result.py
6 | @Author: YangChen
7 | @Date: 2023/12/21
8 | """
9 | from dataclasses import dataclass
10 |
11 | from common.data import TaskLogger
12 | from common.data_config import TaskConfig
13 | from common.data_preprocess_config import DataPreprocessConfig
14 | from common.data_train_model_config import TrainModelConfig
15 |
16 |
17 | @dataclass
18 | class LoadConfigResultDate:
19 | task_config: TaskConfig = None
20 | data_preprocess_config: DataPreprocessConfig = None
21 | train_model_config: TrainModelConfig = None
22 | task_id: str = ""
23 | task_logger: TaskLogger = None
24 |
--------------------------------------------------------------------------------
/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: __init__.py.py
6 | @Author: YangChen
7 | @Date: 2023/12/20a
8 | """
9 |
10 | from tasks.base_task import BaseTask
11 | from tasks.load_config_task import LoadConfigTask
12 | from tasks.data_preprocess_task import DataPreprocessTask
13 | from tasks.data_split_task import DataSplitTask
14 | from tasks.data_count_task import DataCountTask
15 | from tasks.train_model_task import TrainModelTask
16 | from tasks.show_result_task import ShowResultsTask
17 |
18 | BaseTask = BaseTask
19 | LoadConfigTask = LoadConfigTask
20 | DataPreprocessTask = DataPreprocessTask
21 | DataSplitTask = DataSplitTask
22 | TrainModelTask = TrainModelTask
23 | ShowResultsTask = ShowResultsTask
24 |
--------------------------------------------------------------------------------
/common/data_train_model_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: train_model_config.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | from typing import List
10 |
11 | from common.data import BaseConfig
12 |
13 |
14 | class TrainModelConfig(BaseConfig):
15 | use_gpu: bool = False
16 | gpu_ids: List[int] = None
17 | batch_size: int = 0
18 | num_works: int = 20
19 | his_step: int = 11
20 | max_pred_num: int = 8
21 | max_other_num: int = 6
22 | max_traffic_light: int = 8
23 | max_lane_num: int = 32
24 | max_point_num: int = 128
25 | num_head: int = 8
26 | attention_dim: int = 128
27 | multimodal: int = 10
28 | time_steps: int = 100
29 | schedule: str = "cosine"
30 | num_epoch: int = 0
31 | init_lr: float = 0.00001
32 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: __init__.py.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | from common.data import TaskType, TaskLogger
10 | from common.data_config import TaskConfig
11 | from common.config_result import LoadConfigResultDate
12 | from common.data_train_model_config import TrainModelConfig
13 | from common.obs_type import ObjectType, MapState
14 | from common.data_preprocess_config import DataPreprocessConfig
15 | from common.waymo_dataset import WaymoDataset
16 |
17 |
18 | TaskType = TaskType
19 | TaskLogger = TaskLogger
20 | TaskConfig = TaskConfig
21 | LoadConfigResultDate = LoadConfigResultDate
22 | ObjectType = ObjectType
23 | MapState = MapState
24 | DataPreprocessConfig = DataPreprocessConfig
25 | TrainModelConfig = TrainModelConfig
26 | WaymoDataset = WaymoDataset
--------------------------------------------------------------------------------
/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: base_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | import os
10 | import shutil
11 | from abc import abstractmethod, ABC
12 |
13 | from common import TaskType, LoadConfigResultDate
14 |
15 |
16 | class BaseTask(ABC):
17 | def __init__(self):
18 | super(BaseTask, self).__init__()
19 | self.task_type = TaskType.UNKNOWN
20 |
21 | @abstractmethod
22 | def execute(self, result_info: LoadConfigResultDate):
23 | pass
24 |
25 | @staticmethod
26 | def check_dir_exist(input_dir: str):
27 | if not os.path.exists(input_dir) or \
28 | not os.path.isdir(input_dir) or \
29 | len(os.listdir(input_dir)) <= 0:
30 | raise FileNotFoundError(f"data_preprocess_dir error: {input_dir}, dir is None")
31 |
32 | @staticmethod
33 | def rebuild_dir(input_dir: str):
34 | if os.path.exists(input_dir):
35 | shutil.rmtree(input_dir)
36 | os.makedirs(input_dir, exist_ok=True)
37 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | task_config:
2 | task_list:
3 | # - "DATA_PREPROCESS"
4 | # - "DATA_SPLIT"
5 | # - "DATA_COUNT"
6 | - "TRAIN_MODEL"
7 | # - "SHOW_RESULTS"
8 | # - "EVAL_MODEL"
9 | # - "GENE_SUBMISSION"
10 | # 程序执行过程中的输出总目录
11 | output_dir: "output"
12 | log_dir: "log"
13 | image_dir: "result_image"
14 | model_dir: "model"
15 | result_dir: "result"
16 | pre_train_model: ""
17 | waymo_train_dir: ""
18 | waymo_val_dir: ""
19 | waymo_test_dir: ""
20 | # 数据预处理和打包的输出总目录
21 | data_output: "data_output"
22 | data_preprocess_dir: "data_preprocess_dir"
23 | train_dir: "train_dir"
24 | val_dir: "val_dir"
25 | test_dir: "test_dir"
26 |
27 | data_preprocess_config:
28 | data_size: 100
29 | max_data_size: 2000
30 | num_works: 20
31 |
32 |
33 | train_model_config:
34 | use_gpu: False
35 | gpu_ids:
36 | - 6
37 | - 7
38 | batch_size: 4
39 | num_works: 0
40 | his_step: 11
41 | max_pred_num: 8
42 | max_other_num: 6
43 | max_traffic_light: 8
44 | max_lane_num: 32
45 | max_point_num: 128
46 | num_head: 8
47 | attention_dim: 128
48 | multimodal: 10
49 | time_steps: 50
50 | # cosine or linear
51 | schedule: "linear"
52 | num_epoch: 200
53 | init_lr: 0.0001
54 |
55 |
56 |
--------------------------------------------------------------------------------
/common/data_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_config.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | from typing import List
10 |
11 | from common.data import TaskType, BaseConfig
12 |
13 |
14 | class TaskConfig(BaseConfig):
15 | __task_list: List[TaskType] = None
16 | # 所有输出的目录
17 | output_dir: str = ""
18 | # log保存路径
19 | log_dir: str = ""
20 | model_dir: str = ""
21 | result_dir: str = ""
22 | pre_train_model: str = ""
23 | # waymo数据的目录
24 | waymo_train_dir: str = ""
25 | waymo_val_dir: str = ""
26 | waymo_test_dir: str = ""
27 | # 训练产生的图片保存路径
28 | image_dir: str = ""
29 | # 数据输出
30 | data_output: str = ""
31 | # 数据预处理保存路径
32 | data_preprocess_dir: str = ""
33 | # 数据集构建
34 | train_dir: str = ""
35 | val_dir: str = ""
36 | test_dir: str = ""
37 |
38 | @property
39 | def task_list(self) -> List[TaskType]:
40 | return self.__task_list
41 |
42 | @task_list.setter
43 | def task_list(self, task_list: List[str]):
44 | self.__task_list = [TaskType(task_name) for task_name in task_list]
45 |
46 | def check_config(self):
47 | """
48 | 检查配置文件的输入
49 | @return:
50 | """
51 | if len(self.task_list) < 0:
52 | raise Warning("task_list is None")
53 |
--------------------------------------------------------------------------------
/net_works/traj_decoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: traj_decoder.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | import torch
10 | from torch import nn
11 |
12 |
13 | class TrajDecoder(nn.Module):
14 | def __init__(self, multimodal: int = 10, dim: int = 256, future_step: int = 80):
15 | super(TrajDecoder, self).__init__()
16 | self.multimodal = multimodal
17 | self.future_step = future_step
18 | self.one_modal = (future_step * 3 + 1)
19 | output_dim = multimodal * self.one_modal
20 | self.decoder = nn.Sequential(
21 | nn.Linear(dim, dim * 2),
22 | nn.LayerNorm(dim * 2),
23 | nn.ReLU(inplace=True),
24 | nn.Linear(dim * 2, dim),
25 | nn.LayerNorm(dim),
26 | nn.ReLU(inplace=True),
27 | nn.Linear(dim, dim),
28 | nn.LayerNorm(dim),
29 | nn.ReLU(inplace=True),
30 | nn.Linear(dim, output_dim)
31 | )
32 |
33 | def forward(self, input_x):
34 | batch_size, obs_num = input_x.shape[0], input_x.shape[1]
35 | decoder_output = self.decoder(input_x)
36 | decoder_output = decoder_output.view(batch_size, obs_num, self.multimodal, self.one_modal)
37 | traj = decoder_output[:, :, :, :-1].view(batch_size, obs_num, self.multimodal, self.future_step, 3)
38 | confidence = decoder_output[:, :, :, -1]
39 | return traj, confidence
40 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: main.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | from typing import List
10 |
11 | from common import TaskType, LoadConfigResultDate, TaskLogger
12 | from tasks import BaseTask
13 | from tasks import (LoadConfigTask, DataPreprocessTask, DataSplitTask,
14 | DataCountTask, TrainModelTask, ShowResultsTask)
15 |
16 |
17 | class TaskFactory:
18 |
19 | @staticmethod
20 | def init_config() -> LoadConfigResultDate:
21 | result_info = LoadConfigResultDate()
22 | load_config_task = LoadConfigTask()
23 | load_config_task.execute(result_info)
24 | return result_info
25 |
26 | @staticmethod
27 | def init_tasks(task_type_list: List[TaskType]) -> List[BaseTask]:
28 | task_list = list()
29 | for task_type in task_type_list:
30 | if task_type == TaskType.DATA_PREPROCESS:
31 | task_list.append(DataPreprocessTask())
32 | elif task_type == TaskType.DATA_SPLIT:
33 | task_list.append(DataSplitTask())
34 | elif task_type == TaskType.DATA_COUNT:
35 | task_list.append(DataCountTask())
36 | elif task_type == TaskType.TRAIN_MODEL:
37 | task_list.append(TrainModelTask())
38 | elif task_type == TaskType.SHOW_RESULTS:
39 | task_list.append(ShowResultsTask())
40 | return task_list
41 |
42 |
43 | def execute_tasks():
44 | load_config_result = TaskFactory.init_config()
45 | task_list = TaskFactory.init_tasks(load_config_result.task_config.task_list)
46 | task_logger: TaskLogger = load_config_result.task_logger
47 | for task in task_list:
48 | task_logger.logger.info(f"task type {task.task_type.value} start")
49 | task.execute(load_config_result)
50 | task_logger.logger.info(f"task type {task.task_type.value} success")
51 |
52 |
53 | if __name__ == "__main__":
54 | execute_tasks()
55 |
--------------------------------------------------------------------------------
/utils/visualize_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: visualize_utils.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | from typing import Any, Dict
10 |
11 | import torch
12 | from matplotlib import pyplot as plt
13 |
14 |
15 | class VisualizeUtil:
16 | @staticmethod
17 | def show_result(image_path: str, min_loss_traj: torch.Tensor, data: Dict):
18 | min_loss_traj = min_loss_traj[0]
19 | predicted_traj_mask = data['predicted_traj_mask'][0]
20 | predicted_future_traj = data['predicted_future_traj'][0]
21 | predicted_his_traj = data['predicted_his_traj'][0]
22 | predicted_num = 0
23 | for i in range(predicted_traj_mask.shape[0]):
24 | if int(predicted_traj_mask[i]) == 1:
25 | predicted_num += 1
26 | generate_traj = min_loss_traj[:predicted_num]
27 | predicted_future_traj = predicted_future_traj[:predicted_num]
28 | predicted_his_traj = predicted_his_traj[:predicted_num]
29 | map_feature_list = eval(data['map_json'][0])
30 | real_traj = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, :2].detach().numpy()
31 | model_output = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, :2].detach().numpy()
32 | fig, ax = plt.subplots(1, 2, figsize=(10, 5))
33 | # 画地图
34 | for map_feature in map_feature_list:
35 | x_list = [float(point[0]) for point in map_feature]
36 | y_list = [float(point[1]) for point in map_feature]
37 | ax[0].plot(x_list, y_list, color="grey")
38 | ax[1].plot(x_list, y_list, color="grey")
39 | # 画原图,画模型输出
40 | for i in range(predicted_num):
41 | ax[0].plot(real_traj[i, :, 0], real_traj[i, :, 1])
42 | ax[1].plot(model_output[i, :, 0], model_output[i, :, 1])
43 |
44 | # label = 'Epoch {0}'.format(num_epoch)
45 | # plt.show()
46 | # fig.text(0.5, 0.04, label, ha='center')
47 | plt.savefig(image_path)
48 | plt.close('all') # 避免内存泄漏
49 | print("save_image_success")
50 |
--------------------------------------------------------------------------------
/utils/math_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: math_utils.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | import numpy as np
10 | import torch
11 |
12 |
13 | class MathUtil:
14 |
15 | @staticmethod
16 | def generate_linear_schedule(time_steps: int, low: float = 1e-5, high: float = 2e-5):
17 | return np.linspace(low, high, time_steps)
18 |
19 | @staticmethod
20 | def step_cos(time_step: int, time_steps: int, s: float):
21 | return (np.cos((time_step / time_steps + s) / (1 + s) * np.pi / 2)) ** 2
22 |
23 | # @classmethod
24 | # def generate_cosine_schedule(cls, time_steps: int, s: float = 0.008):
25 | # alphas = []
26 | # f0 = cls.step_cos(0, time_steps, s)
27 | # for t in range(time_steps + 1):
28 | # alphas.append(cls.step_cos(t, time_steps, s) / f0)
29 | # betas = []
30 | # for t in range(1, time_steps + 1):
31 | # betas.append(min(1 - alphas[t] / alphas[t - 1], 0.999))
32 | # return np.array(betas)
33 |
34 | @classmethod
35 | def generate_cosine_schedule(cls, time_steps: int, s: float = 0.008):
36 | steps = time_steps + 1
37 | x = torch.linspace(0, time_steps, steps)
38 | alphas_cum_prod = torch.cos(((x / time_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
39 | alphas_cum_prod = alphas_cum_prod / alphas_cum_prod[0]
40 | betas = 1 - (alphas_cum_prod[1:] / alphas_cum_prod[:-1])
41 | return torch.clip(betas, 0.0001, 0.9999).detach().numpy()
42 |
43 | @staticmethod
44 | def post_process_output(generate_traj: torch.Tensor, predicted_his_traj: torch.Tensor) -> torch.Tensor:
45 | delt_t = 0.1
46 | batch_size = generate_traj.shape[0]
47 | num_obs = generate_traj.shape[1]
48 | vx = generate_traj[:, :, :, :, 0] / delt_t
49 | vy = generate_traj[:, :, :, :, 1] / delt_t
50 | start_x = predicted_his_traj[:, :, -1, 0].view(batch_size, num_obs, 1, 1)
51 | start_y = predicted_his_traj[:, :, -1, 1].view(batch_size, num_obs, 1, 1)
52 | start_heading = predicted_his_traj[:, :, -1, 2].view(batch_size, num_obs, 1, 1)
53 | x = torch.cumsum(generate_traj[:, :, :, :, 0], dim=-1) + start_x
54 | y = torch.cumsum(generate_traj[:, :, :, :, 1], dim=-1) + start_y
55 | heading = torch.cumsum(generate_traj[:, :, :, :, 2], dim=-1) + start_heading
56 | output = torch.stack((x, y, heading, vx, vy), dim=-1)
57 | return output
58 |
--------------------------------------------------------------------------------
/common/data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | import logging
10 | from dataclasses import dataclass
11 | from enum import Enum, unique
12 | from typing import Any
13 |
14 |
15 | def object_dict_print(obj: Any) -> str:
16 | """
17 | 打印基础类里面所有的属性
18 | @param obj:
19 | @return:
20 | """
21 | result_str = ""
22 | for key, value in obj.__dict__.items():
23 | if isinstance(value, list) and len(value) > 0:
24 | list_str = f"{key}:\n\t["
25 | for row in value:
26 | list_str += f"{str(row)}, "
27 | list_str = list_str[:-2]
28 | list_str += "]\n"
29 | result_str += list_str
30 | else:
31 | result_str += f"{key}: {value} \n"
32 | return result_str
33 |
34 |
35 | class TaskType(Enum):
36 | LOAD_CONFIG = "LOAD_CONFIG"
37 | DATA_PREPROCESS = "DATA_PREPROCESS"
38 | DATA_SPLIT = "DATA_SPLIT"
39 | DATA_COUNT = "DATA_COUNT"
40 | TRAIN_MODEL = "TRAIN_MODEL"
41 | SHOW_RESULTS = "SHOW_RESULTS"
42 | EVAL_MODEL = "EVAL_MODEL"
43 | GENE_SUBMISSION = "GENE_SUBMISSION"
44 | UNKNOWN = "UNKNOWN"
45 |
46 | def __str__(self):
47 | return self.value
48 |
49 |
50 | @dataclass
51 | class BaseConfig(object):
52 |
53 | def __str__(self) -> str:
54 | return object_dict_print(self)
55 |
56 |
57 | class TaskLogger(object):
58 | """
59 | 输出日志
60 | Args:
61 | log_path: 日志的路径
62 | """
63 |
64 | def __init__(self, log_path: str):
65 | super(TaskLogger, self).__init__()
66 | # 创建一个日志器
67 | self.logger = logging.getLogger("logger")
68 |
69 | # 设置日志输出的最低等级,低于当前等级则会被忽略
70 | self.logger.setLevel(logging.INFO)
71 |
72 | # 创建处理器:sh为控制台处理器,fh为文件处理器
73 | sh = logging.StreamHandler()
74 | fh = logging.FileHandler(log_path, encoding="UTF-8", mode='w')
75 |
76 | # 创建格式器,并将sh,fh设置对应的格式
77 | format_str = "%(asctime)s -%(name)s -%(levelname)-8s -%(filename)s(line: %(lineno)s): %(message)s"
78 | formatter = logging.Formatter(fmt=format_str, datefmt="%Y/%m/%d %X")
79 | sh.setFormatter(formatter)
80 | fh.setFormatter(formatter)
81 |
82 | # 将处理器,添加至日志器中
83 | self.logger.addHandler(sh)
84 | self.logger.addHandler(fh)
85 |
86 | def get_logger(self) -> logging.Logger:
87 | return self.logger
88 |
89 |
90 | if __name__ == "__main__":
91 | pass
92 |
--------------------------------------------------------------------------------
/tasks/data_split_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_split_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/24
8 | """
9 | import os
10 | import pickle
11 | from multiprocessing import Pool
12 | from typing import Any, Dict
13 |
14 | from tqdm import tqdm
15 |
16 | from common import TaskType, LoadConfigResultDate
17 | from tasks import BaseTask
18 | from utils import DataUtil
19 |
20 |
21 | class DataSplitTask(BaseTask):
22 | def __init__(self):
23 | super(DataSplitTask, self).__init__()
24 | self.task_type = TaskType.DATA_SPLIT
25 |
26 | def execute(self, result_info: LoadConfigResultDate):
27 | logger = result_info.task_logger.get_logger()
28 | train_dir = result_info.task_config.train_dir
29 | self.rebuild_dir(train_dir)
30 | data_preprocess_dir = result_info.task_config.data_preprocess_dir
31 | self.check_dir_exist(data_preprocess_dir)
32 | pkl_list = sorted(os.listdir(data_preprocess_dir),
33 | key=lambda x: int(x[:-4].split('_')[-1]))
34 | data_set_num = 0
35 | his_step = result_info.train_model_config.his_step
36 | for pkl_path in pkl_list:
37 | pool = Pool(result_info.data_preprocess_config.num_works)
38 | with open(os.path.join(data_preprocess_dir, pkl_path), "rb") as f:
39 | pkl_obj = pickle.load(f)
40 | process_list = list()
41 | for one_pkl_dict in pkl_obj:
42 | data_set_num += 1
43 | one_pkl_path = os.path.join(train_dir, f"dataset_{data_set_num}.pkl")
44 | process_list.append(
45 | pool.apply_async(
46 | self.save_split_data,
47 | kwds=dict(
48 | one_pkl_dict=one_pkl_dict,
49 | his_step=his_step,
50 | one_pkl_path=one_pkl_path
51 | )
52 | )
53 | )
54 | for process in tqdm(process_list):
55 | try:
56 | process.get()
57 | except Exception as e:
58 | logger.error(e)
59 | finally:
60 | continue
61 | pool.close()
62 |
63 | @staticmethod
64 | def save_split_data(one_pkl_dict: Dict[str, Any], his_step: int, one_pkl_path: str):
65 | pkl_data = DataUtil.split_pkl_data(one_pkl_dict, his_step)
66 | if pkl_data and len(pkl_data) > 0:
67 | with open(one_pkl_path, "wb") as f:
68 | pickle.dump(pkl_data, f)
69 |
--------------------------------------------------------------------------------
/utils/map_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: map_utils.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | import math
10 | from typing import Tuple
11 |
12 | import numpy as np
13 |
14 |
15 | class MapUtil:
16 | @staticmethod
17 | def local_to_global(ego_heading: float, position_x: np.ndarray,
18 | position_y: np.ndarray, ego_local_x: float,
19 | ego_local_y: float) -> Tuple[np.ndarray, np.ndarray]:
20 | """
21 | 将坐标系从自车系转为世界系
22 | @param ego_heading: 自车在世界系的转角
23 | @param position_x: 要转换的x坐标
24 | @param position_y: 要转换的y坐标
25 | @param ego_local_x: 自车在世界系的位置
26 | @param ego_local_y: 自车在世界系的位置
27 | @return: 世界坐标系的x坐标, 世界坐标系的y坐标
28 | """
29 | yaw = ego_heading
30 | global_x = ego_local_x + position_x * math.cos(yaw) - position_y * math.sin(yaw)
31 | global_y = ego_local_y + position_x * math.sin(yaw) + position_y * math.cos(yaw)
32 | return global_x, global_y
33 |
34 | @classmethod
35 | def theta_local_to_global(cls, ego_heading: float, heading: np.ndarray) -> np.ndarray:
36 | """
37 | 将自车坐标系下的角度转成世界坐标系下的角度
38 | @param ego_heading: 自车在世界系的转角
39 | @param heading: 要转换的heading
40 | @return: 世界标系下的heading
41 | """
42 | heading = heading.tolist()
43 | heading_list = []
44 | for one_heading in heading:
45 | heading_list.append(cls.normalize_angle(ego_heading + one_heading))
46 | return np.array(heading_list)
47 |
48 | @staticmethod
49 | def global_to_local(curr_x: float, curr_y: float, curr_heading: float,
50 | point_x: float, point_y: float) -> Tuple[float, float]:
51 | """
52 | 将世界系的坐标转成自车系
53 | @param curr_x:
54 | @param curr_y:
55 | @param curr_heading:
56 | @param point_x:
57 | @param point_y:
58 | @return:
59 | """
60 | delta_x = point_x - curr_x
61 | delta_y = point_y - curr_y
62 | return delta_x * math.cos(curr_heading) + delta_y * math.sin(curr_heading),\
63 | delta_y * math.cos(curr_heading) - delta_x * math.sin(curr_heading)
64 |
65 | @classmethod
66 | def theta_global_to_local(cls, curr_heading: float, heading: float) -> float:
67 | """
68 | 将世界坐标系下的角度转成自车坐标系下的角度
69 | @param curr_heading: 自车在世界系的转角
70 | @param heading: 要转换的heading
71 | @return: 自车坐标系下的heading
72 | """
73 | return cls.normalize_angle(-curr_heading + heading)
74 |
75 | @staticmethod
76 | def normalize_angle(angle: float) -> float:
77 | """
78 | 归一化弧度值, 使其范围在[-pi, pi]之间
79 | @param angle: 输入的弧度
80 | @return: 归一化之后的弧度
81 | """
82 | angle = (angle + math.pi) % (2 * math.pi)
83 | if angle < .0:
84 | angle += 2 * math.pi
85 | return angle - math.pi
86 |
--------------------------------------------------------------------------------
/tasks/data_count_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_count_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/26
8 | """
9 | import glob
10 | import os
11 | import pickle
12 | import shutil
13 |
14 | from matplotlib import pyplot as plt
15 | from sklearn.decomposition import PCA
16 | from sklearn.cluster import KMeans
17 | from sklearn.preprocessing import MinMaxScaler
18 |
19 | import numpy as np
20 |
21 | from common import TaskType, LoadConfigResultDate
22 | from tasks import BaseTask
23 |
24 |
25 | class DataCountTask(BaseTask):
26 |
27 | def __init__(self):
28 | super(DataCountTask, self).__init__()
29 | self.task_type = TaskType.DATA_COUNT
30 |
31 | def execute(self, result_info: LoadConfigResultDate):
32 | logger = result_info.task_logger.get_logger()
33 | result_dir = result_info.task_config.result_dir
34 | self.rebuild_dir(result_dir)
35 | pkl_dir = result_info.task_config.train_dir
36 | self.check_dir_exist(pkl_dir)
37 | pkl_path_list = glob.glob(os.path.join(pkl_dir, "*.pkl"))
38 | sum_obs = list()
39 | num_predicted_obs = list()
40 | num_other_obs = list()
41 | num_lanes = list()
42 | num_traffic_light = list()
43 | for pkl_path in pkl_path_list:
44 | with open(pkl_path, "rb") as f:
45 | pkl_obj = pickle.load(f)
46 | sum_obs.append(len(pkl_obj[1]))
47 | num_predicted_obs.append(len(pkl_obj[0]))
48 | num_other_obs.append(len(pkl_obj[1]) - len(pkl_obj[0]))
49 | num_lanes.append(len(pkl_obj[4]))
50 | num_traffic_light.append(len(pkl_obj[5]))
51 | logger.info("data count success")
52 | result = {
53 | "sum_obs": sum_obs,
54 | "num_predicted_obs": num_predicted_obs,
55 | "num_other_obs": num_other_obs,
56 | "num_lanes": num_lanes,
57 | "num_traffic_light": num_traffic_light
58 | }
59 | with open(os.path.join(result_dir, f"{result_info.task_id}.pkl"), "wb") as f:
60 | pickle.dump(result, f)
61 | sum_obs = np.array(sum_obs)
62 | num_predicted_obs = np.array(num_predicted_obs)
63 | num_other_obs = np.array(num_other_obs)
64 | num_lanes = np.array(num_lanes)
65 | num_traffic_light = np.array(num_traffic_light)
66 | #
67 | feature_array = np.stack(
68 | (sum_obs, num_predicted_obs, num_other_obs, num_lanes, num_traffic_light),
69 | axis=-1
70 | )
71 | scaler = MinMaxScaler()
72 | # 拟合并转换数据
73 | feature_array = scaler.fit_transform(feature_array)
74 | # 创建PCA模型
75 | pca = PCA(n_components=2)
76 | feature_array = pca.fit_transform(feature_array)
77 | # 聚类
78 | kmeans = KMeans(n_clusters=4, random_state=5)
79 | kmeans.fit(feature_array)
80 | y_pred = kmeans.predict(feature_array)
81 | plt.scatter(feature_array[:, 0], feature_array[:, 1], c=y_pred, s=5, cmap='viridis')
82 | plt.show()
83 |
84 |
--------------------------------------------------------------------------------
/net_works/transformer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: transformer.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | from torch import nn
10 |
11 | from net_works.attention import MultiHeadSelfAttention, MultiHeadCrossAttention
12 |
13 |
14 | class FeedForward(nn.Module):
15 | def __init__(self, dim, hidden_dim=256):
16 | super().__init__()
17 | self.net = nn.Sequential(
18 | nn.LayerNorm(dim),
19 | nn.Linear(dim, hidden_dim),
20 | nn.GELU(),
21 | nn.Linear(hidden_dim, dim),
22 | )
23 |
24 | def forward(self, x):
25 | return self.net(x)
26 |
27 |
28 | class TransformerCrossAttention(nn.Module):
29 |
30 | def __init__(
31 | self, input_dim: int, conditional_dim: int,
32 | head_dim: int = 64, num_heads: int = 8, drop: float = 0.1,
33 | ):
34 | super(TransformerCrossAttention, self).__init__()
35 | self.self_attention = MultiHeadSelfAttention(input_dim, head_dim, num_heads)
36 | self.cross_attention = MultiHeadCrossAttention(input_dim, conditional_dim, head_dim, num_heads)
37 | self.feed_forward = FeedForward(input_dim)
38 | self.drop_self_attention = nn.Dropout(drop)
39 | self.drop_cross_attention = nn.Dropout(drop)
40 | self.norm_self_attention = nn.LayerNorm(input_dim)
41 | self.norm_cross_attention = nn.LayerNorm(input_dim)
42 | self.norm_feed_forward = nn.LayerNorm(input_dim)
43 |
44 | def forward(self, input_x, input_conditional):
45 | # self attention
46 | norm_input_x = self.norm_self_attention(input_x)
47 | self_attention_output = self.self_attention(norm_input_x)
48 | res_output = input_x + self_attention_output
49 | res_output = self.drop_self_attention(res_output)
50 | # cross attention
51 | norm_input_x = self.norm_cross_attention(res_output)
52 | cross_attention_output = self.cross_attention(norm_input_x, input_conditional)
53 | res_output = res_output + cross_attention_output
54 | res_output = self.drop_cross_attention(res_output)
55 | # feed_forward
56 | norm_input_x = self.norm_feed_forward(res_output)
57 | feed_forward_output = self.feed_forward(norm_input_x)
58 | norm_input_x = norm_input_x + feed_forward_output
59 | return norm_input_x
60 |
61 |
62 | class TransformerSelfAttention(nn.Module):
63 |
64 | def __init__(self, input_dim: int, head_dim: int = 64, num_heads: int = 8, drop: float = 0.1):
65 | super(TransformerSelfAttention, self).__init__()
66 | self.self_attention = MultiHeadSelfAttention(input_dim, head_dim, num_heads)
67 | self.feed_forward = FeedForward(input_dim)
68 | self.drop_self_attention = nn.Dropout(drop)
69 | self.norm_self_attention = nn.LayerNorm(input_dim)
70 | self.norm_feed_forward = nn.LayerNorm(input_dim)
71 |
72 | def forward(self, input_x):
73 | # self attention
74 | norm_input_x = self.norm_self_attention(input_x)
75 | self_attention_output = self.self_attention(norm_input_x)
76 | res_output = input_x + self_attention_output
77 | res_output = self.drop_self_attention(res_output)
78 | # feed_forward
79 | norm_input_x = self.norm_feed_forward(res_output)
80 | feed_forward_output = self.feed_forward(norm_input_x)
81 | res_output = res_output + feed_forward_output
82 | return res_output
83 |
--------------------------------------------------------------------------------
/tasks/data_preprocess_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_preprocess_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/23
8 | """
9 | import os
10 | import pickle
11 | import shutil
12 |
13 | import tensorflow as tf
14 | import tqdm
15 | from waymo_open_dataset.protos import scenario_pb2
16 |
17 | from common import LoadConfigResultDate, TaskType
18 | from tasks.base_task import BaseTask
19 | from utils import DataUtil
20 |
21 |
22 | class DataPreprocessTask(BaseTask):
23 | def __init__(self):
24 | super(DataPreprocessTask, self).__init__()
25 | self.task_type = TaskType.DATA_PREPROCESS
26 |
27 | def execute(self, result_info: LoadConfigResultDate):
28 | # 清空原始路径
29 | data_preprocess_dir = result_info.task_config.data_preprocess_dir
30 | if os.path.exists(data_preprocess_dir):
31 | shutil.rmtree(data_preprocess_dir)
32 | os.makedirs(data_preprocess_dir, exist_ok=True)
33 | self.load_waymo_train_data(result_info)
34 |
35 | @staticmethod
36 | def check_waymo_dir(result_info: LoadConfigResultDate):
37 | # 检查waymo数据路径
38 | task_config = result_info.task_config
39 | dataset_names = ["train_set", "val_set", "test_set"]
40 | dataset_dirs = [task_config.waymo_train_dir,
41 | task_config.waymo_val_dir,
42 | task_config.waymo_test_dir]
43 | for dataset_name, dataset_dir in zip(dataset_names, dataset_dirs):
44 | if not os.path.isdir(dataset_dir) or not os.path.isdir(dataset_dir):
45 | error_info = f"waymo {dataset_name} error, {dataset_dir} is not a dir"
46 | result_info.task_logger.logger.error(error_info)
47 | raise ValueError(error_info)
48 | if len(os.listdir(dataset_dir)) == 0 and dataset_names == "train_set":
49 | warn_info = f"waymo {dataset_name} warn, {dataset_dir} size = 0"
50 | result_info.task_logger.logger.warn(warn_info)
51 |
52 | @classmethod
53 | def load_waymo_train_data(cls, result_info: LoadConfigResultDate):
54 | # 读取参数
55 | data_size = result_info.data_preprocess_config.data_size
56 | max_data_size = result_info.data_preprocess_config.max_data_size
57 | waymo_train_dir = result_info.task_config.waymo_train_dir
58 | preprocess_dir = result_info.task_config.data_preprocess_dir
59 | # 加载数据
60 | file_names = os.path.join(waymo_train_dir, "*training.*")
61 | match_filenames = tf.io.matching_files(file_names)
62 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data")
63 | dataset_iterator = dataset.as_numpy_iterator()
64 | bar = tqdm.tqdm(dataset_iterator, desc="load waymo train data: ")
65 | all_data = list()
66 | result_info.task_logger.logger.info("load_waymo_train_data start")
67 | number = 0
68 | for index, data in enumerate(bar):
69 | scenario = scenario_pb2.Scenario.FromString(data)
70 | data_dict = DataUtil.load_scenario_data(scenario)
71 | if len(data_dict) == 0:
72 | result_info.task_logger.logger.warn(f"scenario: {index} obs track is none")
73 | continue
74 | all_data.append(data_dict)
75 | number += 1
76 | if number % data_size == 0:
77 | file_name = f"result_{number}.pkl"
78 | with open(os.path.join(preprocess_dir, file_name), 'wb') as file:
79 | pickle.dump(all_data, file)
80 | all_data.clear()
81 | result_info.task_logger.logger.warn(f"file: {file_name} save success")
82 | if number > max_data_size:
83 | break
84 | result_info.task_logger.logger.info("load_waymo_train_data success")
85 |
--------------------------------------------------------------------------------
/net_works/back_bone.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: back_bone.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | from typing import Dict
10 |
11 | import numpy as np
12 | import torch
13 | from torch import nn
14 |
15 | from net_works.diffusion import GaussianDiffusion
16 | from net_works.scene_encoder import SceneEncoder
17 | from net_works.traj_decoder import TrajDecoder
18 | from utils import MathUtil
19 |
20 |
21 | class MultiModalLoss(nn.Module):
22 | def __init__(self):
23 | super(MultiModalLoss, self).__init__()
24 | self.huber_loss = nn.HuberLoss(reduction="none")
25 | self.confidence_loss = nn.CrossEntropyLoss(reduction="none")
26 |
27 | def forward(self, traj, confidence, predicted_future_traj, predicted_traj_mask):
28 | batch = traj.shape[0]
29 | obs_num = traj.shape[1]
30 | multimodal = traj.shape[2]
31 | future_step = traj.shape[3]
32 | output_dim = traj.shape[4]
33 | predicted_future_extend = predicted_future_traj.unsqueeze(dim=-3)
34 | loss = self.huber_loss(traj, predicted_future_extend)
35 | loss = loss.view(batch, obs_num, multimodal, -1)
36 | loss = torch.mean(loss, dim=-1)
37 | min_loss, _ = torch.min(loss, dim=-1)
38 | min_loss_modal = torch.argmin(loss, dim=-1)
39 | min_loss_index = min_loss_modal.view(batch, obs_num, 1, 1, 1).repeat(
40 | (1, 1, 1, future_step, output_dim)
41 | )
42 | min_loss_traj = torch.gather(traj, -3, min_loss_index).squeeze()
43 | confidence = confidence.view(-1, multimodal)
44 | min_loss_modal = min_loss_modal.view(-1)
45 | confidence_loss = self.confidence_loss(confidence, min_loss_modal).view(batch, obs_num)
46 | traj_loss = torch.sum(min_loss * predicted_traj_mask) / (torch.sum(predicted_traj_mask) + 0.00001)
47 | confidence_loss = torch.sum(confidence_loss * predicted_traj_mask) / (torch.sum(predicted_traj_mask) + 0.00001)
48 | return traj_loss, confidence_loss, min_loss_traj
49 |
50 |
51 | class BackBone(nn.Module):
52 | def __init__(self, betas: np.ndarray):
53 | super(BackBone, self).__init__()
54 | self.diffusion = GaussianDiffusion(betas=betas)
55 | self.scene_encoder = SceneEncoder()
56 | self.traj_decoder = TrajDecoder()
57 | self.multi_modal_loss = MultiModalLoss()
58 |
59 | def forward(self, data: Dict):
60 | # batch, other_obs(10), 40, 7
61 | predicted_feature = data['predicted_feature']
62 | # batch, other_obs(10), 40, 5
63 | other_his_pos = data['other_his_pos']
64 | other_his_traj_delt = data['other_his_traj_delt']
65 | other_feature = data['other_feature']
66 | other_traj_mask = data['other_traj_mask']
67 | # batch, pred_obs(15), 40, 5
68 | predicted_his_pos = data['predicted_his_pos']
69 | predicted_his_traj_delt = data['predicted_his_traj_delt']
70 | predicted_his_traj = data['predicted_his_traj']
71 | # batch, pred_obs(15), 50, 5
72 | predicted_future_traj = data['predicted_future_traj']
73 | predicted_traj_mask = data['predicted_traj_mask']
74 | # batch, tl_num(10), 2
75 | traffic_light = data['traffic_light']
76 | traffic_light_pos = data['traffic_light_pos']
77 | # batch, num_lane(32), num_point(128), 2
78 | lane_list = data['lane_list']
79 | # diffusion训练
80 | noise = torch.randn_like(predicted_his_traj_delt)
81 | diffusion_loss = self.diffusion(data)
82 | noise = self.diffusion.sample(noise, predicted_his_traj)
83 | # scene encoder
84 | scene_feature = self.scene_encoder(
85 | noise, lane_list,
86 | other_his_traj_delt, other_his_pos, other_feature,
87 | predicted_his_traj_delt, predicted_his_pos, predicted_feature,
88 | traffic_light, traffic_light_pos
89 | )
90 | # traj_decoder
91 | traj, confidence = self.traj_decoder(scene_feature)
92 | traj = MathUtil.post_process_output(traj, predicted_his_traj)
93 | traj_loss, confidence_loss, min_loss_traj = self.multi_modal_loss(traj, confidence, predicted_future_traj,
94 | predicted_traj_mask)
95 | return diffusion_loss, traj_loss, confidence_loss, min_loss_traj
96 |
--------------------------------------------------------------------------------
/net_works/attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: attention.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class MultiHeadSelfAttention(nn.Module):
16 | def __init__(self, input_dim: int, head_dim: int = 64, num_heads: int = 8):
17 | super(MultiHeadSelfAttention, self).__init__()
18 | self.embed_dim = head_dim * num_heads
19 | self.num_heads = num_heads
20 | # We assume d_v is the same as d_k here aka head_dim == embed_dim // num_heads
21 | self.head_dim = head_dim
22 | # Learnable parameters
23 | self.query_proj = nn.Linear(input_dim, self.embed_dim)
24 | self.key_proj = nn.Linear(input_dim, self.embed_dim)
25 | self.value_proj = nn.Linear(input_dim, self.embed_dim)
26 | self.out_proj = nn.Linear(self.embed_dim, input_dim)
27 |
28 | def forward(self, input_feature, attn_mask=None):
29 | batch_size, obs_num = input_feature.shape[0], input_feature.shape[1]
30 | # Linear transformations
31 | q = self.query_proj(input_feature)
32 | k = self.key_proj(input_feature)
33 | v = self.value_proj(input_feature)
34 |
35 | # Split the embedding dimension into number of heads
36 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim)
37 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim)
38 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim)
39 |
40 | # Transpose to (batch_size, heads, seq_len, head_dim)
41 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
42 |
43 | # Calculate attention scores
44 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
45 | if attn_mask is not None:
46 | scores = scores.masked_fill(attn_mask, -float('inf'))
47 |
48 | # Apply attention score scaling and softmax
49 | attn = nn.functional.softmax(scores, dim=-1)
50 |
51 | # Multiply attention scores with value and sum to get the final output
52 | out = torch.matmul(attn, v)
53 |
54 | # Transpose back to (batch_size, seq_len, heads, head_dim)
55 | out = out.transpose(1, 2).contiguous().reshape(batch_size, obs_num, self.embed_dim)
56 |
57 | # Final linear transformation
58 | return self.out_proj(out)
59 |
60 |
61 | class MultiHeadCrossAttention(nn.Module):
62 | def __init__(self, input_q_dim: int, input_kv_dim: int, head_dim: int = 64, num_heads: int = 8):
63 | super(MultiHeadCrossAttention, self).__init__()
64 | self.embed_dim = head_dim * num_heads
65 | self.num_heads = num_heads
66 | # We assume d_v is the same as d_k here aka head_dim == embed_dim // num_heads
67 | self.head_dim = head_dim
68 | # Learnable parameters
69 | self.query_proj = nn.Linear(input_q_dim, self.embed_dim)
70 | self.key_proj = nn.Linear(input_kv_dim, self.embed_dim)
71 | self.value_proj = nn.Linear(input_kv_dim, self.embed_dim)
72 | self.out_proj = nn.Linear(self.embed_dim, input_q_dim)
73 |
74 | def forward(self, input_q, input_kv, attn_mask=None):
75 | batch_size, obs_num = input_q.shape[0], input_q.shape[1]
76 | # Linear transformations
77 | q = self.query_proj(input_q)
78 | k = self.key_proj(input_kv)
79 | v = self.value_proj(input_kv)
80 |
81 | # Split the embedding dimension into number of heads
82 | q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim)
83 | k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim)
84 | v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim)
85 |
86 | # Transpose to (batch_size, heads, seq_len, head_dim)
87 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
88 |
89 | # Calculate attention scores
90 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
91 | if attn_mask is not None:
92 | scores = scores.masked_fill(attn_mask, -float('inf'))
93 |
94 | # Apply attention score scaling and softmax
95 | attn = nn.functional.softmax(scores, dim=-1)
96 |
97 | # Multiply attention scores with value and sum to get the final output
98 | out = torch.matmul(attn, v)
99 |
100 | # Transpose back to (batch_size, seq_len, heads, head_dim)
101 | out = out.transpose(1, 2).contiguous().reshape(batch_size, obs_num, self.embed_dim)
102 |
103 | # Final linear transformation
104 | return self.out_proj(out)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: WcDT
2 | channels:
3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
10 | - defaults
11 | dependencies:
12 | - _libgcc_mutex=0.1=conda_forge
13 | - _openmp_mutex=4.5=2_gnu
14 | - bzip2=1.0.8=hd590300_5
15 | - ca-certificates=2023.11.17=hbcca054_0
16 | - ld_impl_linux-64=2.40=h41732ed_0
17 | - libffi=3.4.2=h7f98852_5
18 | - libgcc-ng=13.2.0=h807b86a_3
19 | - libgomp=13.2.0=h807b86a_3
20 | - libnsl=2.0.1=hd590300_0
21 | - libsqlite=3.44.2=h2797004_0
22 | - libuuid=2.38.1=h0b41bf4_0
23 | - libzlib=1.2.13=hd590300_5
24 | - ncurses=6.4=h59595ed_2
25 | - openssl=3.2.0=hd590300_1
26 | - pip=23.3.1=pyhd8ed1ab_0
27 | - python=3.10.1=h543edf9_2_cpython
28 | - readline=8.2=h8228510_1
29 | - sqlite=3.44.2=h2c6b66d_0
30 | - tk=8.6.13=noxft_h4845f30_101
31 | - tzdata=2023c=h71feb2d_0
32 | - wheel=0.42.0=pyhd8ed1ab_0
33 | - xz=5.2.6=h166bdaf_0
34 | - pip:
35 | - absl-py==1.4.0
36 | - array-record==0.5.0
37 | - astunparse==1.6.3
38 | - cachetools==5.3.2
39 | - certifi==2023.11.17
40 | - charset-normalizer==3.3.2
41 | - click==8.1.7
42 | - cloudpickle==3.0.0
43 | - cmake==3.27.9
44 | - contourpy==1.2.0
45 | - cycler==0.12.1
46 | - dask==2023.3.1
47 | - decorator==5.1.1
48 | - dm-tree==0.1.8
49 | - einops==0.7.0
50 | - einsum==0.3.0
51 | - etils==1.5.2
52 | - filelock==3.13.1
53 | - flatbuffers==23.5.26
54 | - fonttools==4.46.0
55 | - fsspec==2023.12.1
56 | - gast==0.4.0
57 | - google-auth==2.16.2
58 | - google-auth-oauthlib==0.4.6
59 | - google-pasta==0.2.0
60 | - googleapis-common-protos==1.62.0
61 | - grpcio==1.60.0
62 | - h5py==3.10.0
63 | - idna==3.6
64 | - imageio==2.33.0
65 | - immutabledict==2.2.0
66 | - importlib-resources==6.1.1
67 | - jinja2==3.1.2
68 | - joblib==1.3.2
69 | - keras==2.11.0
70 | - kiwisolver==1.4.5
71 | - lazy-loader==0.3
72 | - libclang==16.0.6
73 | - lit==17.0.6
74 | - locket==1.0.0
75 | - markdown==3.5.1
76 | - markupsafe==2.1.3
77 | - matplotlib==3.6.1
78 | - mpmath==1.3.0
79 | - networkx==3.2.1
80 | - numpy==1.23.0
81 | - nvidia-cublas-cu11==11.10.3.66
82 | - nvidia-cuda-cupti-cu11==11.7.101
83 | - nvidia-cuda-nvrtc-cu11==11.7.99
84 | - nvidia-cuda-runtime-cu11==11.7.99
85 | - nvidia-cudnn-cu11==8.5.0.96
86 | - nvidia-cufft-cu11==10.9.0.58
87 | - nvidia-curand-cu11==10.2.10.91
88 | - nvidia-cusolver-cu11==11.4.0.1
89 | - nvidia-cusparse-cu11==11.7.4.91
90 | - nvidia-nccl-cu11==2.14.3
91 | - nvidia-nvtx-cu11==11.7.91
92 | - oauthlib==3.2.2
93 | - opencv-python==4.8.0.76
94 | - openexr==1.3.9
95 | - opt-einsum==3.3.0
96 | - packaging==23.2
97 | - pandas==1.5.3
98 | - partd==1.4.1
99 | - pillow==9.2.0
100 | - plotly==5.13.1
101 | - promise==2.3
102 | - protobuf==3.19.6
103 | - psutil==5.9.6
104 | - pyarrow==10.0.0
105 | - pyasn1==0.5.1
106 | - pyasn1-modules==0.3.0
107 | - pyparsing==3.1.1
108 | - python-dateutil==2.8.2
109 | - pytz==2023.3.post1
110 | - pywavelets==1.4.1
111 | - pyyaml==6.0.1
112 | - requests==2.31.0
113 | - requests-oauthlib==1.3.1
114 | - rsa==4.9
115 | - scikit-image==0.20.0
116 | - scikit-learn==1.2.2
117 | - scipy==1.10.1
118 | - setuptools==67.6.0
119 | - six==1.16.0
120 | - sympy==1.12
121 | - tenacity==8.2.3
122 | - tensorboard==2.11.2
123 | - tensorboard-data-server==0.6.1
124 | - tensorboard-plugin-wit==1.8.1
125 | - tensorflow==2.11.0
126 | - tensorflow-addons==0.23.0
127 | - tensorflow-datasets==4.9.0
128 | - tensorflow-estimator==2.11.0
129 | - tensorflow-graphics==2021.12.3
130 | - tensorflow-io-gcs-filesystem==0.34.0
131 | - tensorflow-metadata==1.13.0
132 | - tensorflow-probability==0.19.0
133 | - termcolor==2.4.0
134 | - threadpoolctl==3.2.0
135 | - tifffile==2023.12.9
136 | - toml==0.10.2
137 | - toolz==0.12.0
138 | - torch==2.0.0
139 | - torchaudio==2.0.1
140 | - torchvision==0.15.1
141 | - tqdm==4.66.1
142 | - trimesh==4.0.5
143 | - triton==2.0.0
144 | - typeguard==2.13.3
145 | - typing-extensions==4.9.0
146 | - urllib3==2.1.0
147 | - waymo-open-dataset-tf-2-11-0==1.5.2
148 | - werkzeug==3.0.1
149 | - wrapt==1.16.0
150 | - zipp==3.17.0
151 | prefix: /home/haomo/anaconda3/envs/TSDiT
152 |
--------------------------------------------------------------------------------
/net_works/scene_encoder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: scene_encoder.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | import torch
10 | from torch import nn
11 |
12 | from net_works.transformer import TransformerCrossAttention, TransformerSelfAttention
13 |
14 |
15 | class OtherFeatureFormer(nn.Module):
16 | def __init__(
17 | self, block_num: int, input_dim: int, conditional_dim: int,
18 | head_dim: int = 64, num_heads: int = 8
19 | ):
20 | super(OtherFeatureFormer, self).__init__()
21 | self.blocks = nn.ModuleList([])
22 | for _ in range(block_num):
23 | self.blocks.append(TransformerCrossAttention(input_dim, conditional_dim, head_dim, num_heads))
24 | self.norm = nn.LayerNorm(input_dim)
25 |
26 | def forward(self, input_x, other_feature):
27 | for block in self.blocks:
28 | input_x = block(input_x, other_feature)
29 | return self.norm(input_x)
30 |
31 |
32 | class SelfFeatureFormer(nn.Module):
33 | def __init__(self, block_num: int, input_dim: int, head_dim: int = 64, num_heads: int = 8):
34 | super(SelfFeatureFormer, self).__init__()
35 | self.blocks = nn.ModuleList([])
36 | for _ in range(block_num):
37 | self.blocks.append(TransformerSelfAttention(input_dim, head_dim, num_heads))
38 | self.norm = nn.LayerNorm(input_dim)
39 |
40 | def forward(self, input_x):
41 | for block in self.blocks:
42 | input_x = block(input_x)
43 | return self.norm(input_x)
44 |
45 |
46 | class SceneEncoder(nn.Module):
47 | def __init__(
48 | self, dim: int = 256, embedding_dim: int = 32,
49 | his_step: int = 11, other_agent_depth: int = 4,
50 | map_feature_depth: int = 4, traffic_light_depth: int = 2,
51 | self_attention_depth: int = 4
52 | ):
53 | super(SceneEncoder, self).__init__()
54 | self.embedding_dim = embedding_dim
55 | self.pos_embedding = nn.Sequential(
56 | nn.Linear(2, dim),
57 | nn.LayerNorm(dim),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(dim, dim),
60 | nn.LayerNorm(dim),
61 | nn.ReLU(inplace=True),
62 | nn.Linear(dim, embedding_dim)
63 | )
64 | self.feature_embedding = nn.Sequential(
65 | nn.Linear(7, dim),
66 | nn.LayerNorm(dim),
67 | nn.ReLU(inplace=True),
68 | nn.Linear(dim, dim),
69 | nn.LayerNorm(dim),
70 | nn.ReLU(inplace=True),
71 | nn.Linear(dim, embedding_dim)
72 | )
73 | linear_input_dim = (his_step - 1) * 5 + embedding_dim + embedding_dim
74 | self.linear_input = nn.Sequential(
75 | nn.Linear(linear_input_dim, dim),
76 | nn.LayerNorm(dim),
77 | nn.ReLU(inplace=True),
78 | nn.Linear(dim, dim)
79 | )
80 | self.other_agent_former = OtherFeatureFormer(block_num=other_agent_depth, input_dim=dim, conditional_dim=114)
81 | self.map_former = OtherFeatureFormer(block_num=map_feature_depth, input_dim=dim, conditional_dim=embedding_dim)
82 | self.traffic_light_former = OtherFeatureFormer(block_num=traffic_light_depth, input_dim=dim,
83 | conditional_dim=43)
84 | self.fusion_block = SelfFeatureFormer(block_num=self_attention_depth, input_dim=dim, num_heads=16)
85 |
86 | def forward(
87 | self, noise, lane_list,
88 | other_his_traj_delt, other_his_pos, other_feature,
89 | predicted_his_traj_delt, predicted_his_pos, predicted_feature,
90 | traffic_light, traffic_light_pos
91 | ):
92 | batch_size, obs_num = noise.shape[0], noise.shape[1]
93 | # batch, obs_num(8), his_step, 3
94 | x = predicted_his_traj_delt + (noise * 0.001)
95 | x = torch.flatten(x, start_dim=2)
96 | other_his_traj_delt = torch.flatten(other_his_traj_delt, start_dim=2)
97 | # 对各个位置进行位置编码
98 | lane_list = self.pos_embedding(lane_list)
99 | lane_list = lane_list.view(batch_size, -1, self.embedding_dim)
100 | traffic_light_pos = self.pos_embedding(traffic_light_pos)
101 | other_his_pos = self.pos_embedding(other_his_pos)
102 | predicted_his_pos = self.pos_embedding(predicted_his_pos)
103 | # 对属性进行编码
104 | other_feature = self.feature_embedding(other_feature)
105 | predicted_feature = self.feature_embedding(predicted_feature)
106 | # 组合输入信息
107 | x = torch.cat((x, predicted_his_pos, predicted_feature), dim=-1)
108 | # batch, obs_num(15), 256
109 | x = self.linear_input(x)
110 | # other agent former
111 | other_obs_feature = torch.cat((other_his_traj_delt, other_his_pos, other_feature), dim=-1)
112 | x = self.other_agent_former(x, other_obs_feature)
113 | # map_point_transformer
114 | x = self.map_former(x, lane_list)
115 | # traffic_light_transformer
116 | traffic_light = torch.cat((traffic_light, traffic_light_pos), dim=-1)
117 | x = self.traffic_light_former(x, traffic_light)
118 | x = self.fusion_block(x)
119 | return x
120 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WcDT: World-centric Diffusion Transformer Traffic Scene Generation
2 |
3 | This repository contains the official implementation of WcDT: World-centric Diffusion Transformer for Traffic Scene Generation [[Paper Link](https://arxiv.org/abs/2404.02082)]
4 |
5 | ## Gettting Started
6 |
7 | First of all, we recommend that you read the description of the [Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/) and the [Motion Prediction](https://waymo.com/open/data/motion/) dataset on the waymo website.
8 |
9 | 1. Clone this repository:
10 |
11 | git clone https://github.com/yangchen1997/WcDT.git
12 |
13 | 2. Install the dependencies:
14 |
15 | conda env create -f environment.yml
16 |
17 | 3. Download [Waymo Open Dataset](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_2_0?pli=1)(Please note: If you are downloading waymo datasets for the first time, you need to click "Download" on the waymo website and register your account).After downloading the dataset directory should be organized as follows:
18 |
19 | /path/to/dataset_root/
20 | ├── train_set/
21 | ├── training.tfrecord-00000-of-01000
22 | ├── training.tfrecord-00001-of-01000
23 | ├── ...
24 | └── val_set/
25 | ├── validation.tfrecord-00000-of-00150
26 | ├── validation.tfrecord-00001-of-00150
27 | ├── ...
28 | └── test_set/
29 | ├── testing.tfrecord-00000-of-00150
30 | ├── testing.tfrecord-00001-of-00150
31 | ├── ...
32 |
33 |
34 | ## Tasks
35 |
36 | This project includes the following tasks:
37 |
38 | 1. Data Preprocess: This task is divided into two subtasks, data compression (which removes redundancy from the waymo dataset) and data splitting.
39 |
40 | 2. Training
41 |
42 | 3. Evaluating Models and Visualising Results
43 |
44 |
45 | Before running the project, you need to configure the tasks to be performed in config.yaml.
46 |
47 | tasks config:
48 |
49 | task_config:
50 | task_list:
51 | - "DATA_PREPROCESS"
52 | - "DATA_SPLIT"
53 | - "DATA_COUNT"
54 | - "TRAIN_MODEL"
55 | - "SHOW_RESULTS"
56 | - "EVAL_MODEL"
57 | - "GENE_SUBMISSION"
58 | output_dir: "output"
59 | log_dir: "log"
60 | image_dir: "result_image"
61 | model_dir: "model"
62 | result_dir: "result"
63 | pre_train_model: ""
64 | waymo_train_dir: "path to waymo train_set"
65 | waymo_val_dir: "path to waymo valid_set"
66 | waymo_test_dir: "path to waymo test_set"
67 | data_output: "data_output"
68 | data_preprocess_dir: "data_preprocess_dir"
69 | train_dir: "train_dir"
70 | val_dir: "val_dir"
71 | test_dir: "test_dir"
72 |
73 | start tasks:
74 |
75 | bash run_main.sh
76 |
77 | ### Data Preprocess
78 |
79 | task config:
80 |
81 | data_preprocess_config:
82 | data_size: 100
83 | max_data_size: 2000
84 | num_works: 20
85 |
86 | ### Training
87 |
88 | task config:
89 |
90 | train_model_config:
91 | use_gpu: False
92 | gpu_ids:
93 | - 6
94 | - 7
95 | batch_size: 4
96 | num_works: 0
97 | his_step: 11
98 | max_pred_num: 8
99 | max_other_num: 6
100 | max_traffic_light: 8
101 | max_lane_num: 32
102 | max_point_num: 128
103 | num_head: 8
104 | attention_dim: 128
105 | multimodal: 10
106 | time_steps: 50
107 | # cosine or linear
108 | schedule: "linear"
109 | num_epoch: 200
110 | init_lr: 0.0001
111 |
112 | ### Evaluate
113 |
114 | | Model | ADE↓ | MinADE↓ |
115 | | --- | --- | --- |
116 | | WcDT-64 | 4.872 | 1.962 |
117 | | WcDT-128 | 4.563 | 1.669 |
118 |
119 | ### Qualitative Results
120 |
121 | Demos for lane-changing scenarios:
122 |
123 |
124 | |
125 |
126 |
127 |
128 | Ground truth
129 |
130 | |
131 |
132 |
133 |
134 |
135 | WcDT-128 result
136 |
137 | |
138 |
139 |
140 |
141 |
142 | Demos for more complex turning scenarios:
143 |
144 |
145 | |
146 |
147 |
148 |
149 | Ground truth
150 |
151 | |
152 |
153 |
154 |
155 |
156 | WcDT-128 result
157 |
158 | |
159 |
160 |
161 |
162 | ## Todo List
163 |
164 | - [x] Data Statistics
165 | - [x] Generate Submission
166 | - [ ] Factorized Attention for Temporal Features
167 | - [ ] Graph Attention Mechanisms for Transformer
168 | - [ ] Lane Loss(FDE Loss + Timestep Weighted Loss)
169 | - [ ] Scene Label
170 | - [ ] Upgrade Decoder(Prposed + Refined Trajectory)
171 |
172 | ## Citation
173 |
174 | If you found this repository useful, please consider citing our paper:
175 |
176 | ```bibtex
177 | @article{yang2024wcdt,
178 | title={Wcdt: World-centric diffusion transformer for traffic scene generation},
179 | author={Yang, Chen and He, Yangfan and Tian, Aaron Xuxiang and Chen, Dong and Wang, Jianhui and Shi, Tianyu and Heydarian, Arsalan and Liu, Pei},
180 | journal={arXiv preprint arXiv:2404.02082},
181 | year={2024}
182 | }
183 | ```
184 |
185 |
186 |
--------------------------------------------------------------------------------
/tasks/train_model_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: train_model_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/26
8 | """
9 | import os
10 | import shutil
11 | from typing import Union
12 |
13 | import torch
14 | from torch import optim, nn
15 | from torch.optim import Optimizer
16 | from torch.utils.data import DataLoader
17 | from tqdm import tqdm
18 |
19 | from common import TaskType, LoadConfigResultDate
20 | from common import WaymoDataset
21 | from net_works import BackBone
22 | from tasks import BaseTask
23 | from utils import MathUtil, VisualizeUtil
24 |
25 |
26 | class TrainModelTask(BaseTask):
27 |
28 | def __init__(self):
29 | super(TrainModelTask, self).__init__()
30 | self.task_type = TaskType.TRAIN_MODEL
31 | self.device = torch.device("cpu")
32 | self.multi_gpus = False
33 | self.gpu_ids = list()
34 |
35 | def execute(self, result_info: LoadConfigResultDate):
36 | train_dir = result_info.task_config.train_dir
37 | train_model_config = result_info.train_model_config
38 | self.init_dirs(result_info)
39 | # 初始化device
40 | if train_model_config.use_gpu:
41 | if train_model_config.gpu_ids:
42 | self.device = torch.device(f"cuda:{train_model_config.gpu_ids[0]}")
43 | self.multi_gpus = True
44 | self.gpu_ids = train_model_config.gpu_ids
45 | else:
46 | self.device = torch.device('cuda')
47 | # 初始化dataloader
48 | waymo_dataset = WaymoDataset(
49 | train_dir, train_model_config.his_step, train_model_config.max_pred_num,
50 | train_model_config.max_other_num, train_model_config.max_traffic_light,
51 | train_model_config.max_lane_num, train_model_config.max_point_num
52 | )
53 | data_loader = DataLoader(
54 | waymo_dataset,
55 | shuffle=False,
56 | batch_size=train_model_config.batch_size,
57 | num_workers=train_model_config.num_works,
58 | pin_memory=True,
59 | drop_last=False
60 | )
61 | model = self.init_model(result_info)
62 | model_train = model.train()
63 | optimizer = optim.Adam(model_train.parameters(), lr=train_model_config.init_lr,
64 | betas=(0.9, 0.999), weight_decay=0)
65 | epoch_step = len(waymo_dataset) // train_model_config.batch_size
66 | if epoch_step == 0:
67 | raise ValueError("dataset is too small, epoch_step = 0")
68 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, train_model_config.num_epoch, eta_min=1e-6)
69 | for epoch in range(train_model_config.num_epoch):
70 | self.fit_one_epoch(epoch, epoch_step, optimizer, model, data_loader, result_info)
71 | scheduler.step()
72 |
73 | def fit_one_epoch(
74 | self, epoch_num: int, epoch_step: int,
75 | optimizer: Optimizer, model: Union[BackBone, nn.DataParallel],
76 | data_loader: DataLoader, result_info: LoadConfigResultDate
77 | ):
78 | diffusion_losses = 0
79 | traj_losses = 0
80 | confidence_losses = 0
81 | total_losses = 0
82 | train_model_config = result_info.train_model_config
83 | pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch_num + 1}/{train_model_config.num_epoch}', mininterval=0.3)
84 | torch.autograd.set_detect_anomaly(True)
85 | for iteration, data in enumerate(data_loader):
86 | for key, value in data.items():
87 | if isinstance(value, torch.Tensor):
88 | data[key] = value.to(self.device).to(torch.float32)
89 | optimizer.zero_grad()
90 | diffusion_loss, traj_loss, confidence_loss, min_loss_traj = model(data)
91 | diffusion_loss = diffusion_loss.mean()
92 | traj_loss = traj_loss.mean()
93 | confidence_loss = confidence_loss.mean()
94 | total_loss = diffusion_loss + traj_loss + confidence_loss
95 | total_loss.backward()
96 | optimizer.step()
97 | total_losses += total_loss.item()
98 | diffusion_losses += diffusion_loss.item()
99 | traj_losses += traj_loss.item()
100 | confidence_losses += confidence_loss.item()
101 | pbar.set_postfix(**{'total_loss': total_losses / (iteration + 1),
102 | 'diffusion_loss': diffusion_losses / (iteration + 1),
103 | 'traj_losses': traj_losses / (iteration + 1),
104 | 'confidence_losses': confidence_losses / (iteration + 1)})
105 | pbar.update()
106 | if iteration % 10 == 0:
107 | image_path = os.path.join(result_info.task_config.image_dir,
108 | f"epoch_{epoch_num}_batch_num_{iteration}_image.png")
109 | VisualizeUtil.show_result(image_path, min_loss_traj, data)
110 |
111 | @staticmethod
112 | def init_dirs(result_info: LoadConfigResultDate):
113 | task_config = result_info.task_config
114 | if os.path.exists(task_config.image_dir):
115 | shutil.rmtree(task_config.image_dir)
116 | os.makedirs(task_config.image_dir, exist_ok=True)
117 | os.makedirs(task_config.model_dir, exist_ok=True)
118 |
119 | def init_model(self, result_info: LoadConfigResultDate) -> Union[BackBone, nn.DataParallel]:
120 | train_model_config = result_info.train_model_config
121 | task_config = result_info.task_config
122 | # 初始化diffusion的betas
123 | if train_model_config.schedule == "cosine":
124 | betas = MathUtil.generate_cosine_schedule(train_model_config.time_steps)
125 | else:
126 | schedule_low = 1e-4
127 | schedule_high = 0.008
128 | betas = MathUtil.generate_linear_schedule(
129 | train_model_config.time_steps,
130 | schedule_low * 1000 / train_model_config.time_steps,
131 | schedule_high * 1000 / train_model_config.time_steps,
132 | )
133 | model = BackBone(betas)
134 | if task_config.pre_train_model:
135 | pre_train_model_path = task_config.pre_train_model
136 | model_dict = model.state_dict()
137 | pretrained_dict = torch.load(pre_train_model_path)
138 | # 模型参数赋值
139 | new_model_dict = dict()
140 | for key in model_dict.keys():
141 | if ("module." + key) in pretrained_dict:
142 | new_model_dict[key] = pretrained_dict["module." + key]
143 | elif key in pretrained_dict:
144 | new_model_dict[key] = pretrained_dict[key]
145 | else:
146 | print("key: ", key, ", not in pretrained")
147 | model.load_state_dict(new_model_dict)
148 | result_info.task_logger.logger.info("load pre_train_model success")
149 | model = model.to(self.device)
150 | if self.multi_gpus:
151 | model = nn.DataParallel(model, device_ids=self.gpu_ids, output_device=self.gpu_ids[0])
152 | return model
153 |
--------------------------------------------------------------------------------
/tasks/load_config_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: load_config_task.py
6 | @Author: YangChen
7 | @Date: 2023/12/20
8 | """
9 | import os
10 | import random
11 | from datetime import datetime
12 |
13 | import torch
14 | import yaml
15 |
16 | from common import TaskType, LoadConfigResultDate, TaskLogger
17 | from common import TaskConfig, DataPreprocessConfig, TrainModelConfig
18 | from tasks.base_task import BaseTask
19 |
20 | # 可能在不同的地方调用ConfigFactory,因此优先初始化config_file
21 | path_list = os.path.realpath(__file__).split(os.path.sep)
22 | path_list = path_list[:-2]
23 | path_list.append("config.yaml")
24 | CONFIG_PATH = os.path.sep.join(path_list)
25 |
26 |
27 | class LoadConfigTask(BaseTask):
28 | """
29 | 主要功能:
30 | 加载配置文件
31 | 检查config是否合理
32 | 根据config初始化output
33 | 根据config初始化log
34 | """
35 |
36 | def __init__(self):
37 | super(LoadConfigTask, self).__init__()
38 | self.task_type = TaskType.LOAD_CONFIG
39 | with open(CONFIG_PATH, "rb") as file:
40 | self.__yaml_loader = yaml.safe_load(file)
41 |
42 | def execute(self, result_info: LoadConfigResultDate):
43 | # 初始化结果信息
44 | task_config = self.get_task_config()
45 | data_preprocess_config = self.get_preprocess_config()
46 | train_model_config = self.get_train_model_config()
47 | self.init_dirs_and_log(task_config, result_info)
48 | self.check_preprocess_config(data_preprocess_config)
49 | self.check_train_model_config(train_model_config)
50 | # 结果信息赋值
51 | result_info.task_config = task_config
52 | result_info.data_preprocess_config = data_preprocess_config
53 | result_info.train_model_config = train_model_config
54 |
55 | def get_task_config(self) -> TaskConfig:
56 | task_config = TaskConfig()
57 | self.__init_config_object_attr(task_config, self.__yaml_loader['task_config'])
58 | return task_config
59 |
60 | def get_preprocess_config(self) -> DataPreprocessConfig:
61 | data_preprocess_config = DataPreprocessConfig()
62 | self.__init_config_object_attr(data_preprocess_config, self.__yaml_loader['data_preprocess_config'])
63 | return data_preprocess_config
64 |
65 | def get_train_model_config(self) -> TrainModelConfig:
66 | train_model_config = TrainModelConfig()
67 | self.__init_config_object_attr(train_model_config, self.__yaml_loader['train_model_config'])
68 | return train_model_config
69 |
70 | @staticmethod
71 | def __init_config_object_attr(instance: object, attrs: dict):
72 | """
73 | 根据字典给配置类对象赋值
74 | @param instance: 配置类对象
75 | @param attrs: 从配置文件中读取的属性
76 | @return: None
77 | """
78 | if not instance or not attrs:
79 | return
80 | for name, value in attrs.items():
81 | if hasattr(instance, name):
82 | setattr(instance, name, value)
83 | else:
84 | raise ValueError(f"unknown config, config name is {name}")
85 |
86 | @staticmethod
87 | def init_dirs_and_log(task_config: TaskConfig, result_data: LoadConfigResultDate):
88 | """
89 | 检查task_config的参数并初始化路径
90 | @param task_config:
91 | @param result_data
92 | @return:
93 | """
94 | # output路径检查
95 | os.makedirs(task_config.output_dir, exist_ok=True)
96 | task_config.log_dir = os.path.join(task_config.output_dir, task_config.log_dir)
97 | os.makedirs(task_config.log_dir, exist_ok=True)
98 | # 创建log
99 | now = datetime.now()
100 | formatted_now = now.strftime('%Y%m%d-%H-%M-%S')
101 | random_num = random.randint(10000, 99999)
102 | result_data.task_id = f"{formatted_now}_{random_num}"
103 | log_file_name = os.path.join(task_config.log_dir, f"{result_data.task_id}.log")
104 | result_data.task_logger = TaskLogger(log_file_name)
105 | result_data.task_logger.logger.info(f"task id {result_data.task_id} start")
106 | # task list里的任务必须唯一
107 | if len(task_config.task_list) != len(set(task_config.task_list)):
108 | raise ValueError("task_config must be unique")
109 | # 创建model save dir
110 | task_config.model_dir = os.path.join(task_config.output_dir, task_config.model_dir)
111 | # os.makedirs(task_config.model_dir, exist_ok=True)
112 | # 创建result dir
113 | task_config.result_dir = os.path.join(task_config.output_dir, task_config.result_dir)
114 | # 检查模型路径
115 | if task_config.pre_train_model:
116 | path_type = ".pth"
117 | if not os.path.isfile(task_config.pre_train_model) or \
118 | not os.path.exists(task_config.pre_train_model) or \
119 | task_config.pre_train_model[-len(path_type):] != path_type:
120 | raise ValueError("task_config.pre_train_model error")
121 | result_data.task_logger.logger.info(f"{task_config.pre_train_model} check success")
122 | else:
123 | result_data.task_logger.logger.warn("pre_train_model path is None")
124 |
125 | # 初始化图片保存路径
126 | task_config.image_dir = os.path.join(
127 | task_config.output_dir,
128 | task_config.image_dir,
129 | result_data.task_id
130 | )
131 | # os.makedirs(task_config.image_dir, exist_ok=True)
132 | # 数据预处理输出
133 | os.makedirs(task_config.data_output, exist_ok=True)
134 | task_config.data_preprocess_dir = os.path.join(task_config.data_output,
135 | task_config.data_preprocess_dir)
136 | # os.makedirs(task_config.data_preprocess_dir, exist_ok=True)
137 | # 初始化训练集验证集测试集
138 | task_config.train_dir = os.path.join(task_config.data_output,
139 | task_config.train_dir)
140 | # os.makedirs(task_config.train_dir, exist_ok=True)
141 | task_config.val_dir = os.path.join(task_config.data_output,
142 | task_config.val_dir)
143 | # os.makedirs(task_config.val_dir, exist_ok=True)
144 | task_config.test_dir = os.path.join(task_config.data_output,
145 | task_config.test_dir)
146 | # os.makedirs(task_config.test_dir, exist_ok=True)
147 | result_data.task_logger.logger.info(str(task_config))
148 | result_data.task_logger.logger.info("task config init success")
149 |
150 | @staticmethod
151 | def check_preprocess_config(data_preprocess_config: DataPreprocessConfig):
152 | if data_preprocess_config.num_works <= 0:
153 | raise ValueError(f"num_works = {data_preprocess_config.num_works}, cannot <= 0")
154 |
155 | @staticmethod
156 | def check_train_model_config(train_model_config: TrainModelConfig):
157 | if train_model_config.his_step <= 0 or train_model_config.his_step >= 91:
158 | raise ValueError(f"his_step {train_model_config.his_step} is out of range")
159 | if train_model_config.use_gpu and not torch.cuda.is_available():
160 | raise ValueError(f"cuda is unavailable")
161 | if train_model_config.use_gpu and len(train_model_config.gpu_ids) > 0 and torch.cuda.device_count() <= 1:
162 | raise ValueError("only one gpu can used")
163 | if train_model_config.schedule not in ("cosine", "linear"):
164 | raise ValueError(f"schedule: {train_model_config.schedule}, is not in (cosine, linear)")
165 |
--------------------------------------------------------------------------------
/common/waymo_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: waymo_dataset.py
6 | @Author: YangChen
7 | @Date: 2023/12/25
8 | """
9 |
10 | import json
11 | import os
12 | import pickle
13 | from typing import Dict, Tuple, List
14 |
15 | import torch
16 | from torch.utils.data import Dataset
17 |
18 |
19 | class WaymoDataset(Dataset):
20 |
21 | def __init__(
22 | self, dataset_dir: str, his_step: int, max_pred_num: int,
23 | max_other_num: int, max_traffic_light: int, max_lane_num: int,
24 | max_point_num: int, max_data_size: int = -1
25 | ):
26 | super().__init__()
27 | self.__dataset_dir = dataset_dir
28 | self.__pkl_list = sorted(os.listdir(dataset_dir),
29 | key=lambda x: int(x[:-4].split('_')[-1]))
30 | if max_data_size != -1:
31 | self.__pkl_list = self.__pkl_list[:max_data_size]
32 | # 加载参数
33 | self.__max_pred_num = max_pred_num
34 | self.__max_other_num = max_other_num
35 | self.__max_traffic_light = max_traffic_light
36 | self.__max_lane_num = max_lane_num
37 | self.__max_point_num = max_point_num
38 | self.__his_step = his_step
39 | self.__future_step = 91 - his_step
40 |
41 | def __len__(self) -> int:
42 | return len(self.__pkl_list)
43 |
44 | def get_obs_feature(
45 | self, other_obs_index: List[int],
46 | all_obs_his_traj: torch.Tensor,
47 | all_obs_feature: torch.Tensor
48 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
49 | # other obs历史轨迹
50 | if len(other_obs_index) > 0:
51 | other_obs_traj_index = (torch.Tensor(other_obs_index)
52 | .to(torch.long).view(-1, 1, 1)
53 | .repeat((1, self.__his_step, 5)))
54 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index)
55 | other_obs_feature_index = (torch.Tensor(other_obs_index)
56 | .to(torch.long).view(-1, 1)
57 | .repeat((1, 7)))
58 | # other feature
59 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index)
60 | other_his_traj = other_his_traj[:self.__max_other_num]
61 | other_feature = other_feature[:self.__max_other_num]
62 | num_gap = self.__max_other_num - other_his_traj.shape[0]
63 | if num_gap > 0:
64 | other_his_traj = torch.cat((other_his_traj,
65 | torch.zeros(size=(num_gap, self.__his_step, 5),
66 | dtype=torch.float32)), dim=0)
67 | other_feature = torch.cat((other_feature,
68 | torch.zeros(size=(num_gap, 7),
69 | dtype=torch.float32)), dim=0)
70 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] *
71 | num_gap)
72 | else:
73 | other_his_traj = torch.zeros(
74 | size=(self.__max_other_num, self.__his_step, 5),
75 | dtype=torch.float32
76 | )
77 | other_feature = torch.zeros(size=(self.__max_other_num, 7), dtype=torch.float32)
78 | other_traj_mask = torch.zeros(size=[self.__max_other_num], dtype=torch.float32)
79 | return other_his_traj, other_feature, other_traj_mask
80 |
81 | def get_pred_feature(
82 | self, predicted_index: torch.Tensor,
83 | all_obs_his_traj: torch.Tensor,
84 | all_obs_future_traj: torch.Tensor,
85 | all_obs_feature: torch.Tensor
86 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
87 |
88 | predicted_his_traj = torch.gather(
89 | all_obs_his_traj, 0,
90 | predicted_index.repeat((1, self.__his_step, 5))
91 | )
92 | predicted_future_traj = torch.gather(
93 | all_obs_future_traj, 0,
94 | predicted_index.repeat((1, self.__future_step, 5))
95 | )
96 | predicted_his_traj = predicted_his_traj[:self.__max_pred_num]
97 | predicted_future_traj = predicted_future_traj[:self.__max_pred_num]
98 | # predicted feature
99 | predicted_feature = torch.gather(
100 | all_obs_feature, 0,
101 | predicted_index.view(-1, 1).repeat((1, 7))
102 | )
103 | predicted_feature = predicted_feature[:self.__max_pred_num]
104 | num_gap = self.__max_pred_num - predicted_his_traj.shape[0]
105 | if num_gap > 0:
106 | predicted_his_traj = torch.cat((predicted_his_traj,
107 | torch.zeros(size=(num_gap, self.__his_step, 5),
108 | dtype=torch.float32)), dim=0)
109 | predicted_future_traj = torch.cat((predicted_future_traj,
110 | torch.zeros(size=(num_gap, self.__future_step, 5),
111 | dtype=torch.float32)), dim=0)
112 | predicted_feature = torch.cat((predicted_feature,
113 | torch.zeros(size=(num_gap, 7),
114 | dtype=torch.float32)), dim=0)
115 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj.shape[0] - num_gap) +
116 | [0.0] * num_gap)
117 | return predicted_future_traj, predicted_his_traj, predicted_feature, predicted_traj_mask
118 |
119 | def get_traffic_light(
120 | self, traffic_light: torch.Tensor,
121 | traffic_light_pos: torch.Tensor
122 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
123 | traffic_light = traffic_light[:self.__max_traffic_light]
124 | traffic_light_pos = traffic_light_pos[:self.__max_traffic_light]
125 | num_gap = self.__max_traffic_light - traffic_light.shape[0]
126 | if num_gap > 0:
127 | traffic_light = torch.cat((traffic_light,
128 | torch.zeros(size=(num_gap, self.__his_step),
129 | dtype=torch.float32)), dim=0)
130 | traffic_light_pos = torch.cat((traffic_light_pos,
131 | torch.zeros(size=(num_gap, 2),
132 | dtype=torch.float32)), dim=0)
133 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) +
134 | [0.0] * num_gap)
135 | return traffic_light, traffic_light_pos, traffic_mask
136 |
137 | def get_lane_feature(self, lane_feature: List) -> torch.Tensor:
138 | lane_list = list()
139 | for lane_index, lane in enumerate(lane_feature):
140 | if lane_index >= self.__max_lane_num:
141 | break
142 | point_list = lane[:self.__max_point_num]
143 | point_gap = self.__max_point_num - len(point_list)
144 | # MAX_POINT_NUM, 2
145 | point_list = torch.Tensor(point_list)
146 | if point_gap > 0:
147 | point_list = torch.cat((point_list,
148 | torch.zeros(size=(point_gap, 2),
149 | dtype=torch.float32)), dim=0)
150 | lane_list.append(point_list)
151 | lane_gap = self.__max_lane_num - len(lane_list)
152 | lane_list = torch.stack(lane_list, dim=0)
153 | if lane_gap > 0:
154 | lane_list = torch.cat(
155 | (lane_list, torch.zeros(size=(lane_gap, self.__max_point_num, 2),
156 | dtype=torch.float32)), dim=0
157 | )
158 | return lane_list
159 |
160 | def __getitem__(self, idx) -> Dict:
161 | with open(os.path.join(self.__dataset_dir, self.__pkl_list[idx]), "rb") as f:
162 | result = pickle.load(f)
163 | all_obs_index = set([i for i in range(result[1].shape[0])])
164 | predicted_index = torch.Tensor(result[0]).to(torch.long).view(-1, 1, 1)
165 | all_obs_traj = torch.Tensor(result[1])
166 | all_obs_feature = torch.Tensor(result[2])
167 | all_obs_his_traj = all_obs_traj[:, :self.__his_step]
168 | all_obs_future_traj = all_obs_traj[:, self.__his_step:]
169 | other_obs_index = list(all_obs_index - set(result[0]))
170 | # 获取障碍物信息
171 | other_his_traj, other_feature, other_traj_mask = self.get_obs_feature(
172 | other_obs_index, all_obs_his_traj, all_obs_feature
173 | )
174 | # 需要预测的障碍物历史轨迹
175 | (predicted_future_traj, predicted_his_traj,
176 | predicted_feature, predicted_traj_mask) = self.get_pred_feature(
177 | predicted_index, all_obs_his_traj,
178 | all_obs_future_traj, all_obs_feature
179 | )
180 | # 交通灯信息
181 | traffic_light = torch.Tensor(result[3])
182 | traffic_light_pos = torch.Tensor(result[4])
183 | traffic_light, traffic_light_pos, traffic_mask = self.get_traffic_light(
184 | traffic_light, traffic_light_pos
185 | )
186 | # 车道线信息
187 | lane_list = self.get_lane_feature(result[-1])
188 | map_json = json.dumps(result[-1])
189 | other_his_traj_delt = other_his_traj[:, 1:] - other_his_traj[:, :-1]
190 | other_his_pos = other_his_traj[:, -1, :2]
191 | predicted_his_traj_delt = predicted_his_traj[:, 1:] - predicted_his_traj[:, :-1]
192 | predicted_his_pos = predicted_his_traj[:, -1, :2]
193 | result = {
194 | "other_his_traj": other_his_traj,
195 | "other_feature": other_feature,
196 | "other_traj_mask": other_traj_mask,
197 | "other_his_traj_delt": other_his_traj_delt,
198 | "other_his_pos": other_his_pos,
199 | "predicted_future_traj": predicted_future_traj,
200 | "predicted_his_traj": predicted_his_traj,
201 | "predicted_traj_mask": predicted_traj_mask,
202 | "predicted_feature": predicted_feature,
203 | "predicted_his_traj_delt": predicted_his_traj_delt,
204 | "predicted_his_pos": predicted_his_pos,
205 | "traffic_light": traffic_light,
206 | "traffic_light_pos": traffic_light_pos,
207 | "traffic_mask": traffic_mask,
208 | "lane_list": lane_list,
209 | "map_json": map_json
210 | }
211 | return result
212 |
--------------------------------------------------------------------------------
/net_works/diffusion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: diffusion.py
6 | @Author: YangChen
7 | @Date: 2023/12/27
8 | """
9 | from functools import partial
10 | from typing import List
11 |
12 | import numpy as np
13 | import torch
14 | from torch import nn
15 | import torch.nn.functional as F
16 |
17 | from net_works.transformer import TransformerCrossAttention
18 |
19 |
20 | def extract(a, t, x_shape):
21 | t = t.to(torch.long)
22 | b, *_ = t.shape
23 | out = a.gather(-1, t)
24 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
25 |
26 |
27 | class LinearLayer(nn.Module):
28 | def __init__(self, input_dim: int, output_dim: int):
29 | super().__init__()
30 | self.layer = nn.Sequential(
31 | nn.Linear(input_dim, output_dim, bias=True, dtype=torch.float32),
32 | nn.LeakyReLU(inplace=True)
33 | )
34 | self.normal = nn.BatchNorm1d(output_dim)
35 |
36 | def forward(self, x):
37 | linear_output = self.layer(x)
38 | linear_output = torch.transpose(linear_output, -1, -2)
39 | normal_output = self.normal(linear_output)
40 | return torch.transpose(normal_output, -1, -2)
41 |
42 |
43 | class Decoder(nn.Module):
44 | def __init__(self, input_dim: int, middle_dim: int, output_dim: int):
45 | super(Decoder, self).__init__()
46 | self.up = LinearLayer(input_dim, output_dim)
47 | self.cat_layer = LinearLayer(middle_dim, output_dim)
48 |
49 | def forward(self, x1, x2):
50 | up_output = self.up(x1)
51 | cat_output = torch.cat((up_output, x2), dim=-1)
52 | output = self.cat_layer(cat_output)
53 | return output
54 |
55 |
56 | class UnetDiffusionModel(nn.Module):
57 | def __init__(
58 | self, dims: List[int] = None, input_dim: int = 3,
59 | conditional_dim: int = 5, his_stp: int = 11
60 | ):
61 | super(UnetDiffusionModel, self).__init__()
62 | if dims is None:
63 | self.__dims = [64, 128, 256, 512]
64 | else:
65 | self.__dims = dims
66 | self.his_delt_step = his_stp - 1
67 | self.input_dim = input_dim
68 | input_tensor_dim = self.his_delt_step * input_dim + conditional_dim * his_stp + 1
69 | self.layer1 = LinearLayer(input_tensor_dim, self.__dims[0])
70 | self.layer2 = LinearLayer(self.__dims[0], self.__dims[1])
71 | self.layer3 = LinearLayer(self.__dims[1], self.__dims[2])
72 | self.layer4 = LinearLayer(self.__dims[2], self.__dims[3])
73 | self.layer5 = LinearLayer(self.__dims[3], 64)
74 | self.decode4 = Decoder(64, 768, self.__dims[2])
75 | self.decode3 = Decoder(self.__dims[2], 384, self.__dims[1])
76 | self.decode2 = Decoder(self.__dims[1], 192, self.__dims[0])
77 | self.output_layer = nn.Sequential(
78 | nn.Linear(
79 | self.__dims[0], self.his_delt_step * input_dim,
80 | bias=True, dtype=torch.float32
81 | ),
82 | nn.Tanh()
83 | )
84 |
85 | def forward(self, perturbed_x, t, predicted_his_traj):
86 | # batch, obs_num, 10, 5
87 | batch_size = perturbed_x.shape[0]
88 | obs_num = perturbed_x.shape[1]
89 | input_tensor = torch.flatten(perturbed_x, start_dim=2)
90 | his_traj = torch.flatten(predicted_his_traj, start_dim=2)
91 | t = t.view(-1, 1, 1).repeat((1, obs_num, 1))
92 | # batch, obs_num, 50 + 50 + 1
93 | input_tensor = torch.cat([input_tensor, his_traj, t], dim=-1).to(torch.float32)
94 | # batch, 64
95 | e1 = self.layer1(input_tensor)
96 | e2 = self.layer2(e1)
97 | e3 = self.layer3(e2)
98 | e4 = self.layer4(e3)
99 | f = self.layer5(e4)
100 | d4 = self.decode4(f, e4)
101 | d3 = self.decode3(d4, e3)
102 | d2 = self.decode2(d3, e2)
103 | out = self.output_layer(d2)
104 | return out.view(batch_size, obs_num, self.his_delt_step, self.input_dim)
105 |
106 |
107 | class DitDiffusionModel(nn.Module):
108 | def __init__(
109 | self, input_dim: int = 3, conditional_dim: int = 5,
110 | his_stp: int = 11, num_dit_blocks: int = 4
111 | ):
112 | super(DitDiffusionModel, self).__init__()
113 | self.his_delt_step = his_stp - 1
114 | self.input_dim = (self.his_delt_step * input_dim) + 1
115 | self.conditional_dim = his_stp * conditional_dim
116 | self.dit_blocks = nn.ModuleList([])
117 | for _ in range(num_dit_blocks):
118 | self.dit_blocks.append(TransformerCrossAttention(self.input_dim, self.conditional_dim))
119 | self.linear_output = nn.Sequential(
120 | nn.Linear(self.input_dim, self.input_dim * 2),
121 | nn.LayerNorm(self.input_dim * 2),
122 | nn.ReLU(inplace=True),
123 | nn.Linear(self.input_dim * 2, self.input_dim),
124 | nn.LayerNorm(self.input_dim),
125 | nn.ReLU(inplace=True),
126 | nn.Linear(self.input_dim, self.his_delt_step * input_dim),
127 | nn.Tanh()
128 | )
129 |
130 | def forward(self, perturbed_x, t, predicted_his_traj):
131 | batch_size = perturbed_x.shape[0]
132 | obs_num = perturbed_x.shape[1]
133 | # batch, pred_obs, his_stp * 3
134 | input_tensor = torch.flatten(perturbed_x, start_dim=2)
135 | # batch, 15, 1
136 | t = t.view(-1, 1, 1).repeat(1, obs_num, 1)
137 | # batch, pred_obs, his_stp * 5
138 | predicted_his_traj = torch.flatten(predicted_his_traj, start_dim=2)
139 | # batch, pred_obs, his_stp * 3 + 1
140 | input_tensor = torch.cat([input_tensor, t], dim=-1)
141 | for dit_block in self.dit_blocks:
142 | input_tensor = dit_block(input_tensor, predicted_his_traj)
143 | noise_output = self.linear_output(input_tensor)
144 | return noise_output.view(batch_size, obs_num, self.his_delt_step, -1)
145 |
146 |
147 | class GaussianDiffusion(nn.Module):
148 | def __init__(
149 | self, input_dim: int = 5, conditional_dim: int = 5,
150 | his_stp: int = 11, betas: np.ndarray = None,
151 | loss_type: str = "l2", num_dit_blocks: int = 4,
152 | diffusion_type: str = "none"
153 | ):
154 | super(GaussianDiffusion, self).__init__()
155 | if betas is None:
156 | betas = []
157 | # l1或者l2损失
158 | if loss_type not in ["l1", "l2"]:
159 | raise ValueError(f"get unknown loss type: {loss_type}")
160 | if diffusion_type not in ["dit", "unet", "none"]:
161 | raise ValueError(f"get unknown diffusion type: {diffusion_type}")
162 |
163 | self.loss_type = loss_type
164 | self.diffusion_type = diffusion_type
165 | self.num_time_steps = len(betas)
166 |
167 | alphas = 1.0 - betas
168 | alphas_cum_prod = np.cumprod(alphas)
169 | # 转换成torch.tensor来处理
170 | to_torch = partial(torch.tensor, dtype=torch.float32)
171 |
172 | # betas [0.0001, 0.00011992, 0.00013984 ... , 0.02]
173 | self.register_buffer("betas", to_torch(betas))
174 | # alphas [0.9999, 0.99988008, 0.99986016 ... , 0.98]
175 | self.register_buffer("alphas", to_torch(alphas))
176 | # alphas_cum_prod [9.99900000e-01, 9.99780092e-01, 9.99640283e-01 ... , 4.03582977e-05]
177 | self.register_buffer("alphas_cum_prod", to_torch(alphas_cum_prod))
178 | # sqrt(alphas_cum_prod)
179 | self.register_buffer("sqrt_alphas_cum_prod", to_torch(np.sqrt(alphas_cum_prod)))
180 | # sqrt(1 - alphas_cum_prod)
181 | self.register_buffer("sqrt_one_minus_alphas_cum_prod", to_torch(np.sqrt(1 - alphas_cum_prod)))
182 | # sqrt(1 / alphas)
183 | self.register_buffer("reciprocal_sqrt_alphas", to_torch(np.sqrt(1 / alphas)))
184 | self.register_buffer("sigma", to_torch(np.sqrt(betas)))
185 | alphas_cum_prod_prev = np.append(1, alphas_cum_prod[:-1])
186 | # self.register_buffer("remove_noise_coeff", to_torch(betas / np.sqrt(1 - alphas_cum_prod)))
187 | self.register_buffer("remove_noise_coeff",
188 | to_torch(betas * (1 - alphas_cum_prod_prev / np.sqrt(1 - alphas_cum_prod))))
189 | # 初始化model
190 | self.his_delt_step = his_stp - 1
191 | if self.diffusion_type == "dit":
192 | self.diffusion_model = DitDiffusionModel(
193 | input_dim=input_dim,
194 | conditional_dim=conditional_dim,
195 | num_dit_blocks=num_dit_blocks
196 | )
197 | elif self.diffusion_type == "unet":
198 | self.diffusion_model = UnetDiffusionModel(
199 | input_dim=input_dim,
200 | conditional_dim=conditional_dim,
201 | )
202 | else:
203 | self.diffusion_model = None
204 | self.input_dim = input_dim
205 | self.norm_output = nn.BatchNorm2d(5)
206 |
207 | def remove_noise(self, noise, t_batch, predicted_his_traj):
208 | model_output = self.diffusion_model(noise, t_batch, predicted_his_traj)
209 | return (
210 | (noise - extract(self.remove_noise_coeff, t_batch, noise.shape) * model_output) *
211 | extract(self.reciprocal_sqrt_alphas, t_batch, noise.shape)
212 | )
213 |
214 | def sample(self, noise, predicted_his_traj):
215 | if self.diffusion_type != "none":
216 | batch_size = predicted_his_traj.shape[0]
217 | device = predicted_his_traj.device
218 | for t in range(self.num_time_steps - 1, -1, -1):
219 | t_batch = torch.tensor([t], device=device).repeat(batch_size)
220 | noise = self.remove_noise(noise, t_batch, predicted_his_traj)
221 | if t > 0:
222 | noise += extract(self.sigma, t_batch, noise.shape) * torch.randn_like(noise)
223 | noise = torch.transpose(noise, 1, -1)
224 | noise = self.norm_output(noise)
225 | noise = torch.transpose(noise, 1, -1)
226 | return noise
227 |
228 | def perturb_x(self, future_traj, t, noise):
229 | return (
230 | extract(self.sqrt_alphas_cum_prod, t, future_traj.shape) * future_traj +
231 | extract(self.sqrt_one_minus_alphas_cum_prod, t, future_traj.shape) * noise
232 | )
233 |
234 | def get_losses(self, predicted_his_traj_delt, predicted_his_traj, predicted_traj_mask, t):
235 | if self.diffusion_type != "none":
236 | noise = torch.randn_like(predicted_his_traj_delt)
237 | perturbed_x = self.perturb_x(predicted_his_traj_delt, t, noise)
238 | estimated_noise = self.diffusion_model(perturbed_x, t, predicted_his_traj)
239 | batch_size, obs_num = perturbed_x.shape[0], perturbed_x.shape[1]
240 | # diffusion_loss
241 | diffusion_loss = F.mse_loss(estimated_noise, noise, reduction="none")
242 | diffusion_loss_mask = (predicted_traj_mask.view(batch_size, obs_num, 1, 1)
243 | .repeat(1, 1, self.his_delt_step, self.input_dim))
244 | diffusion_loss = torch.sum(diffusion_loss * diffusion_loss_mask) / torch.sum(diffusion_loss_mask)
245 | return diffusion_loss
246 | else:
247 | return torch.tensor(0).to(torch.float32)
248 |
249 | def forward(self, data: dict):
250 | # batch, pred_obs(8), his_step, 3
251 | # batch, pred_obs(8), his_step, 5
252 | predicted_his_traj_delt = data['predicted_his_traj_delt']
253 | predicted_his_traj = data['predicted_his_traj']
254 | predicted_traj_mask = data['predicted_traj_mask']
255 | batch_size = predicted_his_traj_delt.shape[0]
256 | device = predicted_his_traj.device
257 | t = torch.randint(0, self.num_time_steps, (batch_size,), device=device)
258 | return self.get_losses(predicted_his_traj_delt, predicted_his_traj, predicted_traj_mask, t)
259 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/tasks/show_result_task.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: show_result_task.py
6 | @Author: YangChen
7 | @Date: 2024/1/6
8 | """
9 | import os.path
10 | import shutil
11 | from typing import Any, Dict
12 |
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import tensorflow as tf
16 | import torch
17 | from matplotlib import animation
18 | from matplotlib import patches
19 | from matplotlib.collections import LineCollection
20 | from matplotlib.colors import LinearSegmentedColormap
21 | from waymo_open_dataset.protos import scenario_pb2
22 | from waymo_open_dataset.protos.scenario_pb2 import Scenario
23 | from waymo_open_dataset.utils.sim_agents import visualizations, submission_specs
24 |
25 | from common import TaskType, LoadConfigResultDate
26 | from net_works import BackBone
27 | from tasks import BaseTask
28 | from utils import DataUtil, MathUtil, MapUtil
29 |
30 | RESULT_DIR = r"/home/haomo/yangchen/Scene-Diffusion/output/image"
31 | DATA_SET_PATH = r"/home/haomo/yangchen/Scene-Diffusion/data_set/train_set/training.tfrecord-00000-of-01000"
32 | MODEL_PATH = r"/home/haomo/yangchen/Scene-Diffusion/model_epoch152.pth"
33 |
34 |
35 | class ShowResultsTask(BaseTask):
36 |
37 | def __init__(self):
38 | super(ShowResultsTask, self).__init__()
39 | self.task_type = TaskType.SHOW_RESULTS
40 | self.cmap = LinearSegmentedColormap.from_list(
41 | 'my_cmap',
42 | [np.array([0., 232., 157.]) / 255, np.array([0., 120., 255.]) / 255],
43 | 100
44 | )
45 | self.color_dict = {
46 | 0: np.array([0., 120., 255.]) / 255,
47 | 1: np.array([0., 232., 157.]) / 255,
48 | 2: np.array([255., 205., 85.]) / 255,
49 | 3: np.array([244., 175., 145.]) / 255,
50 | 4: np.array([145., 80., 200.]) / 255,
51 | 5: np.array([0., 51., 102.]) / 255,
52 | 6: np.array([1, 0, 0]),
53 | 7: np.array([0, 1, 0]),
54 | }
55 |
56 | def execute(self, result_info: LoadConfigResultDate):
57 | if os.path.exists(RESULT_DIR):
58 | shutil.rmtree(RESULT_DIR)
59 | os.makedirs(RESULT_DIR, exist_ok=True)
60 | self.show_result(result_info)
61 |
62 | @staticmethod
63 | def load_pretrain_model(result_info: LoadConfigResultDate) -> BackBone:
64 | betas = MathUtil.generate_linear_schedule(result_info.train_model_config.time_steps)
65 | model = BackBone(betas).eval()
66 | device = torch.device("cpu")
67 | pretrained_dict = torch.load(MODEL_PATH, map_location=device)
68 | model_dict = model.state_dict()
69 | # 模型参数赋值
70 | new_model_dict = dict()
71 | for key in model_dict.keys():
72 | if ("module." + key) in pretrained_dict:
73 | new_model_dict[key] = pretrained_dict["module." + key]
74 | elif key in pretrained_dict:
75 | new_model_dict[key] = pretrained_dict[key]
76 | else:
77 | print("key: ", key, ", not in pretrained")
78 | model.load_state_dict(new_model_dict)
79 | print("load_pretrain_model success")
80 | return model
81 |
82 | def show_result(self, result_info: LoadConfigResultDate):
83 | model = self.load_pretrain_model(result_info)
84 | match_filenames = tf.io.matching_files([DATA_SET_PATH])
85 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data").take(100)
86 | dataset_iterator = dataset.as_numpy_iterator()
87 | for index, scenario_bytes in enumerate(dataset_iterator):
88 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes)
89 | data_dict = DataUtil.transform_data_to_input(scenario, result_info)
90 | for key, value in data_dict.items():
91 | if isinstance(value, torch.Tensor):
92 | data_dict[key] = value.to(torch.float32).unsqueeze(dim=0)
93 |
94 | predict_traj = model(data_dict)[-1]
95 | predicted_traj_mask = data_dict['predicted_traj_mask'][0]
96 | predicted_future_traj = data_dict['predicted_future_traj'][0]
97 | predicted_his_traj = data_dict['predicted_his_traj'][0]
98 | predicted_num = 0
99 | for i in range(predicted_traj_mask.shape[0]):
100 | if int(predicted_traj_mask[i]) == 1:
101 | predicted_num += 1
102 | generate_traj = predict_traj[:predicted_num]
103 | predicted_future_traj = predicted_future_traj[:predicted_num]
104 | predicted_his_traj = predicted_his_traj[:predicted_num]
105 | real_traj = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, :2].detach().numpy()
106 | real_yaw = torch.cat((predicted_his_traj, predicted_future_traj), dim=1)[:, :, 2].detach().numpy()
107 | model_output = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, :2].detach().numpy()
108 | model_yaw = torch.cat((predicted_his_traj, generate_traj), dim=1)[:, :, 2].detach().numpy()
109 | # 可视化输入
110 | image_path = os.path.join(RESULT_DIR, f"{index}_input.png")
111 | self.draw_input(scenario, image_path)
112 | # 可视化ground truth
113 | image_path = os.path.join(RESULT_DIR, f"{index}_ground_truth.png")
114 | self.draw_scene(predicted_num, real_traj, data_dict, scenario, image_path)
115 | # 可视化model output
116 | image_path = os.path.join(RESULT_DIR, f"{index}_model_output.png")
117 | self.draw_scene(predicted_num, model_output, data_dict, scenario, image_path)
118 | # 可视化动态图
119 | image_path = os.path.join(RESULT_DIR, f"{index}_ground_truth.gif")
120 | self.draw_gif(predicted_num, real_traj, real_yaw, data_dict, scenario, image_path)
121 | image_path = os.path.join(RESULT_DIR, f"{index}_model_output.gif")
122 | self.draw_gif(predicted_num, model_output, model_yaw, data_dict, scenario, image_path)
123 |
124 | fig, axis = plt.subplots(1, 1, figsize=(10, 11))
125 | num = np.array([i for i in range(91)])
126 | segments = np.stack((num, num), axis=-1)[np.newaxis, :]
127 | image_path = os.path.join(RESULT_DIR, f"color_bar.png")
128 | line_segments = LineCollection(segments=segments, linewidths=1,
129 | linestyles='solid', cmap=self.cmap)
130 | cbar = fig.colorbar(line_segments, cmap=self.cmap, orientation='horizontal')
131 | cbar.set_ticks(np.linspace(0, 1, 10))
132 | cbar.set_ticklabels([str(i) for i in range(0, 91, 10)])
133 | plt.savefig(image_path)
134 | plt.close('all') # 避免内存泄漏
135 |
136 | @staticmethod
137 | def draw_input(scenario: Scenario, image_path: str):
138 | fig, axis = plt.subplots(1, 1, figsize=(10, 10))
139 | visualizations.add_map(axis, scenario)
140 | predicted_obs_ids = submission_specs.get_evaluation_sim_agent_ids(scenario)
141 | current_time_index = scenario.current_time_index
142 | for track in scenario.tracks:
143 | if track.id not in predicted_obs_ids:
144 | continue
145 | param_dict = {
146 | "x": track.states[current_time_index].center_x,
147 | "y": track.states[current_time_index].center_y,
148 | "bbox_yaw": track.states[current_time_index].heading,
149 | "length": track.states[current_time_index].length,
150 | "width": track.states[current_time_index].width,
151 | }
152 | rect = visualizations.get_bbox_patch(**param_dict)
153 | axis.add_patch(rect)
154 | plt.savefig(image_path)
155 | plt.close('all') # 避免内存泄漏
156 |
157 | def draw_scene(
158 | self, predicted_num: int, traj: np.ndarray,
159 | data_dict: Dict[str, Any], scenario: Scenario, image_path: str
160 | ):
161 | fig, axis = plt.subplots(1, 1, figsize=(10, 10))
162 | visualizations.add_map(axis, scenario)
163 | # visualizations.get_bbox_patch()
164 | # axis.axis('equal') # 横纵坐标比例相等
165 | curr_x, curr_y, curr_heading, _ = data_dict['curr_loc']
166 | for i in range(predicted_num):
167 | real_traj_x, real_traj_y = MapUtil.local_to_global(curr_heading, traj[i, :, 0],
168 | traj[i, :, 1], curr_x, curr_y)
169 | num = np.linspace(0, 1, len(real_traj_x))
170 | for j in range(2, len(real_traj_x)):
171 | axis.plot(
172 | real_traj_x[j - 2:j],
173 | real_traj_y[j - 2:j],
174 | linewidth=5,
175 | color=self.cmap(num[j]),
176 | )
177 | axis.set_xticks([])
178 | axis.set_yticks([])
179 | # plt.show()
180 | plt.savefig(image_path)
181 | plt.close('all') # 避免内存泄漏
182 | print(f"{image_path} save success")
183 |
184 | def draw_gif(
185 | self, predicted_num: int, traj: np.ndarray, real_yaw: np.ndarray,
186 | data_dict: Dict[str, Any], scenario: Scenario, image_path: str
187 | ):
188 | fig, axis = plt.subplots(1, 1, figsize=(10, 10))
189 | visualizations.add_map(axis, scenario)
190 | # visualizations.get_bbox_patch()
191 | # axis.axis('equal') # 横纵坐标比例相等
192 | curr_x, curr_y, curr_heading, _ = data_dict['curr_loc']
193 | x_list = list()
194 | y_list = list()
195 | yaw_list = list()
196 | for i in range(predicted_num):
197 | real_traj_x, real_traj_y = MapUtil.local_to_global(curr_heading, traj[i, :, 0],
198 | traj[i, :, 1], curr_x, curr_y)
199 | real_traj_yaw = MapUtil.theta_local_to_global(curr_heading, real_yaw[i])
200 | x_list.append(real_traj_x)
201 | y_list.append(real_traj_y)
202 | yaw_list.append(real_traj_yaw)
203 | # [num, step]
204 | x_list = np.stack(x_list, axis=0)
205 | y_list = np.stack(y_list, axis=0)
206 | yaw_list = np.stack(yaw_list, axis=0)
207 | predicted_feature = data_dict['predicted_feature'].squeeze()[:, :2]
208 |
209 | def animate(t: int) -> list[patches.Rectangle]:
210 | # At each animation step, we need to remove the existing patches. This can
211 | # only be done using the `pop()` operation.
212 | for _ in range(len(axis.patches)):
213 | axis.patches.pop()
214 | bboxes = list()
215 | for j in range(x_list.shape[0]):
216 | bboxes.append(axis.add_patch(
217 | self.get_bbox_patch(
218 | x_list[:, t][j], y_list[:, t][j], yaw_list[:, t][j],
219 | predicted_feature[j, 1], predicted_feature[j, 0], self.color_dict[j]
220 | )
221 | ))
222 | return bboxes
223 |
224 | animations = animation.FuncAnimation(
225 | fig, animate, frames=x_list.shape[1], interval=100,
226 | blit=True)
227 | axis.set_xticks([])
228 | axis.set_yticks([])
229 | # plt.show()
230 | animations.save(image_path, writer='ffmpeg', fps=30)
231 | plt.close('all') # 避免内存泄漏
232 | print(f"{image_path} save success")
233 |
234 | @staticmethod
235 | def get_bbox_patch(
236 | x: float, y: float, bbox_yaw: float, length: float, width: float,
237 | color: np.ndarray
238 | ) -> patches.Rectangle:
239 | left_rear_object = np.array([-length / 2, -width / 2])
240 |
241 | rotation_matrix = np.array([[np.cos(bbox_yaw), -np.sin(bbox_yaw)],
242 | [np.sin(bbox_yaw), np.cos(bbox_yaw)]])
243 | left_rear_rotated = rotation_matrix.dot(left_rear_object)
244 | left_rear_global = np.array([x, y]) + left_rear_rotated
245 | color = list(color) + [0.5]
246 | rect = patches.Rectangle(
247 | left_rear_global, length, width, angle=np.rad2deg(bbox_yaw), color=color)
248 | return rect
249 |
250 |
251 | if __name__ == "__main__":
252 | # show_result()
253 | pass
254 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | @Project: WcDT
5 | @Name: data_utils.py
6 | @Author: YangChen
7 | @Date: 2023/12/24
8 | """
9 | import json
10 | import math
11 | from typing import List, Dict, Any, Tuple
12 |
13 | import numpy as np
14 | import torch
15 | from waymo_open_dataset.protos import scenario_pb2
16 | from waymo_open_dataset.utils.sim_agents import submission_specs
17 |
18 | from common import ObjectType, LoadConfigResultDate
19 | from utils.map_utils import MapUtil
20 |
21 |
22 | class DataUtil:
23 |
24 | @classmethod
25 | def load_scenario_data(cls, scenario: scenario_pb2.Scenario) -> Dict[str, Any]:
26 | result_dict = dict()
27 | # 获取自车当前位置把所有轨迹和车道线转化成自车坐标系
28 | sdc_track_index = scenario.sdc_track_index
29 | current_time_index = scenario.current_time_index
30 | curr_state = scenario.tracks[scenario.sdc_track_index].states[current_time_index]
31 | ego_curr_x, ego_curr_y = curr_state.center_x, curr_state.center_y
32 | ego_curr_heading = curr_state.heading
33 | # 需要预测的障碍物id
34 | predicted_obs_ids = submission_specs.get_evaluation_sim_agent_ids(scenario)
35 | obs_tracks = cls.load_obs_tracks(scenario, ego_curr_x, ego_curr_y, ego_curr_heading)
36 | map_features = cls.load_map_features(scenario, ego_curr_x, ego_curr_y, ego_curr_heading)
37 | traffic_lights = cls.load_traffic_light(scenario, ego_curr_x, ego_curr_y, ego_curr_heading)
38 | if len(obs_tracks) <= 1:
39 | return result_dict
40 | result_dict['predicted_obs_ids'] = predicted_obs_ids
41 | result_dict['obs_tracks'] = obs_tracks
42 | result_dict['map_features'] = map_features
43 | result_dict['dynamic_states'] = traffic_lights
44 | result_dict['curr_loc'] = (ego_curr_x, ego_curr_y, ego_curr_heading, sdc_track_index)
45 | return result_dict
46 |
47 | @classmethod
48 | def load_obs_tracks(cls, scenario: scenario_pb2.Scenario,
49 | ego_curr_x: float, ego_curr_y: float,
50 | ego_curr_heading: float) -> List[Dict[str, Any]]:
51 | obs_tracks = list()
52 | # 一个障碍物的状态(id, 类型, 轨迹)
53 | for track in scenario.tracks:
54 | one_obs_track = dict()
55 | obs_id = track.id
56 | if hasattr(track, "object_type"):
57 | object_type = track.object_type
58 | else:
59 | object_type = 4
60 | obs_traj = list()
61 | for state in track.states:
62 | if not state.valid:
63 | continue
64 | if 'height' not in one_obs_track and 'length' not in one_obs_track \
65 | and 'width' not in one_obs_track:
66 | one_obs_track['height'] = state.height
67 | one_obs_track['length'] = state.length
68 | one_obs_track['width'] = state.width
69 | center_x, center_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
70 | state.center_x, state.center_y)
71 | center_heading = MapUtil.theta_global_to_local(ego_curr_heading, state.heading)
72 | # 速度变化有两种方法:
73 | # 1. 在新坐标系做投影
74 | # 2. 计算新坐标系的航向角,用速度乘sin和cos
75 | curr_v = math.sqrt(math.pow(state.velocity_x, 2) +
76 | math.pow(state.velocity_y, 2))
77 | curr_v_heading = math.atan2(state.velocity_y, state.velocity_x)
78 | curr_v_heading = MapUtil.theta_global_to_local(ego_curr_heading, curr_v_heading)
79 | obs_traj.append(
80 | (center_x, center_y, center_heading,
81 | curr_v * math.cos(curr_v_heading), curr_v * math.sin(curr_v_heading))
82 | )
83 | one_obs_track['obs_id'] = obs_id
84 | one_obs_track['object_type'] = object_type
85 | one_obs_track['obs_traj'] = obs_traj
86 | # 轨迹丢失的障碍物不需要
87 | if len(obs_traj) == 91:
88 | obs_tracks.append(one_obs_track)
89 | return obs_tracks
90 |
91 | @staticmethod
92 | def load_map_features(scenario: scenario_pb2.Scenario,
93 | ego_curr_x: float, ego_curr_y: float,
94 | ego_curr_heading: float) -> List[Dict[str, Any]]:
95 | map_features = list()
96 | for map_feature in scenario.map_features:
97 | one_map_dict = dict()
98 | map_id = map_feature.id
99 | polygon_points = list()
100 | if hasattr(map_feature, "road_edge") and map_feature.road_edge.polyline:
101 | map_type = "road_edge"
102 | polygon_list = map_feature.road_edge.polyline
103 | elif hasattr(map_feature, "road_line") and map_feature.road_line.polyline:
104 | map_type = "road_line"
105 | polygon_list = map_feature.road_line.polyline
106 | else:
107 | continue
108 | for polygon_point in polygon_list:
109 | polygon_point_x, polygon_point_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
110 | polygon_point.x, polygon_point.y)
111 | polygon_points.append((polygon_point_x, polygon_point_y))
112 | one_map_dict['map_id'] = map_id
113 | one_map_dict['map_type'] = map_type
114 | one_map_dict['polygon_points'] = polygon_points
115 | map_features.append(one_map_dict)
116 | return map_features
117 |
118 | @staticmethod
119 | def load_traffic_light(scenario: scenario_pb2.Scenario,
120 | ego_curr_x: float, ego_curr_y: float,
121 | ego_curr_heading: float) -> Dict[str, Any]:
122 | dynamic_states = dict()
123 | for dynamic_state in scenario.dynamic_map_states:
124 | for lane_state in dynamic_state.lane_states:
125 | lane_id = lane_state.lane
126 | lane_x, lane_y = MapUtil.global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
127 | lane_state.stop_point.x,
128 | lane_state.stop_point.y)
129 | if lane_id not in dynamic_states:
130 | dynamic_states[lane_id] = list()
131 | dynamic_states[lane_id].append((lane_x, lane_y))
132 | state = lane_state.state
133 | dynamic_states[lane_id].append(state)
134 | return dynamic_states
135 |
136 | @staticmethod
137 | def split_pkl_data(one_pkl_dict: Dict[str, Any], his_step: int) -> Tuple:
138 | map_points = [feature['polygon_points'] for feature in one_pkl_dict['map_features']]
139 | predicted_obs_ids = one_pkl_dict['predicted_obs_ids']
140 | # 初始化需要保存的信息
141 | index = 0
142 | traj_list = list()
143 | obs_feature_list = list()
144 | predicted_obs_index = list()
145 | # 障碍物信息
146 | for one_obs_info in one_pkl_dict['obs_tracks']:
147 | if len(one_obs_info['obs_traj']) < 91 or None in one_obs_info['obs_traj']:
148 | continue
149 | # 障碍物size和type
150 | obs_feature = list()
151 | obs_feature.append(one_obs_info['width'])
152 | obs_feature.append(one_obs_info['length'])
153 | type_onehot = [0] * len(ObjectType)
154 | type_onehot[one_obs_info['object_type']] = 1
155 | obs_feature += type_onehot
156 | obs_feature_list.append(obs_feature)
157 | # 记录predicted_obs的索引
158 | if one_obs_info['obs_id'] in predicted_obs_ids:
159 | predicted_obs_index.append(index)
160 | traj = np.array(one_obs_info['obs_traj'])
161 | traj_list.append(traj)
162 | index += 1
163 | if len(predicted_obs_index) < 1:
164 | return tuple()
165 | # 动态地图信息
166 | dynamic_states = list()
167 | dynamic_pos = list()
168 | for key, value in one_pkl_dict['dynamic_states'].items():
169 | dynamic_pos.append(value[0])
170 | dynamic_state = value[1:]
171 | dynamic_state = dynamic_state[:his_step]
172 | if len(dynamic_state) < his_step:
173 | dynamic_state = dynamic_state + ([0] * (his_step - len(dynamic_state)))
174 | dynamic_states.append(dynamic_state)
175 | traj_arr = np.stack(traj_list, axis=0)
176 | obs_feature_list = np.stack(obs_feature_list, axis=0)
177 | dynamic_states = np.array(dynamic_states)
178 | dynamic_pos = np.array(dynamic_pos)
179 | one_pkl_data = (predicted_obs_index, traj_arr,
180 | obs_feature_list, dynamic_states,
181 | dynamic_pos, map_points)
182 | return one_pkl_data
183 |
184 | @staticmethod
185 | def get_obs_feature(
186 | config_data: LoadConfigResultDate,
187 | other_obs_index: List[int],
188 | all_obs_his_traj: torch.Tensor,
189 | all_obs_feature: torch.Tensor
190 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
191 | his_step = config_data.train_model_config.his_step
192 | max_other_num = config_data.train_model_config.max_other_num
193 | # other obs历史轨迹
194 | if len(other_obs_index) > 0:
195 | other_obs_traj_index = (torch.Tensor(other_obs_index)
196 | .to(torch.long).view(-1, 1, 1)
197 | .repeat((1, his_step, 5)))
198 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index)
199 | other_obs_feature_index = (torch.Tensor(other_obs_index)
200 | .to(torch.long).view(-1, 1)
201 | .repeat((1, 7)))
202 | # other feature
203 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index)
204 | other_his_traj = other_his_traj[:max_other_num]
205 | other_feature = other_feature[:max_other_num]
206 | num_gap = max_other_num - other_his_traj.shape[0]
207 | if num_gap > 0:
208 | other_his_traj = torch.cat((other_his_traj,
209 | torch.zeros(size=(num_gap, his_step, 5),
210 | dtype=torch.float32)), dim=0)
211 | other_feature = torch.cat((other_feature,
212 | torch.zeros(size=(num_gap, 7),
213 | dtype=torch.float32)), dim=0)
214 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] *
215 | num_gap)
216 | else:
217 | other_his_traj = torch.zeros(
218 | size=(max_other_num, his_step, 5),
219 | dtype=torch.float32
220 | )
221 | other_feature = torch.zeros(size=(max_other_num, 7), dtype=torch.float32)
222 | other_traj_mask = torch.zeros(size=[max_other_num], dtype=torch.float32)
223 | return other_his_traj, other_feature, other_traj_mask
224 |
225 | @staticmethod
226 | def get_pred_feature(
227 | config_data: LoadConfigResultDate,
228 | predicted_index: torch.Tensor,
229 | all_obs_his_traj: torch.Tensor,
230 | all_obs_future_traj: torch.Tensor,
231 | all_obs_feature: torch.Tensor
232 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
233 | his_step = config_data.train_model_config.his_step
234 | future_step = 91 - his_step
235 | max_pred_num = config_data.train_model_config.max_pred_num
236 | predicted_his_traj = torch.gather(
237 | all_obs_his_traj, 0,
238 | predicted_index.repeat((1, his_step, 5))
239 | )
240 | predicted_future_traj = torch.gather(
241 | all_obs_future_traj, 0,
242 | predicted_index.repeat((1, future_step, 5))
243 | )
244 | predicted_his_traj = predicted_his_traj[:max_pred_num]
245 | predicted_future_traj = predicted_future_traj[:max_pred_num]
246 | # predicted feature
247 | predicted_feature = torch.gather(
248 | all_obs_feature, 0,
249 | predicted_index.view(-1, 1).repeat((1, 7))
250 | )
251 | predicted_feature = predicted_feature[:max_pred_num]
252 | num_gap = max_pred_num - predicted_his_traj.shape[0]
253 | if num_gap > 0:
254 | predicted_his_traj = torch.cat((predicted_his_traj,
255 | torch.zeros(size=(num_gap, his_step, 5),
256 | dtype=torch.float32)), dim=0)
257 | predicted_future_traj = torch.cat((predicted_future_traj,
258 | torch.zeros(size=(num_gap, future_step, 5),
259 | dtype=torch.float32)), dim=0)
260 | predicted_feature = torch.cat((predicted_feature,
261 | torch.zeros(size=(num_gap, 7),
262 | dtype=torch.float32)), dim=0)
263 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj.shape[0] - num_gap) +
264 | [0.0] * num_gap)
265 | return predicted_future_traj, predicted_his_traj, predicted_feature, predicted_traj_mask
266 |
267 | @staticmethod
268 | def get_traffic_light(
269 | config_data: LoadConfigResultDate,
270 | traffic_light: torch.Tensor,
271 | traffic_light_pos: torch.Tensor
272 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
273 | his_step = config_data.train_model_config.his_step
274 | max_traffic_light = config_data.train_model_config.max_traffic_light
275 | traffic_light = traffic_light[:max_traffic_light]
276 | traffic_light_pos = traffic_light_pos[:max_traffic_light]
277 | num_gap = max_traffic_light - traffic_light.shape[0]
278 | if num_gap > 0:
279 | traffic_light = torch.cat((traffic_light,
280 | torch.zeros(size=(num_gap, his_step),
281 | dtype=torch.float32)), dim=0)
282 | traffic_light_pos = torch.cat((traffic_light_pos,
283 | torch.zeros(size=(num_gap, 2),
284 | dtype=torch.float32)), dim=0)
285 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) +
286 | [0.0] * num_gap)
287 | return traffic_light, traffic_light_pos, traffic_mask
288 |
289 | @staticmethod
290 | def get_lane_feature(config_data: LoadConfigResultDate, lane_feature: List) -> torch.Tensor:
291 | max_lane_num = config_data.train_model_config.max_lane_num
292 | max_point_num = config_data.train_model_config.max_point_num
293 | lane_list = list()
294 | for lane_index, lane in enumerate(lane_feature):
295 | if lane_index >= max_lane_num:
296 | break
297 | point_list = lane[:max_point_num]
298 | point_gap = max_point_num - len(point_list)
299 | # MAX_POINT_NUM, 2
300 | point_list = torch.Tensor(point_list)
301 | if point_gap > 0:
302 | point_list = torch.cat((point_list,
303 | torch.zeros(size=(point_gap, 2),
304 | dtype=torch.float32)), dim=0)
305 | lane_list.append(point_list)
306 | lane_gap = max_lane_num - len(lane_list)
307 | lane_list = torch.stack(lane_list, dim=0)
308 | if lane_gap > 0:
309 | lane_list = torch.cat(
310 | (lane_list, torch.zeros(size=(lane_gap, max_point_num, 2),
311 | dtype=torch.float32)), dim=0
312 | )
313 | return lane_list
314 |
315 | @classmethod
316 | def transform_data_to_input(cls, scenario: scenario_pb2.Scenario,
317 | config_data: LoadConfigResultDate) -> Dict[str, Any]:
318 | data_dict = cls.load_scenario_data(scenario)
319 | if len(data_dict) == 0:
320 | return dict()
321 | his_step = config_data.train_model_config.his_step
322 | pkl_data = DataUtil.split_pkl_data(data_dict, his_step)
323 | all_obs_index = set([i for i in range(pkl_data[1].shape[0])])
324 | predicted_index = torch.Tensor(pkl_data[0]).to(torch.long).view(-1, 1, 1)
325 | all_obs_traj = torch.Tensor(pkl_data[1])
326 | all_obs_feature = torch.Tensor(pkl_data[2])
327 | all_obs_his_traj = all_obs_traj[:, :his_step]
328 | all_obs_future_traj = all_obs_traj[:, his_step:]
329 | other_obs_index = list(all_obs_index - set(pkl_data[0]))
330 | # 获取障碍物信息
331 | other_his_traj, other_feature, other_traj_mask = cls.get_obs_feature(
332 | config_data, other_obs_index, all_obs_his_traj, all_obs_feature
333 | )
334 | # 需要预测的障碍物历史轨迹
335 | (predicted_future_traj, predicted_his_traj,
336 | predicted_feature, predicted_traj_mask) = cls.get_pred_feature(
337 | config_data, predicted_index, all_obs_his_traj,
338 | all_obs_future_traj, all_obs_feature
339 | )
340 | # 交通灯信息
341 | traffic_light = torch.Tensor(pkl_data[3])
342 | traffic_light_pos = torch.Tensor(pkl_data[4])
343 | traffic_light, traffic_light_pos, traffic_mask = cls.get_traffic_light(
344 | config_data, traffic_light, traffic_light_pos
345 | )
346 | # 车道线信息
347 | lane_list = cls.get_lane_feature(config_data, pkl_data[-1])
348 | map_json = json.dumps(pkl_data[-1])
349 | other_his_traj_delt = other_his_traj[:, 1:] - other_his_traj[:, :-1]
350 | other_his_pos = other_his_traj[:, -1, :2]
351 | predicted_his_traj_delt = predicted_his_traj[:, 1:] - predicted_his_traj[:, :-1]
352 | predicted_his_pos = predicted_his_traj[:, -1, :2]
353 | result = {
354 | "other_his_traj": other_his_traj,
355 | "other_feature": other_feature,
356 | "other_traj_mask": other_traj_mask,
357 | "other_his_traj_delt": other_his_traj_delt,
358 | "other_his_pos": other_his_pos,
359 | "predicted_future_traj": predicted_future_traj,
360 | "predicted_his_traj": predicted_his_traj,
361 | "predicted_traj_mask": predicted_traj_mask,
362 | "predicted_feature": predicted_feature,
363 | "predicted_his_traj_delt": predicted_his_traj_delt,
364 | "predicted_his_pos": predicted_his_pos,
365 | "traffic_light": traffic_light,
366 | "traffic_light_pos": traffic_light_pos,
367 | "traffic_mask": traffic_mask,
368 | "lane_list": lane_list,
369 | "map_json": map_json,
370 | "curr_loc": data_dict['curr_loc']
371 | }
372 | return result
373 |
--------------------------------------------------------------------------------
/gene_submission.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | import tarfile
5 | from typing import Tuple, Dict, Any, List
6 |
7 | import cv2
8 | import numpy as np
9 | import tensorflow as tf
10 | import torch
11 | import tqdm
12 | from waymo_open_dataset.protos import scenario_pb2
13 | from waymo_open_dataset.utils.sim_agents import submission_specs
14 | from waymo_open_dataset.utils.sim_agents import submission_specs
15 | from waymo_open_dataset.wdl_limited.sim_agents_metrics import metrics
16 | from waymo_open_dataset.utils import trajectory_utils
17 |
18 | from nets import (GaussianDiffusion, generate_linear_schedule, SimpleViT, TrajDecorder, MapEncoder)
19 |
20 | SHIFT_H = 125
21 | SHIFT_W = 125
22 | PIX_SIZE = 0.8
23 | MAX_LANE_NUM = 50
24 | MAX_POINT_NUM = 250
25 | HIS_STEP = 11
26 | FUTRUE_STEP = 80
27 | MAX_PRED_NUM = 8
28 | MAX_OTHER_NUM = 6
29 | MAX_TRAFFIC_LIGHT_NUM = 8
30 | MAX_LANE_NUM = 32
31 | MAX_POINT_NUM = 128
32 |
33 | transform_arr = np.array([SHIFT_W, SHIFT_H])
34 | pix_size = PIX_SIZE
35 | object_type = {
36 | 'TYPE_UNSET': 0,
37 | 'TYPE_VEHICLE': 1,
38 | 'TYPE_PEDESTRIAN': 2,
39 | 'TYPE_CYCLIST': 3,
40 | 'TYPE_OTHER': 4
41 | }
42 |
43 | ObjectType = {0: "TYPE_UNSET", 1: "TYPE_VEHICLE", 2: "TYPE_PEDESTRIAN", 3: "TYPE_CYCLIST", 4: "TYPE_OTHER"}
44 | MapState = {0: "LANE_STATE_UNKNOWN", 1: "LANE_STATE_ARROW_STOP", 2: "LANE_STATE_ARROW_CAUTION",
45 | 3: "LANE_STATE_ARROW_GO", 4: "LANE_STATE_STOP", 5: "LANE_STATE_CAUTION",
46 | 6: "LANE_STATE_GO", 7: "LANE_STATE_FLASHING_STOP", 8: "LANE_STATE_FLASHING_CAUTION"}
47 |
48 | VALID_PATH = r"/mnt/share_disk/yangchen/Scene-Diffusion/valid_set"
49 | TEST_PATH = r"/mnt/share_disk/waymo_dataset/waymodata_1_1/testing"
50 | MODEL_PATH = r"/mnt/share_disk/yangchen/Scene-Diffusion/diffusion_model_last_epoch_weights.pth"
51 |
52 |
53 | def global_to_local(curr_x: float, curr_y: float, curr_heading: float,
54 | point_x: float, point_y: float) -> Tuple[float, float]:
55 | delta_x = point_x - curr_x
56 | delta_y = point_y - curr_y
57 | return delta_x * math.cos(curr_heading) + delta_y * math.sin(curr_heading), \
58 | delta_y * math.cos(curr_heading) - delta_x * math.sin(curr_heading)
59 |
60 |
61 | def theta_global_to_local(curr_heading: float, heading: float) -> float:
62 | """
63 | 将世界坐标系下的角度转成自车坐标系下的角度
64 | Args:
65 | curr_heading: 自车在世界系的转角
66 | heading: 要转换的heading
67 | Returns: 自车坐标系下的heading
68 | """
69 | return normalize_angle(-curr_heading + heading)
70 |
71 |
72 | def normalize_angle(angle: float) -> float:
73 | """
74 | 归一化弧度值, 使其范围在[-pi, pi]之间
75 | Args:
76 | angle: 输入的弧度
77 | Returns: 归一化之后的弧度
78 | """
79 | angle = (angle + math.pi) % (2 * math.pi)
80 | if angle < .0:
81 | angle += 2 * math.pi
82 | return angle - math.pi
83 |
84 |
85 | def draw_one_rect(traj_list: List, image: np.ndarray,
86 | box_points: np.ndarray, color: int) -> Tuple[np.ndarray, np.ndarray]:
87 | one_step_color = color / 11
88 | image_list = list()
89 | for i in range(0, HIS_STEP):
90 | yaw = traj_list[i][2]
91 | rotate_matrix = np.array([np.cos(yaw), -np.sin(yaw),
92 | np.sin(yaw), np.cos(yaw)]).reshape(2, 2)
93 |
94 | # 坐标变换
95 | box_points_temp = box_points @ -rotate_matrix
96 | coord = np.array([traj_list[i][0], traj_list[i][1]])
97 | box_points_temp = box_points_temp + coord
98 | box_points_temp = box_points_temp / pix_size + transform_arr
99 | box_points_temp = np.round(box_points_temp).astype(int).reshape(1, -1, 2)
100 | color_temp = int(one_step_color * (i + 1))
101 | temp_image = cv2.fillPoly(
102 | np.zeros_like(image),
103 | box_points_temp,
104 | color=[color_temp],
105 | shift=0,
106 | )
107 | image_list.append(temp_image)
108 | image_list.append(image)
109 | image = np.concatenate(image_list, axis=2)
110 | image = np.expand_dims(np.max(image, axis=2), axis=2)
111 | return image, np.array(traj_list[:HIS_STEP])
112 |
113 |
114 | def gene_model_input_step_three(one_pkl_data: Tuple, predicted_obs_id: List) -> List:
115 | result = list()
116 | if len(one_pkl_data[1]) != len(predicted_obs_id):
117 | raise ValueError("data preprocess predicted_obs_id size != predicted_obs_id size")
118 | map_image = torch.Tensor(one_pkl_data[0])
119 | all_obs_index = set([i for i in range(one_pkl_data[2].shape[0])])
120 | predicted_index = torch.Tensor(one_pkl_data[1]).to(torch.long).view(-1, 1, 1)
121 | all_obs_traj = torch.Tensor(one_pkl_data[2])
122 | all_obs_feature = torch.Tensor(one_pkl_data[3])
123 | all_obs_his_traj = all_obs_traj[:, :HIS_STEP]
124 | other_obs_index = list(all_obs_index - set(one_pkl_data[1]))
125 | if len(other_obs_index) > 0:
126 | # other obs历史轨迹
127 | other_obs_traj_index = torch.Tensor(other_obs_index).to(torch.long).view(-1, 1, 1).repeat((1, HIS_STEP, 5))
128 | other_his_traj = torch.gather(all_obs_his_traj, 0, other_obs_traj_index)
129 | other_obs_feature_index = torch.Tensor(other_obs_index).to(torch.long).view(-1, 1).repeat((1, 7))
130 | # other feature
131 | other_feature = torch.gather(all_obs_feature, 0, other_obs_feature_index)
132 | other_his_traj = other_his_traj[:MAX_OTHER_NUM]
133 | other_feature = other_feature[:MAX_OTHER_NUM]
134 | num_gap = MAX_OTHER_NUM - other_his_traj.shape[0]
135 | if num_gap > 0:
136 | other_his_traj = torch.cat((other_his_traj,
137 | torch.zeros(size=(num_gap, HIS_STEP, 5),
138 | dtype=torch.float32)), dim=0)
139 | other_feature = torch.cat((other_feature,
140 | torch.zeros(size=(num_gap, 7),
141 | dtype=torch.float32)), dim=0)
142 | other_traj_mask = torch.Tensor([1.0] * (other_his_traj.shape[0] - num_gap) + [0.0] * num_gap)
143 | else:
144 | other_his_traj = torch.zeros(size=(MAX_OTHER_NUM, HIS_STEP, 5), dtype=torch.float32)
145 | other_feature = torch.zeros(size=(MAX_OTHER_NUM, 7), dtype=torch.float32)
146 | other_traj_mask = torch.zeros(size=[MAX_OTHER_NUM], dtype=torch.float32)
147 | # 交通灯信息
148 | traffic_light = torch.Tensor(one_pkl_data[4])
149 | traffic_light_pos = torch.Tensor(one_pkl_data[5])
150 | traffic_light = traffic_light[:MAX_TRAFFIC_LIGHT_NUM]
151 | traffic_light_pos = traffic_light_pos[:MAX_TRAFFIC_LIGHT_NUM]
152 | num_gap = MAX_TRAFFIC_LIGHT_NUM - traffic_light.shape[0]
153 | if num_gap > 0:
154 | traffic_light = torch.cat((traffic_light,
155 | torch.zeros(size=(num_gap, 91),
156 | dtype=torch.float32)), dim=0)
157 | traffic_light_pos = torch.cat((traffic_light_pos,
158 | torch.zeros(size=(num_gap, 2),
159 | dtype=torch.float32)), dim=0)
160 | traffic_mask = torch.Tensor([1.0] * (traffic_light_pos.shape[0] - num_gap) + [0.0] * num_gap)
161 | # 车道线信息
162 | lane_list = list()
163 | for lane_index, lane in enumerate(one_pkl_data[6]):
164 | if lane_index >= MAX_LANE_NUM:
165 | break
166 | point_list = lane[:MAX_POINT_NUM]
167 | point_gap = MAX_POINT_NUM - len(point_list)
168 | # MAX_POINT_NUM, 2
169 | point_list = torch.Tensor(point_list)
170 | if point_gap > 0:
171 | point_list = torch.cat((point_list,
172 | torch.zeros(size=(point_gap, 2),
173 | dtype=torch.float32)), dim=0)
174 | lane_list.append(point_list)
175 | lane_gap = MAX_LANE_NUM - len(lane_list)
176 | if len(lane_list) > 0:
177 | lane_list = torch.stack(lane_list, dim=0)
178 | if lane_gap > 0:
179 | lane_list = torch.cat((lane_list,
180 | torch.zeros(size=(lane_gap, MAX_POINT_NUM, 2),
181 | dtype=torch.float32)), dim=0)
182 | else:
183 | lane_list = torch.zeros(size=(MAX_LANE_NUM, MAX_POINT_NUM, 2), dtype=torch.float32)
184 | map_json = json.dumps(one_pkl_data[6])
185 | other_his_traj_delt = other_his_traj[:, 1:, :2] - other_his_traj[:, :-1, :2]
186 | other_his_traj_delt = torch.cat((other_his_traj_delt, other_his_traj[:, 1:, 2:]), dim=-1)
187 | other_his_pos = other_his_traj[:, -1, :2]
188 | # 如果需要预测的障碍物大于8个,需要分批次推理
189 | predicted_obs_batch_num = (len(one_pkl_data[1]) // MAX_PRED_NUM) + 1
190 | predicted_his_traj = torch.gather(all_obs_his_traj, 0, predicted_index.repeat((1, HIS_STEP, 5)))
191 | predicted_feature = torch.gather(all_obs_feature, 0, predicted_index.view(-1, 1).repeat((1, 7)))
192 | for i in range(predicted_obs_batch_num):
193 | start_index = i * MAX_PRED_NUM
194 | end_index = start_index + MAX_PRED_NUM - 1
195 | # 需要预测的障碍物历史轨迹
196 | predicted_his_traj_batch = predicted_his_traj[start_index:end_index]
197 | # predicted feature
198 | predicted_feature_batch = predicted_feature[start_index:end_index]
199 | # mask
200 | num_gap = MAX_PRED_NUM - predicted_his_traj_batch.shape[0]
201 | predicted_traj_mask = torch.Tensor([1.0] * (predicted_his_traj_batch.shape[0]) + [0.0] * num_gap)
202 | if num_gap > 0:
203 | predicted_his_traj_batch = torch.cat((predicted_his_traj_batch,
204 | torch.zeros(size=(num_gap, HIS_STEP, 5),
205 | dtype=torch.float32)), dim=0)
206 | predicted_feature_batch = torch.cat((predicted_feature_batch,
207 | torch.zeros(size=(num_gap, 7),
208 | dtype=torch.float32)), dim=0)
209 | predicted_his_traj_delt = predicted_his_traj_batch[:, 1:, :2] - predicted_his_traj_batch[:, :-1, :2]
210 | predicted_his_traj_delt = torch.cat((predicted_his_traj_delt, predicted_his_traj_batch[:, 1:, 2:]), dim=-1)
211 | predicted_his_pos = predicted_his_traj_batch[:, -1, :2]
212 | result_dict = {
213 | "map_image": map_image,
214 | "other_his_traj": other_his_traj,
215 | "other_feature": other_feature,
216 | "other_traj_mask": other_traj_mask,
217 | "other_his_traj_delt": other_his_traj_delt,
218 | "other_his_pos": other_his_pos,
219 | "predicted_traj_mask": predicted_traj_mask,
220 | "predicted_his_traj": predicted_his_traj_batch,
221 | "predicted_feature": predicted_feature_batch,
222 | "predicted_his_traj_delt": predicted_his_traj_delt,
223 | "predicted_his_pos": predicted_his_pos,
224 | "traffic_light": traffic_light,
225 | "traffic_light_pos": traffic_light_pos,
226 | "traffic_mask": traffic_mask,
227 | "lane_list": lane_list,
228 | "map_json": map_json
229 | }
230 | for key, value in result_dict.items():
231 | if not isinstance(value, str):
232 | result_dict[key] = value.unsqueeze(dim=0)
233 |
234 | result.append(result_dict)
235 | return result
236 |
237 |
238 | def gene_model_input_step_two(data_dict: Dict[str, Any]) -> Tuple:
239 | # 把车道线光栅化为图片
240 | image_roadmap = np.zeros((256, 256, 1), dtype=np.uint8)
241 | image_vru_traj = np.zeros((256, 256, 1), dtype=np.uint8)
242 | image_car_traj = np.zeros((256, 256, 1), dtype=np.uint8)
243 | for road_info in data_dict['map_features']:
244 | if len(road_info['polygon_points']) <= 0:
245 | continue
246 | lane_points = np.array(road_info['polygon_points'])
247 | lane_points = lane_points / pix_size + transform_arr
248 | lane_points = np.round(lane_points).astype(int)
249 | image_roadmap = cv2.polylines(
250 | image_roadmap,
251 | [lane_points],
252 | False,
253 | [255],
254 | shift=0,
255 | )
256 | # 障碍物轨迹,光栅化图片或做成vector
257 | traj_list = list()
258 | obs_feature_list = list()
259 | valid_index = list()
260 | index = 0
261 | predicted_obs_ids = data_dict['predicted_obs_ids']
262 | predicted_obs_index = list()
263 | predicted_obs_real_index = list()
264 | for obs_index, one_obs_info in enumerate(data_dict['obs_tracks']):
265 | if len(one_obs_info['obs_traj']) < 11 or None in one_obs_info['obs_traj']:
266 | continue
267 | obs_feature = list()
268 | obs_feature.append(one_obs_info['width'])
269 | obs_feature.append(one_obs_info['length'])
270 | type_onehot = [0] * len(object_type)
271 | type_onehot[object_type[one_obs_info['object_type']]] = 1
272 | obs_feature += type_onehot
273 | obs_feature_list.append(obs_feature)
274 | if one_obs_info['obs_id'] in predicted_obs_ids:
275 | predicted_obs_index.append(index)
276 | predicted_obs_real_index.append(one_obs_info['obs_id'])
277 | obs_half_width = 0.5 * one_obs_info['width']
278 | obs_half_length = 0.5 * one_obs_info['length']
279 | box_points = np.array([-obs_half_length, -obs_half_width,
280 | obs_half_length, -obs_half_width,
281 | obs_half_length, obs_half_width,
282 | -obs_half_length, obs_half_width]).reshape(4, 2).astype(np.float32)
283 |
284 | if one_obs_info['object_type'] in ('TYPE_PEDESTRIAN', 'TYPE_CYCLIST'):
285 | image_vru_traj, traj = draw_one_rect(one_obs_info['obs_traj'],
286 | image_vru_traj,
287 | box_points, 255)
288 |
289 | elif one_obs_info['object_type'] == 'TYPE_VEHICLE':
290 | valid_index.append(index)
291 | image_car_traj, traj = draw_one_rect(one_obs_info['obs_traj'],
292 | image_car_traj,
293 | box_points, 255)
294 |
295 | else:
296 | image_car_traj, traj = draw_one_rect(one_obs_info['obs_traj'],
297 | image_car_traj,
298 | box_points, 125)
299 | traj_list.append(traj)
300 | index += 1
301 | # 动态地图信息
302 | dynamic_states = list()
303 | dynamic_pos = list()
304 | for key, value in data_dict['dynamic_states'].items():
305 | dynamic_pos.append(value[0])
306 | dynamic_state = value[1:]
307 | if len(dynamic_state) < 91:
308 | dynamic_state = dynamic_state + [0] * (91 - len(dynamic_state))
309 | dynamic_states.append(dynamic_state)
310 | if len(predicted_obs_index) < 1:
311 | raise ValueError("predicted_obs_index < 1")
312 | image = image_roadmap
313 | traj_arr = np.stack(traj_list, axis=0)
314 | obs_feature_list = np.stack(obs_feature_list, axis=0)
315 |
316 | dynamic_states = np.array(dynamic_states)
317 | dynamic_pos = np.array(dynamic_pos)
318 | map_feature_list = [feature['polygon_points'] for feature in data_dict['map_features']]
319 | one_pkl_data = (image.transpose(2, 1, 0), predicted_obs_index, traj_arr,
320 | obs_feature_list, dynamic_states,
321 | dynamic_pos, map_feature_list, predicted_obs_real_index)
322 | return one_pkl_data
323 |
324 |
325 | def gene_model_input_step_one(scenario: scenario_pb2.Scenario) -> Dict[str, Any]:
326 | data_dict = dict()
327 | sdc_track_index = scenario.sdc_track_index
328 | current_time_index = scenario.current_time_index
329 | obs_tracks = list()
330 | # 获取自车当前位置把所有轨迹和车道线转化成自车坐标系
331 | curr_state = scenario.tracks[scenario.sdc_track_index].states[current_time_index]
332 | ego_curr_x, ego_curr_y = curr_state.center_x, curr_state.center_y
333 | ego_curr_heading = curr_state.heading
334 | predicted_obs_ids = list()
335 | for predicted_obs in scenario.tracks_to_predict:
336 | predicted_obs_ids.append(scenario.tracks[predicted_obs.track_index].id)
337 | predicted_obs_ids.append(scenario.tracks[scenario.sdc_track_index].id)
338 | # 一个障碍物的状态(id, 类型, 轨迹)
339 | for track in scenario.tracks:
340 | one_obs_track = dict()
341 | obs_id = track.id
342 | if hasattr(track, "object_type"):
343 | object_type = ObjectType[track.object_type]
344 | else:
345 | object_type = ObjectType[4]
346 | obs_traj = list()
347 | # 遍历障碍物轨迹
348 | for state in track.states:
349 | if not state.valid and obs_id not in predicted_obs_ids:
350 | continue
351 | if not state.valid and obs_id in predicted_obs_ids and len(obs_traj) > 0:
352 | obs_traj.append(obs_traj[-1])
353 | continue
354 | if 'height' not in one_obs_track and 'length' not in one_obs_track \
355 | and 'width' not in one_obs_track:
356 | one_obs_track['height'] = state.height
357 | one_obs_track['length'] = state.length
358 | one_obs_track['width'] = state.width
359 | center_x, center_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
360 | state.center_x, state.center_y)
361 | center_heading = theta_global_to_local(ego_curr_heading, state.heading)
362 | # 速度变化有两种方法:
363 | # 1. 在新坐标系做投影
364 | # 2. 计算新坐标系的航向角,用速度乘sin和cos
365 | curr_v = math.sqrt(math.pow(state.velocity_x, 2) +
366 | math.pow(state.velocity_y, 2))
367 | curr_v_heading = math.atan2(state.velocity_y, state.velocity_x)
368 | curr_v_heading = theta_global_to_local(ego_curr_heading, curr_v_heading)
369 | obs_traj.append(
370 | (center_x, center_y, center_heading,
371 | curr_v * math.cos(curr_v_heading), curr_v * math.sin(curr_v_heading))
372 | )
373 | one_obs_track['obs_id'] = obs_id
374 | one_obs_track['object_type'] = object_type
375 | one_obs_track['obs_traj'] = obs_traj
376 | # 轨迹丢失的障碍物不需要
377 | if len(obs_traj) >= 11:
378 | obs_tracks.append(one_obs_track)
379 | # 全局地图信息
380 | map_features = list()
381 | for map_feature in scenario.map_features:
382 | one_map_dict = dict()
383 | map_id = map_feature.id
384 | polygon_points = list()
385 | if hasattr(map_feature, "road_edge") and map_feature.road_edge.polyline:
386 | map_type = "road_edge"
387 | polygon_list = map_feature.road_edge.polyline
388 | elif hasattr(map_feature, "road_line") and map_feature.road_line.polyline:
389 | map_type = "road_line"
390 | polygon_list = map_feature.road_line.polyline
391 | else:
392 | continue
393 | for polygon_point in polygon_list:
394 | polygon_point_x, polygon_point_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
395 | polygon_point.x, polygon_point.y)
396 | polygon_points.append((polygon_point_x, polygon_point_y))
397 | one_map_dict['map_id'] = map_id
398 | one_map_dict['map_type'] = map_type
399 | one_map_dict['polygon_points'] = polygon_points
400 | map_features.append(one_map_dict)
401 | # 动态地图信息
402 | dynamic_states = dict()
403 | for dynamic_state in scenario.dynamic_map_states:
404 | for lane_state in dynamic_state.lane_states:
405 | lane_id = lane_state.lane
406 | lane_x, lane_y = global_to_local(ego_curr_x, ego_curr_y, ego_curr_heading,
407 | lane_state.stop_point.x,
408 | lane_state.stop_point.y)
409 | if lane_id not in dynamic_states:
410 | dynamic_states[lane_id] = list()
411 | dynamic_states[lane_id].append((lane_x, lane_y))
412 | state = lane_state.state
413 | dynamic_states[lane_id].append(state)
414 | data_dict['predicted_obs_ids'] = predicted_obs_ids
415 | data_dict['obs_tracks'] = obs_tracks
416 | data_dict['map_features'] = map_features
417 | data_dict['dynamic_states'] = dynamic_states
418 | data_dict['curr_loc'] = (ego_curr_x, ego_curr_y, ego_curr_heading, sdc_track_index)
419 | return data_dict
420 |
421 |
422 | def inference(input_batch: List) -> np.ndarray:
423 | input_shape = (256, 256)
424 | num_timesteps = 100
425 | schedule_low = 1e-4
426 | schedule_high = 0.008
427 | betas = generate_linear_schedule(
428 | num_timesteps,
429 | schedule_low * 1000 / num_timesteps,
430 | schedule_high * 1000 / num_timesteps,
431 | )
432 | diffusion_model = GaussianDiffusion(SimpleViT(), MapEncoder(), TrajDecorder(),
433 | input_shape, 3, betas=betas)
434 | if torch.cuda.is_available() and torch.cuda.device_count() > 1 and False:
435 | device = torch.device('cuda')
436 | else:
437 | device = torch.device('cpu')
438 | diffusion_model = diffusion_model.to(device)
439 | model_dict = diffusion_model.state_dict()
440 | pretrained_dict = torch.load(MODEL_PATH, map_location=device)
441 | # 模型参数赋值
442 | new_model_dict = dict()
443 | for key in model_dict.keys():
444 | if ("module." + key) in pretrained_dict:
445 | new_model_dict[key] = pretrained_dict["module." + key]
446 | elif key in pretrained_dict:
447 | new_model_dict[key] = pretrained_dict[key]
448 | else:
449 | print("key: ", key, ", not in pretrained")
450 | diffusion_model.load_state_dict(new_model_dict)
451 | print("load param success")
452 | results = list()
453 | for one_batch_input in input_batch:
454 | result = diffusion_model.sample(one_batch_input)[0]
455 | results.append(result)
456 | results = torch.cat(results, dim=0)
457 | return results.numpy()
458 |
459 |
460 | def local_to_global(ego_heading: float, position_x: np.ndarray, position_y: np.ndarray, ego_local_x: float,
461 | ego_local_y: float) -> Tuple[np.ndarray, np.ndarray]:
462 | """
463 | 将世界坐标系下的x,y坐标转成自车坐标系下的坐标
464 | Args:
465 | ego_heading: 自车在世界系的转角
466 | position_x: 要转换的x坐标
467 | position_y: 要转换的y坐标
468 | ego_local_x: 自车在世界系的位置
469 | ego_local_y: 自车在世界系的位置
470 | Returns: 世界坐标系的x坐标, 世界坐标系的y坐标
471 | """
472 | yaw = ego_heading
473 | # global_x = [(ego_local_x + x * math.cos(yaw) - y * math.sin(yaw)) for x, y in zip(position_x.tolist(), position_y.tolist())]
474 | # global_y = [(ego_local_y + x * math.sin(yaw) + y * math.cos(yaw)) for x, y in zip(position_x.tolist(), position_y.tolist())]
475 | # return np.array(global_x), np.array(global_y)
476 | global_x = ego_local_x + position_x * math.cos(yaw) - position_y * math.sin(yaw)
477 | global_y = ego_local_y + position_x * math.sin(yaw) + position_y * math.cos(yaw)
478 | return global_x, global_y
479 |
480 |
481 | def theta_local_to_global(ego_heading: float, heading: np.ndarray) -> np.ndarray:
482 | """
483 | 将自车坐标系下的角度转成世界坐标系下的角度
484 | Args:
485 | ego_heading: 自车在世界系的转角
486 | heading: 要转换的heading
487 | Returns: 世界标系下的heading
488 | """
489 | heading = heading.tolist()
490 | heading_list = []
491 | for one_heading in heading:
492 | heading_list.append(normalize_angle(ego_heading + one_heading))
493 | return np.array(heading_list)
494 |
495 |
496 | def simulate_with_extrapolation_new(
497 | scenario: scenario_pb2.Scenario,
498 | print_verbose_comments: bool = True) -> tf.Tensor:
499 | vprint = print if print_verbose_comments else lambda arg: None
500 |
501 | # To load the data, we create a simple tensorized version of the object tracks.
502 | logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario)
503 | # Using `ObjectTrajectories` we can select just the objects that we need to
504 | # simulate and remove the "future" part of the Scenario.
505 | vprint(f'Original shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)')
506 | logged_trajectories = logged_trajectories.gather_objects_by_id(
507 | tf.convert_to_tensor(submission_specs.get_sim_agent_ids(scenario)))
508 | logged_trajectories = logged_trajectories.slice_time(
509 | start_index=0, end_index=submission_specs.CURRENT_TIME_INDEX + 1)
510 | vprint(f'Modified shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)')
511 | # We can verify that all of these objects are valid at the last step.
512 | vprint(f'Are all agents valid: {tf.reduce_all(logged_trajectories.valid[:, -1]).numpy()}')
513 | # 数据预处理和模型推理
514 | all_logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario)
515 | all_logged_trajectories = all_logged_trajectories.slice_time(
516 | start_index=0, end_index=submission_specs.N_FULL_SCENARIO_STEPS + 1)
517 | logged_pred = all_logged_trajectories.gather_objects_by_id(
518 | tf.convert_to_tensor(submission_specs.get_evaluation_sim_agent_ids(scenario)))
519 | # if not tf.reduce_all(logged_pred.valid):
520 | # print("logged_pred include invalid state")
521 | predicted_obs_id = submission_specs.get_evaluation_sim_agent_ids(scenario)
522 | data_dict = gene_model_input_step_one(scenario)
523 | one_pkl_data = gene_model_input_step_two(data_dict)
524 | input_batch = gene_model_input_step_three(one_pkl_data, predicted_obs_id)
525 | predicted_obs_traj = inference(input_batch)
526 | predicted_obs_id_in_pkl = one_pkl_data[7]
527 | predicted_obs_traj = predicted_obs_traj[:len(predicted_obs_id_in_pkl)]
528 | # 推理出来的轨迹与id一一对应
529 | predicted_obs_id_traj = {obs_id: predicted_obs_traj[index] for index, obs_id in enumerate(predicted_obs_id_in_pkl)}
530 | # 自车在当前时刻的位置
531 | curr_loc = data_dict['curr_loc']
532 |
533 | simulated_states = list()
534 | for index, obs_id in enumerate(submission_specs.get_sim_agent_ids(scenario)):
535 | if obs_id not in predicted_obs_id_traj.keys():
536 | simulated_states.append(np.zeros(shape=(80, 4)))
537 | else:
538 | one_predicted_obs_traj = predicted_obs_id_traj[obs_id]
539 | one_predicted_obs_x = one_predicted_obs_traj[:, 0]
540 | one_predicted_obs_y = one_predicted_obs_traj[:, 1]
541 | one_predicted_obs_z = np.array([float(logged_trajectories.z[:, -1][index])] * 80)
542 | one_predicted_obs_x, one_predicted_obs_y = local_to_global(curr_loc[2], one_predicted_obs_x,
543 | one_predicted_obs_y, curr_loc[0], curr_loc[1])
544 | one_predicted_obs_heading = theta_local_to_global(curr_loc[2], one_predicted_obs_traj[:, 2])
545 | one_simulated_state = np.stack((one_predicted_obs_x, one_predicted_obs_y,
546 | one_predicted_obs_z, one_predicted_obs_heading), axis=-1)
547 | simulated_states.append(one_simulated_state)
548 | simulated_states = np.stack(simulated_states, axis=0)
549 | simulated_states = np.stack([simulated_states] * submission_specs.N_ROLLOUTS, axis=0)
550 | simulated_states = tf.convert_to_tensor(simulated_states)
551 | return logged_trajectories, simulated_states
552 |
553 |
554 | def simulate_with_extrapolation(
555 | scenario: scenario_pb2.Scenario,
556 | print_verbose_comments: bool = True) -> tf.Tensor:
557 | vprint = print if print_verbose_comments else lambda arg: None
558 |
559 | # To load the data, we create a simple tensorized version of the object tracks.
560 | logged_trajectories = trajectory_utils.ObjectTrajectories.from_scenario(scenario)
561 | # Using `ObjectTrajectories` we can select just the objects that we need to
562 | # simulate and remove the "future" part of the Scenario.
563 | vprint(f'Original shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)')
564 | logged_trajectories = logged_trajectories.gather_objects_by_id(
565 | tf.convert_to_tensor(submission_specs.get_sim_agent_ids(scenario)))
566 | logged_trajectories = logged_trajectories.slice_time(
567 | start_index=0, end_index=submission_specs.CURRENT_TIME_INDEX + 1)
568 | vprint(f'Modified shape of tensors inside trajectories: {logged_trajectories.valid.shape} (n_objects, n_steps)')
569 |
570 | # We can verify that all of these objects are valid at the last step.
571 | vprint(f'Are all agents valid: {tf.reduce_all(logged_trajectories.valid[:, -1]).numpy()}')
572 |
573 | # We extract the speed of the sim agents (in the x/y/z components) ready for
574 | # extrapolation (this will be our policy).
575 | states = tf.stack([logged_trajectories.x, logged_trajectories.y,
576 | logged_trajectories.z, logged_trajectories.heading],
577 | axis=-1)
578 | n_objects, n_steps, _ = states.shape
579 | last_velocities = states[:, -1, :3] - states[:, -2, :3]
580 | # We also make the heading constant, so concatenate 0. as angular speed.
581 | last_velocities = tf.concat(
582 | [last_velocities, tf.zeros((n_objects, 1))], axis=-1)
583 | # It can happen that the second to last state of these sim agents might be
584 | # invalid, so we will set a zero speed for them.
585 | vprint(f'Is any 2nd to last state invalid: {tf.reduce_any(tf.logical_not(logged_trajectories.valid[:, -2]))}')
586 | vprint(f'This will result in either min or max speed to be really large: {tf.reduce_max(tf.abs(last_velocities))}')
587 | valid_diff = tf.logical_and(logged_trajectories.valid[:, -1],
588 | logged_trajectories.valid[:, -2])
589 | # `last_velocities` shape: (n_objects, 4).
590 | last_velocities = tf.where(valid_diff[:, tf.newaxis],
591 | last_velocities,
592 | tf.zeros_like(last_velocities))
593 | vprint(f'Now this should be back to a normal value: {tf.reduce_max(tf.abs(last_velocities))}')
594 |
595 | # Now we carry over a simulation. As we discussed, we actually want 32 parallel
596 | # simulations, so we make this batched from the very beginning. We add some
597 | # random noise on top of our actions to make sure the behaviours are different.
598 | # To properly scale the noise, we get the max velocities (average over all
599 | # objects, corresponding to axis 0) in each of the dimensions (x/y/z/heading).
600 | NOISE_SCALE = 0.01
601 | # `max_action` shape: (4,).
602 | max_action = tf.reduce_max(last_velocities, axis=0)
603 | # We create `simulated_states` with shape (n_rollouts, n_objects, n_steps, 4).
604 | simulated_states = tf.tile(states[tf.newaxis, :, -1:, :], [submission_specs.N_ROLLOUTS, 1, 1, 1])
605 | vprint(f'Shape: {simulated_states.shape}')
606 |
607 | for step in range(submission_specs.N_SIMULATION_STEPS):
608 | current_state = simulated_states[:, :, -1, :]
609 | # Random actions, take a normal and normalize by min/max actions
610 | action_noise = tf.random.normal(
611 | current_state.shape, mean=0.0, stddev=NOISE_SCALE)
612 | actions_with_noise = last_velocities[None, :, :] + (action_noise * max_action)
613 | next_state = current_state + actions_with_noise
614 | simulated_states = tf.concat(
615 | [simulated_states, next_state[:, :, None, :]], axis=2)
616 |
617 | # We also need to remove the first time step from `simulated_states` (it was
618 | # still history).
619 | # `simulated_states` shape before: (n_rollouts, n_objects, 81, 4).
620 | # `simulated_states` shape after: (n_rollouts, n_objects, 80, 4).
621 | simulated_states = simulated_states[:, :, 1:, :]
622 | vprint(f'Final simulated states shape: {simulated_states.shape}')
623 |
624 | return logged_trajectories, simulated_states
625 |
626 |
627 | def joint_scene_from_states(
628 | states: tf.Tensor, object_ids: tf.Tensor
629 | ) -> sim_agents_submission_pb2.JointScene:
630 | # States shape: (num_objects, num_steps, 4).
631 | # Objects IDs shape: (num_objects,).
632 | states = states.numpy()
633 | simulated_trajectories = []
634 | for i_object in range(len(object_ids)):
635 | simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(
636 | center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],
637 | center_z=states[i_object, :, 2], heading=states[i_object, :, 3],
638 | object_id=object_ids[i_object]
639 | ))
640 | return sim_agents_submission_pb2.JointScene(
641 | simulated_trajectories=simulated_trajectories)
642 |
643 |
644 | # Now we can replicate this strategy to export all the parallel simulations.
645 | def scenario_rollouts_from_states(
646 | scenario: scenario_pb2.Scenario,
647 | states: tf.Tensor, object_ids: tf.Tensor
648 | ) -> sim_agents_submission_pb2.ScenarioRollouts:
649 | # States shape: (num_rollouts, num_objects, num_steps, 4).
650 | # Objects IDs shape: (num_objects,).
651 | joint_scenes = []
652 | for i_rollout in range(states.shape[0]):
653 | joint_scenes.append(joint_scene_from_states(states[i_rollout], object_ids))
654 | return sim_agents_submission_pb2.ScenarioRollouts(
655 | # Note: remember to include the Scenario ID in the proto message.
656 | joint_scenes=joint_scenes, scenario_id=scenario.scenario_id)
657 |
658 |
659 | def inference_valid_set():
660 | file_names = os.path.join(VALID_PATH, "validation*")
661 | match_filenames = tf.io.matching_files(file_names)
662 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data")
663 | dataset_iterator = dataset.as_numpy_iterator()
664 | for data in dataset_iterator:
665 | scenario = scenario_pb2.Scenario.FromString(data)
666 | valid_list = [index for index, track in enumerate(scenario.tracks) if track.states[10].valid]
667 | predicted_list = [obs.track_index for obs in scenario.tracks_to_predict]
668 | print(valid_list)
669 | print(predicted_list)
670 | # logged_trajectories, simulated_states = simulate_with_extrapolation(
671 | # scenario, print_verbose_comments=True)
672 |
673 | logged_trajectories, simulated_states = simulate_with_extrapolation_new(
674 | scenario, print_verbose_comments=True)
675 | # # Package the first simulation into a `JointScene`
676 | joint_scene = joint_scene_from_states(simulated_states[0, :, :, :],
677 | logged_trajectories.object_id)
678 | # Validate the joint scene. Should raise an exception if it's invalid.
679 | submission_specs.validate_joint_scene(joint_scene, scenario)
680 | scenario_rollouts = scenario_rollouts_from_states(
681 | scenario, simulated_states, logged_trajectories.object_id)
682 | # As before, we can validate the message we just generate.
683 | submission_specs.validate_scenario_rollouts(scenario_rollouts, scenario)
684 | # Compute the features for a single JointScene.
685 | # single_scene_features = metric_features.compute_metric_features(
686 | # scenario, joint_scene)
687 | config = metrics.load_metrics_config()
688 | scenario_metrics = metrics.compute_scenario_metrics_for_bundle(
689 | config, scenario, scenario_rollouts)
690 | print(scenario_metrics)
691 |
692 |
693 | def inference_test_set():
694 | OUTPUT_ROOT_DIRECTORY = r'waymo_output'
695 | os.makedirs(OUTPUT_ROOT_DIRECTORY, exist_ok=True)
696 | output_filenames = []
697 | file_names = os.path.join(TEST_PATH, "testing*")
698 | match_filenames = tf.io.matching_files(file_names)
699 | for shard_filename in match_filenames:
700 | print(f"{shard_filename} start inference")
701 | # extract the suffix.
702 | shard_suffix = shard_filename.numpy().decode('utf8')[-len('-00000-of-00150'):]
703 | shard_dataset = tf.data.TFRecordDataset([shard_filename])
704 | shard_iterator = shard_dataset.as_numpy_iterator()
705 | scenario_rollouts = []
706 | for scenario_bytes in tqdm.tqdm(shard_iterator):
707 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes)
708 | logged_trajectories, simulated_states = simulate_with_extrapolation_new(
709 | scenario, print_verbose_comments=False)
710 | sr = scenario_rollouts_from_states(
711 | scenario, simulated_states, logged_trajectories.object_id)
712 | submission_specs.validate_scenario_rollouts(sr, scenario)
713 | scenario_rollouts.append(sr)
714 | shard_submission = sim_agents_submission_pb2.SimAgentsChallengeSubmission(
715 | scenario_rollouts=scenario_rollouts,
716 | submission_type=sim_agents_submission_pb2.SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION,
717 | account_name='your_account@test.com',
718 | unique_method_name='sim_agents_tutorial',
719 | authors=['test'],
720 | affiliation='waymo',
721 | description='Submission from the Sim Agents tutorial',
722 | method_link='https://waymo.com/open/'
723 | )
724 | # Now we can export this message to a binproto, saved to local storage.
725 | output_filename = f'submission.binproto{shard_suffix}'
726 | with open(os.path.join(OUTPUT_ROOT_DIRECTORY, output_filename), 'wb') as f:
727 | f.write(shard_submission.SerializeToString())
728 | output_filenames.append(output_filename)
729 |
730 | # Once we have created all the shards, we can package them directly into a
731 | # tar.gz archive, ready for submission.
732 | with tarfile.open(
733 | os.path.join(OUTPUT_ROOT_DIRECTORY, 'submission.tar.gz'), 'w:gz') as tar:
734 | for output_filename in output_filenames:
735 | tar.add(os.path.join(OUTPUT_ROOT_DIRECTORY, output_filename),
736 | arcname=output_filename)
737 |
738 |
739 | def cal_dynamic_map_states(is_test: bool = True):
740 | if is_test:
741 | file_names = os.path.join(TEST_PATH, "testing*")
742 | match_filenames = tf.io.matching_files(file_names)
743 | else:
744 | file_names = os.path.join(VALID_PATH, "validation*")
745 | match_filenames = tf.io.matching_files(file_names)
746 | dataset = tf.data.TFRecordDataset(match_filenames, name="train_data")
747 | dataset_iterator = dataset.as_numpy_iterator()
748 | for scenario_bytes in tqdm.tqdm(dataset_iterator):
749 | scenario = scenario_pb2.Scenario.FromString(scenario_bytes)
750 | print("dynamic_map_states list")
751 | print([len(state.lane_states) for state in scenario.dynamic_map_states])
752 |
753 |
754 | if __name__ == "__main__":
755 | # inference_valid_set()
756 | # inference_test_set()
757 | cal_dynamic_map_states(is_test=False)
758 |
--------------------------------------------------------------------------------