├── src
└── dwm
│ ├── __init__.py
│ ├── fs
│ ├── __init__.py
│ ├── README.md
│ ├── ctar.py
│ ├── czip.py
│ ├── s3fs.py
│ └── dirfs.py
│ ├── datasets
│ ├── __init__.py
│ ├── README.md
│ └── waymo_common.py
│ ├── models
│ ├── __init__.py
│ ├── adapters.py
│ ├── voxelizer.py
│ └── maskgit_base.py
│ ├── utils
│ ├── __init__.py
│ ├── lidar.py
│ ├── carla_actor_state_machines.py
│ ├── make_carla_cameras.py
│ ├── make_blank_code.py
│ ├── preview.py
│ ├── carla_simulation.py
│ ├── carla_control.py
│ └── sampler.py
│ ├── pipelines
│ └── __init__.py
│ ├── tools
│ ├── transcode_video.json
│ ├── tar2zip.py
│ ├── transcode_video.py
│ ├── prepare_opendv.py
│ ├── fs_make_info_json.py
│ ├── export_nusc_2_preview_format.py
│ └── dataset_make_info_json.py
│ ├── metrics
│ ├── general_metrics.py
│ ├── voxel_metrics.py
│ ├── fvd.py
│ └── pc_metrics.py
│ ├── distributed.py
│ ├── evaluate.py
│ ├── streaming.py
│ ├── export_generation_result_as_nuscenes_data.py
│ ├── preview.py
│ ├── common.py
│ ├── functional.py
│ └── schedulers
│ └── temporal_independent.py
├── configs
├── fs
│ ├── local.json
│ ├── local_nuscenes.json
│ ├── s3_cn_sh.json
│ ├── s3_st_sh.json
│ ├── local_czip_nuscenes.json
│ └── s3_czip_nuscenes.json
├── README.md
└── experimental
│ └── simulation
│ └── make_carla_cameras_from_nuscenes.json
├── .gitignore
├── requirements.txt
├── .gitmodules
├── LICENSE
├── README_intro_zh.md
├── docs
├── CtsdPipelineFaqs.md
├── InteractiveGeneration.md
├── Datasets.md
└── LiDAR_Generation.md
└── examples
└── ctsd_generation_example.py
/src/dwm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dwm/fs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dwm/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dwm/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dwm/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/dwm/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/configs/fs/local.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "fsspec.implementations.local.LocalFileSystem"
3 | }
--------------------------------------------------------------------------------
/configs/fs/local_nuscenes.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "dwm.fs.dirfs.DirFileSystem",
3 | "path": "/mnt/storage/user/wuzehuan/Downloads/data/nuscenes"
4 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .vscode
3 | build
4 | dist
5 | *.egg-info
6 | output
7 | /externals/chamferdist/chamferdist/*.so
8 | work_dirs
9 | wandb
10 | taming
11 |
--------------------------------------------------------------------------------
/configs/fs/s3_cn_sh.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
3 | "endpoint_url": "http://aoss-internal.cn-sh-01.sensecoreapi-oss.cn",
4 | "aws_access_key_id": "AE939C3A07AE4E6D93908AA603B9F3A9",
5 | "aws_secret_access_key": ""
6 | }
--------------------------------------------------------------------------------
/configs/fs/s3_st_sh.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
3 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn",
4 | "aws_access_key_id": "4853705CDE93446E8D902F70291C2C92",
5 | "aws_secret_access_key": ""
6 | }
--------------------------------------------------------------------------------
/src/dwm/tools/transcode_video.json:
--------------------------------------------------------------------------------
1 | {
2 | "ffmpeg_args": [
3 | "-r",
4 | "10",
5 | "-g",
6 | "20",
7 | "-vf",
8 | "scale=-1:720",
9 | "-row-mt",
10 | "1",
11 | "-b:v",
12 | "2M",
13 | "-crf",
14 | "10",
15 | "-pix_fmt",
16 | "yuv420p"
17 | ]
18 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Pillow
2 | accelerate==1.4.0
3 | av==14.0.0
4 | bitsandbytes
5 | botocore
6 | diffusers==0.31.0
7 | einops
8 | easydict
9 | fsspec
10 | numpy
11 | pandas
12 | protobuf==4.21.6
13 | pyarrow
14 | safetensors
15 | sentencepiece
16 | scipy
17 | tensorboard
18 | torch-fidelity
19 | torchmetrics==1.6.0
20 | tqdm
21 | transformers==4.50.0
22 | transforms3d
23 | Ninja
24 | chamferdist
25 | timm
26 | rotary-embedding-torch
27 | opencv-python
28 | wandb
29 | git+https://github.com/autonomousvision/kitti360Scripts.git
30 |
--------------------------------------------------------------------------------
/src/dwm/metrics/general_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchmetrics
3 | from torchmetrics import MeanMetric
4 | import torch.distributed
5 |
6 |
7 | class CustomMeanMetrics(MeanMetric):
8 | """
9 | Description:
10 | calculate the mean value of certain metrics
11 | """
12 |
13 | def __init__(
14 | self, **kwargs
15 | ):
16 | super().__init__(**kwargs)
17 |
18 | def compute(self, **kwargs):
19 | self.num_samples = int(self.weight)
20 | return super().compute(**kwargs)
21 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "externals/TATS"]
2 | path = externals/TATS
3 | url = https://github.com/songweige/TATS.git
4 | [submodule "externals/taming-transformers"]
5 | path = externals/taming-transformers
6 | url = https://github.com/CompVis/taming-transformers.git
7 | [submodule "externals/waymo-open-dataset"]
8 | path = externals/waymo-open-dataset
9 | url = https://github.com/waymo-research/waymo-open-dataset.git
10 | [submodule "externals/dvgo_cuda"]
11 | path = externals/dvgo_cuda
12 | url = https://github.com/sunset1995/DirectVoxGO.git
13 |
--------------------------------------------------------------------------------
/configs/fs/local_czip_nuscenes.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "dwm.fs.czip.CombinedZipFileSystem",
3 | "fs": {
4 | "_class_name": "dwm.fs.dirfs.DirFileSystem",
5 | "path": "/mnt/storage/user/wuzehuan"
6 | },
7 | "paths": [
8 | "Downloads/data/nuscenes/v1.0-trainval_meta.zip",
9 | "Downloads/data/nuscenes/v1.0-trainval01_blobs.zip",
10 | "Downloads/data/nuscenes/v1.0-trainval02_blobs.zip",
11 | "Downloads/data/nuscenes/v1.0-trainval03_blobs.zip",
12 | "Downloads/data/nuscenes/v1.0-trainval04_blobs.zip",
13 | "Downloads/data/nuscenes/v1.0-trainval05_blobs.zip",
14 | "Downloads/data/nuscenes/v1.0-trainval06_blobs.zip",
15 | "Downloads/data/nuscenes/v1.0-trainval07_blobs.zip",
16 | "Downloads/data/nuscenes/v1.0-trainval08_blobs.zip",
17 | "Downloads/data/nuscenes/v1.0-trainval09_blobs.zip",
18 | "Downloads/data/nuscenes/v1.0-trainval10_blobs.zip"
19 | ]
20 | }
--------------------------------------------------------------------------------
/configs/fs/s3_czip_nuscenes.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "dwm.fs.czip.CombinedZipFileSystem",
3 | "fs": {
4 | "_class_name": "dwm.fs.dirfs.DirFileSystem",
5 | "path": "users/wuzehuan",
6 | "fs": {
7 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
8 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn",
9 | "aws_access_key_id": "4853705CDE93446E8D902F70291C2C92",
10 | "aws_secret_access_key": ""
11 | }
12 | },
13 | "paths": [
14 | "data/nuscenes/v1.0-trainval_meta.zip",
15 | "data/nuscenes/v1.0-trainval01_blobs.zip",
16 | "data/nuscenes/v1.0-trainval02_blobs.zip",
17 | "data/nuscenes/v1.0-trainval03_blobs.zip",
18 | "data/nuscenes/v1.0-trainval04_blobs.zip",
19 | "data/nuscenes/v1.0-trainval05_blobs.zip",
20 | "data/nuscenes/v1.0-trainval06_blobs.zip",
21 | "data/nuscenes/v1.0-trainval07_blobs.zip",
22 | "data/nuscenes/v1.0-trainval08_blobs.zip",
23 | "data/nuscenes/v1.0-trainval09_blobs.zip",
24 | "data/nuscenes/v1.0-trainval10_blobs.zip"
25 | ]
26 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SenseTime Research.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/dwm/utils/lidar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dwm.functional
3 |
4 |
5 | def preprocess_points(batch, device):
6 | mhv = dwm.functional.make_homogeneous_vector
7 | return [
8 | [
9 | (mhv(p_j.to(device)) @ t_j.permute(1, 0))[:, :3]
10 | for p_j, t_j in zip(p_i, t_i.flatten(0, 1))
11 | ]
12 | for p_i, t_i in zip(
13 | batch["lidar_points"], batch["lidar_transforms"].to(device))
14 | ]
15 |
16 |
17 | def postprocess_points(batch, ego_space_points):
18 | return [
19 | [
20 | (
21 | dwm.functional.make_homogeneous_vector(p_j.cpu()) @
22 | torch.linalg.inv(t_j).permute(1, 0)
23 | )[:, :3]
24 | for p_j, t_j in zip(p_i, t_i.flatten(0, 1))
25 | ]
26 | for p_i, t_i in zip(
27 | ego_space_points, batch["lidar_transforms"])
28 | ]
29 |
30 |
31 | def voxels2points(grid_size, voxels):
32 | interval = torch.tensor([grid_size["interval"]])
33 | min = torch.tensor([grid_size["min"]])
34 | return [
35 | [
36 | torch.nonzero(v_j).flip(-1).cpu() * interval + min
37 | for v_j in v_i
38 | ]
39 | for v_i in voxels
40 | ]
41 |
--------------------------------------------------------------------------------
/src/dwm/utils/carla_actor_state_machines.py:
--------------------------------------------------------------------------------
1 | import carla
2 |
3 |
4 | class ClassicPedestrian:
5 |
6 | def __init__(self, controller: carla.Actor):
7 | self.controller = controller
8 | self.state = "idle"
9 |
10 | def update(self):
11 | if self.state == "idle":
12 | world = self.controller.get_world()
13 | self.destination = world.get_random_location_from_navigation()
14 |
15 | # should start by 1 tick after the creation of actor
16 | self.controller.start()
17 | self.controller.go_to_location(self.destination)
18 | self.controller.set_max_speed(
19 | float(self.controller.parent.attributes["speed"]))
20 |
21 | self.state = "acting"
22 |
23 | elif self.state == "acting":
24 | # TODO: stop and transfer to idle when arrival
25 | pass
26 |
27 |
28 | class BevSpectator:
29 |
30 | def __init__(self, actor: carla.Actor):
31 | self.spectator = None
32 | self.hero = actor
33 | world = self.hero.get_world()
34 | self.spectator = world.get_spectator()
35 |
36 | def update(self):
37 | vehicle_transform = self.hero.get_transform()
38 | spectator_transform = carla.Transform(
39 | vehicle_transform.location + carla.Location(x=0, y=0, z=50),
40 | carla.Rotation(pitch=-90, yaw=0, roll=0)
41 | )
42 | self.spectator.set_transform(spectator_transform)
43 |
--------------------------------------------------------------------------------
/README_intro_zh.md:
--------------------------------------------------------------------------------
1 | # Open Driving World Models (OpenDWM)
2 |
3 | [[English README](README.md)]
4 |
5 | https://github.com/user-attachments/assets/649d3b81-3b1f-44f9-9f51-4d1ed7756476
6 |
7 | [视频链接](https://youtu.be/j9RRj-xzOA4)
8 |
9 | 欢迎来到 OpenDWM 项目!这是一个专注于自动驾驶视频生成的开源项目。我们的使命是提供一个高质量、可控的、使用最新技术的自动驾驶视频生成工具。我们的目标是构建一个既用户友好,又高度可复用的代码库,并希望通过聚集社区智慧,不断改进。
10 |
11 | 驾驶世界模型根据文本和道路环境布局条件,生成自动驾驶场景的多视角图像或视频。无论是环境、天气条件、车辆类型,还是驾驶路径,你都可以根据需求来调整。
12 |
13 | 亮点如下:
14 |
15 | 1. **透明且可复现的训练。** 我们提供完整的训练代码和配置,让大家可以根据需要进行实验复现、在自有数据上微调、定制开发功能。
16 |
17 | 2. **环境多样性的显著改进。** 通过对多个数据集的使用,模型的泛化能力得到前所未有的提升。以布局条件控制生成任务为例,下雪的城市街道,远处有雪山的湖边高速路,这些场景对于仅使用单一数据集训练的生成模型都是不可能的任务。
18 |
19 | 3. **大幅提升生成质量。** 对于流行模型架构(SD 2.1, 3.5)的支持,可以更便捷地利用社区内先进的预训练生成能力。包括多任务、自监督在内的多种训练技巧,让模型更有效地利用视频数据里的信息。
20 |
21 | 4. **方便测评。** 测评遵循流行框架 `torchmetrics`,易于配置、开发、并集成到已有管线。一些公开配置(例如在 nuScenes 验证集上的 FID, FVD)用于和其他研究工作对齐。
22 |
23 | 此外,我们设计的代码模块考虑到了相当程度的可复用性,以便于在其他项目中应用。
24 |
25 | 截止现在,本项目实现了以下论文中的技巧:
26 |
27 | > [UniMLVG: Unified Framework for Multi-view Long Video Generation with Comprehensive Control Capabilities for Autonomous Driving](https://sensetime-fvg.github.io/UniMLVG)
28 | > Rui Chen1,2, Zehuan Wu2, Yichen Liu2, Yuxin Guo2, Jingcheng Ni2, Haifeng Xia1, Siyu Xia1
29 | > 1Southeast University 2SenseTime Research
30 |
31 | > [MaskGWM: A Generalizable Driving World Model with Video Mask Reconstruction](https://sensetime-fvg.github.io/MaskGWM)
32 | > Jingcheng Ni, Yuxin Guo, Yichen Liu, Rui Chen, Lewei Lu, Zehuan Wu
33 | > SenseTime Research
34 |
35 | ## 设置和运行
36 |
37 | 请参考 [README](README.md#setup) 中的步骤。
38 |
--------------------------------------------------------------------------------
/src/dwm/tools/tar2zip.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 | import tarfile
5 | import zipfile
6 |
7 |
8 | def create_parser():
9 | parser = argparse.ArgumentParser(
10 | description="The tool to convert the TAR file to the ZIP file for the "
11 | "capability of random access in package.")
12 | parser.add_argument(
13 | "-i", "--input-path", type=str, required=True,
14 | help="The input path of TAR file.")
15 | parser.add_argument(
16 | "-o", "--output-path", type=str, required=True,
17 | help="The output path to save the converted ZIP file.")
18 | parser.add_argument(
19 | "-e", "--extensions-to-compress", default=".bin|.json|.pcd|.txt",
20 | type=str, help="The extension list to compress, split by ':'.")
21 | return parser
22 |
23 |
24 | def convert_content(tar_file_obj, zip_file_obj, extensions_to_compress: list):
25 | i = tarfile.TarInfo.fromtarfile(tar_file_obj)
26 | while i is not None:
27 | if i.isfile():
28 | data = tar_file_obj.extractfile(i).read()
29 | modified = datetime.datetime.fromtimestamp(i.mtime)
30 | zi = zipfile.ZipInfo(i.name, modified.timetuple()[:6])
31 | ext = os.path.splitext(i.name)[-1]
32 | compress_type = zipfile.ZIP_DEFLATED \
33 | if ext in extensions_to_compress else zipfile.ZIP_STORED
34 | zip_file_obj.writestr(zi, data, compress_type)
35 |
36 | i = tar_file_obj.next()
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = create_parser()
41 | args = parser.parse_args()
42 |
43 | etc = args.extensions_to_compress.split("|")
44 | with tarfile.open(args.input_path) as f_in:
45 | with zipfile.ZipFile(args.output_path, "w") as f_out:
46 | convert_content(f_in, f_out, etc)
47 |
--------------------------------------------------------------------------------
/src/dwm/tools/transcode_video.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import subprocess
5 |
6 |
7 | def create_parser():
8 | parser = argparse.ArgumentParser(
9 | description="The script to convert videos for better performance.")
10 | parser.add_argument(
11 | "-c", "--config-path", type=str, required=True,
12 | help="The config path.")
13 | parser.add_argument(
14 | "-i", "--input-path", type=str, required=True,
15 | help="The input path of videos.")
16 | parser.add_argument(
17 | "-o", "--output-path", type=str, required=True,
18 | help="The output path to save the merged dict file.")
19 | parser.add_argument("-f", "--range-from", type=int, default=0)
20 | parser.add_argument("-t", "--range-to", type=int, default=-1)
21 | return parser
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = create_parser()
26 | args = parser.parse_args()
27 |
28 | with open(args.config_path, "r", encoding="utf-8") as f:
29 | config = json.load(f)
30 |
31 | files = os.listdir(args.input_path)
32 | range_to = range_to = len(files) if args.range_to == -1 else args.range_to
33 | print(
34 | "Dataset count: {}, processing range {} - {}".format(
35 | len(files), args.range_from, range_to))
36 | actual_files = files[args.range_from:range_to]
37 | print(actual_files)
38 | os.makedirs(args.output_path, exist_ok=True)
39 | for i in actual_files:
40 | if not os.path.isfile(os.path.join(args.input_path, i)):
41 | continue
42 |
43 | if "reformat" in config:
44 | output_file = os.path.splitext(i)[0] + config["reformat"]
45 | else:
46 | output_file = i
47 |
48 | subprocess.run([
49 | "ffmpeg",
50 | "-hide_banner",
51 | "-y",
52 | "-i",
53 | os.path.join(args.input_path, i),
54 | *config["ffmpeg_args"],
55 | os.path.join(args.output_path, output_file)
56 | ], stdout=subprocess.PIPE, check=True)
57 |
--------------------------------------------------------------------------------
/src/dwm/tools/prepare_opendv.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from datasets import load_dataset
4 | import argparse
5 |
6 | def create_parser():
7 | parser = argparse.ArgumentParser(
8 | description="The script to make information JSON(s) for \
9 | official opendv annotations")
10 | parser.add_argument(
11 | "--meta-path", type=str, required=True,
12 | help="The path to official video metas of OpenDV-2K.")
13 | parser.add_argument(
14 | "-o", "--output-path", type=str, required=True,
15 | help="The path to save the information JSON file(s) on the local file "
16 | "system.")
17 | return parser
18 |
19 | if __name__ == "__main__":
20 | parser = create_parser()
21 | args = parser.parse_args()
22 |
23 | ds = load_dataset("OpenDriveLab/OpenDV-YouTube-Language")
24 | split = None
25 |
26 | with open(args.meta_path, "r", encoding="utf-8") as f:
27 | meta_dict = {
28 | i["videoid"]: i
29 | for i in json.load(f)
30 | if split is None or i["split"] == split and
31 | i["videoid"] not in ignore_list
32 | }
33 |
34 | new_image_descriptions = dict()
35 | for sp in ["train", "validation"]:
36 | with tqdm(
37 | ds[sp],
38 | desc=f"Prepare {sp}",
39 | ) as pbar:
40 | for frame_annos in pbar:
41 | k1 = frame_annos['folder'].split('/')[-1]
42 | k2 = int(frame_annos['first_frame'].split('.')[0])
43 | v = (frame_annos['blip'], frame_annos['cmd'])
44 | if k1 not in meta_dict:
45 | continue
46 | start_discard = meta_dict[k1]["start_discard"]
47 | default_fps, time_base = 10, 0.001
48 | t = int((k2/default_fps+start_discard)/time_base)
49 | new_image_descriptions[f"{k1}.{t}"] = {
50 | "image_description": v[0],
51 | "action": v[1]
52 | }
53 | with open(
54 | args.output_path, "w", encoding="utf-8"
55 | ) as f:
56 | json.dump(new_image_descriptions, f)
--------------------------------------------------------------------------------
/src/dwm/models/adapters.py:
--------------------------------------------------------------------------------
1 | import diffusers.models.adapter
2 | import torch
3 | from typing import Optional
4 |
5 |
6 | class ImageAdapter(torch.nn.Module):
7 | def __init__(
8 | self, in_channels: int = 3,
9 | channels: list = [320, 320, 640, 1280, 1280],
10 | is_downblocks: list = [False, True, True, True, False],
11 | num_res_blocks: int = 2, downscale_factor: int = 8,
12 | use_zero_convs: bool = False, zero_gate_coef: Optional[float] = None,
13 | gradient_checkpointing: bool = True
14 | ):
15 | super().__init__()
16 |
17 | in_channels = in_channels * downscale_factor ** 2
18 | self.unshuffle = torch.nn.PixelUnshuffle(downscale_factor)
19 | self.body = torch.nn.ModuleList([
20 | diffusers.models.adapter.AdapterBlock(
21 | in_channels if i == 0 else channels[i - 1], channels[i],
22 | num_res_blocks, down=is_downblocks[i])
23 | for i in range(len(channels))
24 | ])
25 | self.gradient_checkpointing = gradient_checkpointing
26 |
27 | self.zero_convs = torch.nn.ModuleList([
28 | torch.nn.Conv2d(channel, channel, 1)
29 | for channel in channels
30 | ]) if use_zero_convs else [None for _ in channels]
31 | for i in self.zero_convs:
32 | if i is not None:
33 | torch.nn.init.zeros_(i.weight)
34 | torch.nn.init.zeros_(i.bias)
35 |
36 | self.zero_gate_coef = zero_gate_coef
37 | self.zero_gates = torch.nn.Parameter(torch.zeros(len(channels))) \
38 | if zero_gate_coef else None
39 |
40 | def forward(self, x: torch.Tensor, return_features: bool = False):
41 | base_shape = x.shape[:-3]
42 | x = self.unshuffle(x.flatten(0, -4))
43 | features = []
44 | for i, (block, zero_conv) in enumerate(zip(self.body, self.zero_convs)):
45 | if self.training and self.gradient_checkpointing:
46 | x = torch.utils.checkpoint.checkpoint(
47 | block, x, use_reentrant=False)
48 | else:
49 | x = block(x)
50 |
51 | x_out = x
52 | if zero_conv is not None:
53 | x_out = zero_conv(x_out)
54 |
55 | if self.zero_gates is not None:
56 | x_out = x_out * torch.tanh(
57 | self.zero_gate_coef * self.zero_gates[i])
58 |
59 | features.append(x_out.view(*base_shape, *x_out.shape[1:]))
60 | return features if not return_features else features[-1]
61 |
--------------------------------------------------------------------------------
/src/dwm/utils/make_carla_cameras.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import numpy as np
5 | import transforms3d
6 |
7 |
8 | def create_parser():
9 | parser = argparse.ArgumentParser(
10 | description="Make Carla camera parameters from intrinsic matrices and "
11 | "transform matrices.")
12 | parser.add_argument(
13 | "-i", "--input-path", type=str, required=True,
14 | help="The path of input camera parameter file.")
15 | parser.add_argument(
16 | "-o", "--output-path", type=str, required=True,
17 | help="The path of output Carla parameter file.")
18 | return parser
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = create_parser()
23 | args = parser.parse_args()
24 |
25 | rear_ego_to_center_ego = [-1.5, 0, 0]
26 | lh_from_rh = rh_from_lh = np.diag([1, -1, 1, 1])
27 |
28 | # z fontal camera (OpenCV style) is x-right, y-down, z-front.
29 | # x fontal camera (Y flipped Carla style) is x-front, y-left, z-up.
30 | z_frontal_camera_from_x_frontal_camera = np.array([
31 | [0, -1, 0, 0],
32 | [0, 0, -1, 0],
33 | [1, 0, 0, 0],
34 | [0, 0, 0, 1]
35 | ], np.float32)
36 |
37 | with open(args.input_path, "r", encoding="utf-8") as f:
38 | config = json.load(f)
39 |
40 | result = {}
41 | for k, v in config.items():
42 | carla_transform = lh_from_rh @ np.array(v["transform"]) @ \
43 | z_frontal_camera_from_x_frontal_camera @ rh_from_lh
44 | euler_rotation = transforms3d.euler.mat2euler(
45 | carla_transform[:3, :3], "szyx")
46 | result[k] = {
47 | "attributes": {
48 | "fov": str(
49 | math.degrees(
50 | math.atan(v["intrinsic"][0][2] / v["intrinsic"][0][0])
51 | + math.atan(
52 | (v["image_size"][0] - v["intrinsic"][0][2]) /
53 | v["intrinsic"][0][0]))
54 | ),
55 | "role_name": k
56 | },
57 | "spawn_transform": {
58 | "location": [
59 | (carla_transform[i][3] + rear_ego_to_center_ego[i]).item()
60 | for i in range(3)
61 | ],
62 | "rotation": [
63 | math.degrees(-euler_rotation[1]),
64 | math.degrees(euler_rotation[0]),
65 | math.degrees(-euler_rotation[2])
66 | ]
67 | }
68 | }
69 |
70 | with open(args.output_path, "w", encoding="utf-8") as f:
71 | json.dump(result, f, indent=4)
72 |
--------------------------------------------------------------------------------
/src/dwm/tools/fs_make_info_json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import fsspec.implementations.local
3 | import json
4 | import os
5 | import tarfile
6 | import zipfile
7 |
8 |
9 | def create_parser():
10 | parser = argparse.ArgumentParser(
11 | description="The script to make information JSON for TAR or ZIP files "
12 | "to accelerate file system initialization.")
13 | parser.add_argument(
14 | "-i", "--input-path", type=str, required=True,
15 | help="The path of the TAR or ZIP file.")
16 | parser.add_argument(
17 | "-o", "--output-path", type=str, required=True,
18 | help="The path to save the information JSON file on the local file "
19 | "system.")
20 | parser.add_argument(
21 | "-fs", "--fs-config-path", default=None, type=str,
22 | help="The path of file system JSON config to open the input file.")
23 | return parser
24 |
25 |
26 | def enum_tar_members(tarfile):
27 | tarfile._check()
28 | while True:
29 | tarinfo = tarfile.next()
30 | if tarinfo is None:
31 | break
32 |
33 | yield tarinfo
34 |
35 |
36 | def make_info_dict(ext, fobj, enable_tqdm: bool = True):
37 | if ext == ".zip":
38 | with zipfile.ZipFile(fobj) as zf:
39 | return {
40 | i.filename: [i.header_offset, i.file_size, i.is_dir()]
41 | for i in zf.infolist()
42 | }
43 | elif ext == ".tar":
44 | with tarfile.TarFile(fileobj=fobj) as tf:
45 | tar_member_generator = enum_tar_members(tf)
46 | if enable_tqdm:
47 | # Since TAR files do not have centralized index information,
48 | # the scanning process will take a longer time.
49 | import tqdm
50 | tar_member_generator = tqdm.tqdm(tar_member_generator)
51 |
52 | return {
53 | i.name: [i.offset_data, i.size, not i.isfile()]
54 | for i in tar_member_generator
55 | }
56 | else:
57 | raise Exception("Unknown format {}.".format(ext))
58 |
59 |
60 | if __name__ == "__main__":
61 | parser = create_parser()
62 | args = parser.parse_args()
63 |
64 | if args.fs_config_path is None:
65 | fs = fsspec.implementations.local.LocalFileSystem()
66 | else:
67 | import dwm.common
68 | with open(args.fs_config_path, "r", encoding="utf-8") as f:
69 | fs = dwm.common.create_instance_from_config(json.load(f))
70 |
71 | with fs.open(args.input_path, "rb") as f:
72 | items = make_info_dict(os.path.splitext(args.input_path)[-1], f)
73 |
74 | with open(args.output_path, "w", encoding="utf-8") as f:
75 | json.dump(items, f)
76 |
--------------------------------------------------------------------------------
/src/dwm/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | The datasets of this project are used to read multi-view video clips, point cloud sequence clips, as well as corresponding text descriptions for each frame, camera parameters, ego-vehicle pose transform, projected 3D box condition, and projected HD map condition.
4 | We have adapted the nuScenes, Waymo Percepton, and Argoverse 2 Sensor datasets.
5 |
6 | ## Terms
7 |
8 | **Sequence length** is the frame count of the temporal sequence, denoted as `t`.
9 |
10 | **View count** is the camera view count, denoted as `v`.
11 |
12 | **Sensor count** is the sum of camera view count and the LiDAR count. The LiDAR count can only be 0 or 1 in this version.
13 |
14 | **Image coordinate system** is x-right, y-down, with the origin at the top left corner of the image.
15 |
16 | **Camera coordinate system** is x-right, y-down, z-forward, with the origin at the camera's optical center.
17 |
18 | **Ego vehicle coordinate system** is x-forward, y-left, z-up, with the origin at the the midpoint of the rear vehicle axle.
19 |
20 | ## Data items
21 |
22 | ### Basic information
23 |
24 | `fps`. A float32 tensor in the shape `(1,)`.
25 |
26 | `pts`. A float32 tensor in the shape `(t, sensor_count)`
27 |
28 | `images`. A multi-level list of PIL images in the shape `(t, v)`, with the original image size.
29 |
30 | `lidar_points`. A list in the shape `(t,)` of float32 tensors in the shape `(point_count, 3)`. The point count in each frame of a sequence is not fixed.
31 |
32 | ### Transforms
33 |
34 | `camera_transforms`. A float32 tensor in the shape `(t, v, 4, 4)`, representing the 4x4 transformation matrix from the camera coordinate system to the ego vehicle coordinate system.
35 |
36 | `camera_intrinsics`. A float32 tensor in the shape `(t, v, 3, 3)`, representing the 3x3 transformation matrix from the camera coordinate system to the image coordinate system.
37 |
38 | `image_size`. A float32 tensor in the shape `(t, v, 2)` for the pixel width and height of each image.
39 |
40 | `lidar_transforms`. A float32 tensor in the shape `(t, 1, 4, 4)`, representing the 4x4 transformation matrix from the LiDAR coordinate system to the ego vehicle coordinate system.
41 |
42 | `ego_transforms`. A float32 tensor in the shape `(t, sensor_count, 4, 4)`, representing the 4x4 transformation matrix from the ego vehicle coordinate system to the world coordinate system.
43 |
44 | ### Conditions
45 |
46 | `3dbox_images`. A multi-level list of PIL images in the shape `(t, v)`, the same as the original image size, contains the projected 3D box of object annotations.
47 |
48 | `hdmap_images`. A multi-level list of PIL images in the shape `(t, v)`, the same as the original image size, contains the projected HD map of line annotations.
49 |
50 | `image_description`. A multi-level list of strings in the shape `(t, v)`, for the text descriptions per image.
51 |
--------------------------------------------------------------------------------
/src/dwm/metrics/voxel_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchmetrics
3 | import torch.distributed
4 |
5 |
6 | class VoxelIoU(torchmetrics.Metric):
7 | def __init__(
8 | self, **kwargs
9 | ):
10 | super().__init__(**kwargs)
11 | self.iou_list = []
12 |
13 | def update(self,
14 | gt_voxel: torch.tensor,
15 | pred_voxel: torch.tensor):
16 | if len(gt_voxel.shape) == 3:
17 | self.iou_list.append(
18 | (pred_voxel & gt_voxel).sum() / (pred_voxel | gt_voxel).sum())
19 | else:
20 | for i, j in zip(gt_voxel, pred_voxel):
21 | self.iou_list.append(((i & j).sum() / (i | j).sum()).float())
22 |
23 | def compute(self):
24 | iou_list = torch.stack(self.iou_list, dim=0)
25 | world_size = torch.distributed.get_world_size() \
26 | if torch.distributed.is_initialized() else 1
27 | if world_size > 1:
28 | all_iou = iou_list.new_zeros(
29 | (len(iou_list)*world_size, ) + iou_list.shape[1:])
30 | torch.distributed.all_gather_into_tensor(
31 | all_iou, iou_list)
32 | iou_list = all_iou
33 | num_samples = (~torch.isnan(iou_list) & ~torch.isinf(iou_list)).sum()
34 | iou_list = torch.nan_to_num(iou_list, nan=0.0, posinf=0.0, neginf=0.0)
35 | self.num_samples = num_samples
36 | return iou_list.sum() / num_samples
37 |
38 | def reset(self):
39 | self.iou_list.clear()
40 | super().reset()
41 |
42 |
43 | class VoxelDiff(torchmetrics.Metric):
44 | def __init__(
45 | self, **kwargs
46 | ):
47 | super().__init__(**kwargs)
48 | self.diff_list = []
49 |
50 | def update(self,
51 | gt_voxel: torch.tensor, pred_voxel: torch.tensor):
52 | if len(gt_voxel.shape) == 3:
53 | self.diff_list.append(torch.logical_xor(
54 | pred_voxel, gt_voxel).sum().to(torch.float32))
55 | else:
56 | for i, j in zip(gt_voxel, pred_voxel):
57 | self.diff_list.append(torch.logical_xor(
58 | i, j).sum().float())
59 |
60 | def compute(self):
61 | diff_list = torch.stack(self.diff_list, dim=0)
62 | world_size = torch.distributed.get_world_size() \
63 | if torch.distributed.is_initialized() else 1
64 | if world_size > 1:
65 | all_diff = diff_list.new_zeros(
66 | (len(diff_list)*world_size, ) + diff_list.shape[1:])
67 | torch.distributed.all_gather_into_tensor(
68 | all_diff, diff_list)
69 | diff_list = all_diff
70 | self.num_samples = len(diff_list)
71 | return diff_list.mean()
72 |
73 | def reset(self):
74 | self.diff_list.clear()
75 | super().reset()
76 |
--------------------------------------------------------------------------------
/src/dwm/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.distributed
3 | import torch.distributed.checkpoint.state_dict
4 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
5 |
6 |
7 | def distributed_save_optimizer_state(model, optimizer, folder, filename):
8 | if torch.distributed.is_initialized():
9 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
10 | # DDP: rank 0 save
11 | if torch.distributed.get_rank() == 0:
12 | state_dict = optimizer.state_dict()
13 | torch.save(
14 | state_dict,
15 | os.path.join(folder, "{}.pth".format(filename)))
16 |
17 | elif isinstance(model, FSDP):
18 | # Zero2, Zero3: all ranks save
19 | # Hybird Zero2, Zero3: ranks in the 1st group save
20 | state_dict = torch.distributed.checkpoint.state_dict\
21 | .get_optimizer_state_dict(model, optimizer)
22 | if model._device_mesh is None or (
23 | torch.distributed.get_rank() in
24 | model._device_mesh.mesh[0].tolist()
25 | ):
26 | torch.distributed.checkpoint.state_dict_saver.save(
27 | state_dict, checkpoint_id=os.path.join(folder, filename),
28 | process_group=(
29 | None if model._device_mesh is None
30 | else model._device_mesh.get_group(mesh_dim=-1)
31 | ))
32 |
33 | else:
34 | raise Exception(
35 | "Unsupported distribution framework to save the optimizer "
36 | "state.")
37 |
38 | else:
39 | state_dict = optimizer.state_dict()
40 | torch.save(state_dict, os.path.join(folder, "{}.pth".format(filename)))
41 |
42 |
43 | def distributed_load_optimizer_state(model, optimizer, folder, filename):
44 | if torch.distributed.is_initialized():
45 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
46 | # DDP: all ranks load the same
47 | state_dict = torch.load(
48 | os.path.join(folder, "{}.pth".format(filename)),
49 | map_location="cpu", weights_only=True)
50 | optimizer.load_state_dict(state_dict)
51 | elif isinstance(model, FSDP):
52 | # Zero2, Zero3: all ranks load
53 | # Hybird Zero2, Zero3: ranks in the 1st group load
54 | state_dict = torch.distributed.checkpoint.state_dict\
55 | .get_optimizer_state_dict(model, optimizer)
56 | torch.distributed.checkpoint.state_dict_loader.load(
57 | state_dict, checkpoint_id=os.path.join(folder, filename),
58 | planner=torch.distributed.checkpoint.DefaultLoadPlanner(
59 | allow_partial_load=True))
60 |
61 | else:
62 | state_dict = torch.load(
63 | os.path.join(folder, "{}.pth".format(filename)),
64 | map_location="cpu", weights_only=True)
65 | optimizer.load_state_dict(state_dict)
66 |
--------------------------------------------------------------------------------
/configs/README.md:
--------------------------------------------------------------------------------
1 | # Configurations
2 |
3 | The configuration files are in the JSON format. They include settings for the models, datasets, pipelines, or any arguments for the program.
4 |
5 | ## Introduction
6 |
7 | In our code, we mainly use JSON objects in three ways:
8 |
9 | 1. As a dictionary
10 | 2. As a function's parameter list
11 | 3. As a constructor and parameter for objects
12 |
13 | ### As a dictionary
14 |
15 | The most common way for the config, for example:
16 |
17 | ```JSON
18 | {
19 | "guidance_scale": 4,
20 | "inference_steps": 40,
21 | "preview_image_size": [
22 | 448,
23 | 252
24 | ]
25 | }
26 | ```
27 |
28 | The pipeline finds the corresponding value variable in the dictionary through the key, which determines the behavior at runtime.
29 |
30 | ### As a function's parameter list
31 |
32 | The content of a JSON object is passed into a function, for example:
33 |
34 | ```JSON
35 | {
36 | "num_workers": 3,
37 | "prefetch_factor": 3,
38 | "persistent_workers": true
39 | }
40 | ```
41 |
42 | The PyTorch data loader will accept all the arguments by
43 |
44 | ```Python
45 | data_loader = torch.utils.data.DataLoader(
46 | dataset, **deserialized_json_object)
47 | ```
48 |
49 | In this case, you can fill in the required parameters according to the reference documentation of the function (such as the [data loader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) here).
50 |
51 | ### As a constructor and parameter for objects
52 |
53 | The JSON object declares the name of the object to be created, as well as the parameters, for example:
54 |
55 | ```JSON
56 | {
57 | "_class_name": "torch.optim.AdamW",
58 | "lr": 6e-5,
59 | "betas": [
60 | 0.9,
61 | 0.975
62 | ]
63 | }
64 | ```
65 |
66 | The "_class_name" is in the format of `{name_space}.{class_or_function_name}`, and other key-value pairs are used as parameters for the class constructor (e.g. [AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW) here) or the function.
67 |
68 | In the code, this type of object is parsed with `dwm.common.create_instance_from_config()` function.
69 |
70 | With this design, the configuration, framework, and components are **loosely coupled**. For example, user can easily switch to a third-party optimizer "bitsandbytes.optim.Adam8bit" without editing the code. Developers can provide any component class (e.g. dataset, data transforms) without having to register to a specific framework.
71 |
72 | ## Development
73 |
74 | ### Name convention
75 |
76 | The configs in this folder are mainly about the pipelines and consumed by the `src/dwm/train.py`. So they are named in the format of `{pipeline_name}_{model_config}_{condition_config}_{data_config}.json`.
77 |
78 | * Pipeline name: the python script name in the `src/dwm/pipelines`.
79 | * Model config: the most discriminative model arguments, such as `spatial`, `crossview`, `temporal` for the SD models.
80 | * Condition config: the additional input for the model, such as `ts` for the "text description per scene", `ti` for the "text description per image", `b` for the box condition, `m` for the map condition.
81 | * Data config: `mini` for the debug purpose. Combination of `nuscenes`, `argoverse`, `waymo`, `opendv` (or their initial letters), for the data components.
82 |
--------------------------------------------------------------------------------
/docs/CtsdPipelineFaqs.md:
--------------------------------------------------------------------------------
1 | # CTSD pipeline FAQs
2 |
3 | The CTSD is short for cross-view temporal stable diffusion. This pipeline is an extension of the pre-trained Stable Diffusion model for autonomous driving multi-view tasks and video sequence tasks.
4 |
5 | Here are some frequently asked questions.
6 |
7 | * [Single GPU training](#single-gpu-training)
8 | * [Remove conditions](#remove-conditions)
9 |
10 | ## Single GPU training
11 |
12 | The default training configurations use the [Hybrid FSDP](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) to reduce the GPU memory usage for distributed training, which does not reduce memory usage on single GPU systems.
13 |
14 | To reduce memory on single GPU system, we should edit the config to enable quantization and CUDA AMP.
15 |
16 | 1. Set `"quantization_config"` in the `"text_encoder_load_args"` for those models implementing the [HfQuantizer](https://huggingface.co/docs/transformers/v4.48.2/en/main_classes/quantization#transformers.quantizers.HfQuantizer) (make sure the [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/main/en/index) package is installed for the following config).
17 |
18 | ```JSON
19 | {
20 | "pipeline": {
21 | "common_config": {
22 | "text_encoder_load_args": {
23 | "variant": "fp16",
24 | "torch_dtype": {
25 | "_class_name": "get_class",
26 | "class_name": "torch.float16"
27 | },
28 | "quantization_config": {
29 | "_class_name": "diffusers.quantizers.quantization_config.BitsAndBytesConfig",
30 | "load_in_4bit": true,
31 | "bnb_4bit_quant_type": "nf4",
32 | "bnb_4bit_compute_dtype": {
33 | "_class_name": "get_class",
34 | "class_name": "torch.float16"
35 | }
36 | }
37 | }
38 | }
39 | }
40 | }
41 | ```
42 |
43 | 2. Switch to quantized optimizer.
44 |
45 | ```JSON
46 | {
47 | "optimizer": {
48 | "_class_name": "bitsandbytes.optim.Adam8bit",
49 | "lr": 5e-5
50 | }
51 | }
52 | ```
53 |
54 | 3. Enable the CUDA AMP instead of FSDP mixed precision.
55 |
56 | ```JSON
57 | {
58 | "pipeline": {
59 | "common_config": {
60 | "autocast": {
61 | "device_type": "cuda"
62 | }
63 | }
64 | }
65 | }
66 | ```
67 |
68 | 4. Remove all `"device_mesh"` related config items.
69 |
70 | ## Remove conditions
71 |
72 | Remove layout conditions (3D boxes, HD map):
73 |
74 | * Model: in the case of only removing one condition, adjust the input channel number of `"pipeline.model.condition_image_adapter_config.in_channels"` to 3; If both 3dbox and hdmap are removed, the `"pipeline.model.condition_image_adapter_config"` section should be completely deleted.
75 |
76 | * Dataset: remove `"_3dbox_image_settings"`, `"hdmap_image_settings"`
77 |
78 | * Dataset adapter: remove `torchvision.transforms` items of `"3dbox_images"`, `"hdmap_images"`.
79 |
80 | Dataset related settings, note that modifications to the training set and validation set must be consistent.
81 |
82 | Remove components of text prompt:
83 |
84 | * Refer to the [text condition processing function](../src/dwm/datasets/common.py#L316), in the dataset settings `"image_description_settings"`, add `"selected_keys": ["time", "weather", "environment"]` to exclude other text fields.
85 |
--------------------------------------------------------------------------------
/src/dwm/utils/make_blank_code.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.common
3 | import dwm.pipelines.lidar_vqvae
4 | import json
5 | import pickle
6 | import torch
7 | import torch.utils.data
8 |
9 |
10 | def create_parser():
11 | parser = argparse.ArgumentParser(
12 | description="Make blank code file of the LiDAR codebook checkpoint.")
13 | parser.add_argument(
14 | "-c", "--config-path", type=str, required=True,
15 | help="The path of training config file.")
16 | parser.add_argument(
17 | "-i", "--input-path", type=str, required=True,
18 | help="The path of input checkpoint file.")
19 | parser.add_argument(
20 | "-o", "--output-path", type=str, required=True,
21 | help="The path of output blank code file.")
22 | parser.add_argument(
23 | "-it", "--iteration", default=100, type=int,
24 | help="The iteration count from the validation set for blank code.")
25 | parser.add_argument(
26 | "-s", "--sample-count", default=16, type=int,
27 | help="The count of blank code to sample.")
28 | return parser
29 |
30 |
31 | def count_code(indices):
32 | unique_elements, counts = torch.unique(indices, return_counts=True)
33 | sorted_indices = torch.argsort(counts, descending=True)
34 | sorted_elements = unique_elements[sorted_indices]
35 | sorted_counts = counts[sorted_indices]
36 | return sorted_elements, sorted_counts
37 |
38 |
39 | if __name__ == "__main__":
40 | parser = create_parser()
41 | args = parser.parse_args()
42 |
43 | with open(args.config_path, "r", encoding="utf-8") as f:
44 | config = json.load(f)
45 |
46 | device = torch.device(config.get("device", "cpu"))
47 | dataset = dwm.common.create_instance_from_config(
48 | config["validation_dataset"])
49 |
50 | vq_point_cloud = dwm.common.create_instance_from_config(
51 | config["pipeline"]["vq_point_cloud"])
52 | vq_point_cloud.to(device)
53 | vq_point_cloud.load_state_dict(
54 | dwm.pipelines.lidar_vqvae.LidarCodebook.load_state(args.input_path),
55 | strict=False)
56 |
57 | dataloader = torch.utils.data.DataLoader(
58 | dataset, shuffle=True,
59 | **dwm.common.instantiate_config(config["validation_dataloader"]))
60 |
61 | iteration = 0
62 | code_dict = {}
63 | vq_point_cloud.eval()
64 | for batch in dataloader:
65 | with torch.no_grad():
66 | points = dwm.pipelines.lidar_vqvae.LidarCodebook.get_points(
67 | batch, config["pipeline"]["common_config"], device)
68 |
69 | voxels = vq_point_cloud.voxelizer(points)
70 | lidar_feats = vq_point_cloud.lidar_encoder(voxels)
71 | _, _, code_indices = vq_point_cloud.vector_quantizer(
72 | lidar_feats, vq_point_cloud.code_age,
73 | vq_point_cloud.code_usage)
74 |
75 | codes, counts = count_code(code_indices)
76 | for code, count in zip(codes.tolist(), counts.tolist()):
77 | if code in code_dict:
78 | code_dict[code] += count
79 | else:
80 | code_dict[code] = count
81 |
82 | iteration += 1
83 | if iteration >= args.iteration:
84 | break
85 |
86 | blank_code = [
87 | i[0] for i in sorted(
88 | code_dict.items(),
89 | key=lambda i: i[1], reverse=True)[:args.sample_count]
90 | ]
91 | with open(args.output_path, "wb") as f:
92 | pickle.dump(blank_code, f)
93 |
--------------------------------------------------------------------------------
/src/dwm/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.common
3 | import json
4 | import os
5 | import torch
6 |
7 |
8 | def create_parser():
9 | parser = argparse.ArgumentParser(
10 | description="The script to finetune a stable diffusion model to the "
11 | "driving dataset.")
12 | parser.add_argument(
13 | "-c", "--config-path", type=str, required=True,
14 | help="The config to load the train model and dataset.")
15 | parser.add_argument(
16 | "-o", "--output-path", type=str, required=True,
17 | help="The path to save checkpoint files."),
18 | parser.add_argument(
19 | "--resume-from", default=None, type=int,
20 | help="The step to resume from")
21 | return parser
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = create_parser()
26 | args = parser.parse_args()
27 |
28 | with open(args.config_path, "r", encoding="utf-8") as f:
29 | config = json.load(f)
30 |
31 | # set distributed training (if enabled), log, random number generator, and
32 | # load the checkpoint (if required).
33 | ddp = "LOCAL_RANK" in os.environ
34 | if ddp:
35 | local_rank = int(os.environ["LOCAL_RANK"])
36 | device = torch.device(config["device"], local_rank)
37 | if config["device"] == "cuda":
38 | torch.cuda.set_device(local_rank)
39 |
40 | torch.distributed.init_process_group(backend=config["ddp_backend"])
41 | else:
42 | device = torch.device(config["device"])
43 |
44 | # setup the global state
45 | if "global_state" in config:
46 | for key, value in config["global_state"].items():
47 | dwm.common.global_state[key] = \
48 | dwm.common.create_instance_from_config(value)
49 |
50 | should_log = (ddp and local_rank == 0) or not ddp
51 | should_save = not torch.distributed.is_initialized() or \
52 | torch.distributed.get_rank() == 0
53 |
54 | # load the pipeline including the models
55 | pipeline = dwm.common.create_instance_from_config(
56 | config["pipeline"], output_path=args.output_path, config=config,
57 | device=device, resume_from=args.resume_from)
58 | if should_log:
59 | print("The pipeline is loaded.")
60 |
61 | validation_dataset = dwm.common.create_instance_from_config(
62 | config["validation_dataset"])
63 | if ddp:
64 | # make equal sample count for each process to simplify the result
65 | # gathering
66 | total_batch_size = int(os.environ["WORLD_SIZE"]) * \
67 | config["validation_dataloader"]["batch_size"]
68 | dataset_length = len(validation_dataset) // \
69 | total_batch_size * total_batch_size
70 | validation_dataset = torch.utils.data.Subset(
71 | validation_dataset, range(0, dataset_length))
72 | validation_datasampler = \
73 | torch.utils.data.distributed.DistributedSampler(
74 | validation_dataset)
75 | validation_dataloader = torch.utils.data.DataLoader(
76 | validation_dataset,
77 | **dwm.common.instantiate_config(config["validation_dataloader"]),
78 | sampler=validation_datasampler)
79 | else:
80 | validation_datasampler = None
81 | validation_dataloader = torch.utils.data.DataLoader(
82 | validation_dataset,
83 | **dwm.common.instantiate_config(config["validation_dataloader"]))
84 |
85 | if should_log:
86 | print("The validation dataset is loaded with {} items.".format(
87 | len(validation_dataset)))
88 |
89 | pipeline.evaluate_pipeline(
90 | 0, len(validation_dataset), validation_dataloader,
91 | validation_datasampler)
92 |
93 | if torch.distributed.is_initialized():
94 | torch.distributed.destroy_process_group()
95 |
--------------------------------------------------------------------------------
/docs/InteractiveGeneration.md:
--------------------------------------------------------------------------------
1 | # Interactive Generation
2 |
3 | 4x accelerated playing speed:
4 |
5 | https://github.com/user-attachments/assets/933b84d3-496a-41bd-b6ab-3022a0137062
6 |
7 | We've implemented the interactive generation with Carla (0.9.15). The main components:
8 |
9 | * Server-side [Carla](https://carla.org/) maintains the simulation state.
10 | * Server-side [simulation script](../src/dwm/utils/carla_simulation.py) configs the environment, ego car, sensors, and traffic manager.
11 | * Server-side [streaming generation](../src/dwm/streaming.py) reads condition data from the Carla, and write generated frames to video streaming server.
12 | * Server-side [video streaming server](https://github.com/bluenviron/mediamtx) to publish the video streaming for client video player.
13 | * Client-side [Carla control](../src/dwm/utils/carla_control.py) to control the ego car in the simulation world with kayboard.
14 | * Client-side [video player](https://ffmpeg.org/) to receive the generated result.
15 |
16 | The dataflow is:
17 |
18 | 1. Carla control
19 | 2. Carla (configured by the simulation script)
20 | 3. Streaming generation
21 | 4. Video streaming server
22 | 5. Video player
23 |
24 | ## Requirement
25 |
26 | The server requires:
27 |
28 | 1. GPU (nVidia A100 is recommended)
29 | 2. network accessibility.
30 | 3. Python in 3.9 or 3.10
31 | 4. Carla == 0.9.15
32 | 5. mediamtx
33 |
34 | The client requires:
35 | 1. Windows or Ubuntu (The supported platforms for the Carla Python API).
36 | 2. Python in 3.9 or 3.10
37 | 3. ffmpeg
38 |
39 | ## Models
40 |
41 | The interactive generative model is trained from scratch on autonomous driving data after the specification reduction (model size, view count, resolution) of CTSD 3.5, in order to reduce the overhead of model inference.
42 |
43 | | Base Model | Temporal Training Style | Prediction Style | Configs | Checkpoint Download |
44 | | :-: | :-: | :-: | :-: | :-: |
45 | | [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Diffusion forcing transformer](https://arxiv.org/abs/2502.06764) | [FIFO diffusion](https://arxiv.org/abs/2405.11473) | [Config](../configs/experimental/multi_datasets/ctsd_35_xs_df6v3_tirda_bm_nwao.json) | [Checkpoint](https://huggingface.co/wzhgba/opendwm-models/resolve/main/ctsd_35_xs_df6v3_tirda_bm_nwao_60k.pth?download=true) |
46 |
47 | ## Inference
48 |
49 | ### Server-side Setup
50 |
51 | 1. Download the base model (for VAE and text encoders) and model checkpoint, then edit the [config](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L168).
52 |
53 | 2. Launch the video streaming server following the [official guide](https://github.com/bluenviron/mediamtx?tab=readme-ov-file#installation).
54 |
55 | 3. Launch the Carla: `{CARLA_ROOT}/CarlaUE4.sh -RenderOffScreen -quality-level=Low`
56 |
57 | 4. Configure the Carla by editing the [config template](../configs/experimental/simulation/carla_simulation_town10_nusc_3views.json) and run: `PYTHONPATH=src python src/dwm/utils/carla_simulation.py -c configs/experimental/simulation/carla_simulation_town10_nusc_3views.json --client-timeout 150`
58 |
59 | 5. Edit the generation config template (e.g. [Carla endpoint](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L7), [video streaming options](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L268)) and run: `PYTHONPATH=src python src/dwm/streaming.py -c configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json -l output/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming -s rtsp://{VIDEO_STREAMING_ENDPOINT}/live --fps 2`
60 |
61 | ### Client-side Setup
62 |
63 | 1. Launch the video player after the server-side streaming begin: `ffplay -fflags nobuffer -rtsp_transport tcp rtsp://{VIDEO_STREAMING_ENDPOINT}/live`
64 |
65 | 2. Launch the Carla control after the server-side streaming begin: `python src\dwm\utils\carla_control.py --host {CARLA_SERVER_ADDRESS} -p {CARLA_SERVER_PORT}`
66 |
67 | ## Known issues
68 |
69 | 1. Generation speed.
70 | 2. Latency due to the denoising queue.
71 |
--------------------------------------------------------------------------------
/src/dwm/fs/README.md:
--------------------------------------------------------------------------------
1 | # File System
2 |
3 | This project uses the [fsspec](https://github.com/fsspec/filesystem_spec) to access the data from varied sources including the local file system, S3, ZIP content.
4 |
5 |
6 | ## Common usage
7 |
8 | You can open, read, seek the file with the file system objects. Check out the [usage](https://filesystem-spec.readthedocs.io/en/latest/usage.html) document to quick start.
9 |
10 | ### Create the file system
11 |
12 | ``` Python
13 | import fsspec.implementations.local
14 | fs = fsspec.implementations.local.LocalFileSystem()
15 | ```
16 |
17 | ### Open and read the file
18 |
19 | ``` Python
20 | import json
21 | config_path = "configs/fs/local.json"
22 | with fs.open(config_path, "r", encoding="utf-8") as f:
23 | config = json.load(f)
24 |
25 | from PIL import Image
26 | image_path = "samples/CAM_FRONT/n008-2018-08-01-15-16-36-0400__CAM_FRONT__1533151603512404.jpg"
27 | with fs.open(image_path, "rb") as f:
28 | image = Image.open(f)
29 | ```
30 |
31 |
32 | ## API reference
33 |
34 | ### dwm.fs.ctar.CombinedTarFileSystem
35 |
36 | This file system opens several TAR blobs from a given file system and provide the file access inside the TAR blobs. Please note that the required TAR file here refers only to the uncompressed format, as the TAR.GZ format does not support random access to one of the files within the blob. This file system is forkable and compatible with multi worker data loader of the PyTorch.
37 |
38 | ### dwm.fs.czip.CombinedZipFileSystem
39 |
40 | This file system opens several ZIP blobs from a given file system and provide the file access inside the ZIP blobs. It is forkable and compatible with multi worker data loader of the PyTorch.
41 |
42 | ### dwm.fs.s3fs.ForkableS3FileSystem
43 |
44 | This file system opens the S3 service and provide the file access on the service. It is forkable and compatible with multi worker data loader of the PyTorch.
45 |
46 |
47 | ## Configuration samples
48 | It is easy to initialize the file system object by `dwm.common.create_instance_from_config()` with following configurations by JSON.
49 |
50 | ### Local file system
51 |
52 | ``` JSON
53 | {
54 | "_class_name": "fsspec.implementations.local.LocalFileSystem"
55 | }
56 | ```
57 |
58 | **Relative directory on local file system**
59 | ``` JSON
60 | {
61 | "_class_name": "fsspec.implementations.dirfs.DirFileSystem",
62 | "path": "/mnt/storage/user/wuzehuan/Downloads/data/nuscenes",
63 | "fs": {
64 | "_class_name": "fsspec.implementations.local.LocalFileSystem"
65 | }
66 | }
67 | ```
68 |
69 | ### S3 file system
70 |
71 | The parameters follow the [Botocore confiruation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
72 |
73 | ``` JSON
74 | {
75 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
76 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn",
77 | "aws_access_key_id": "YOUR_ACCESS_KEY",
78 | "aws_secret_access_key": "YOUR_SECRET_KEY"
79 | }
80 | ```
81 |
82 | **Relative directory on S3 file system**
83 |
84 | ``` JSON
85 | {
86 | "_class_name": "fsspec.implementations.dirfs.DirFileSystem",
87 | "path": "users/wuzehuan/data/nuscenes",
88 | "fs": {
89 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
90 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn",
91 | "aws_access_key_id": "YOUR_ACCESS_KEY",
92 | "aws_secret_access_key": "YOUR_SECRET_KEY"
93 | }
94 | }
95 | ```
96 |
97 | **Retry options on S3 file system**
98 |
99 | ``` JSON
100 | {
101 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem",
102 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn",
103 | "aws_access_key_id": "YOUR_ACCESS_KEY",
104 | "aws_secret_access_key": "YOUR_SECRET_KEY",
105 | "config": {
106 | "_class_name": "botocore.config.Config",
107 | "retries": {
108 | "max_attempts": 8
109 | }
110 | }
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/src/dwm/models/voxelizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Voxelizer(torch.nn.Module):
5 | """Voxelizer for converting Lidar point cloud to image"""
6 |
7 | def __init__(self, x_min, x_max, y_min, y_max, step, z_min, z_max, z_step):
8 | super().__init__()
9 |
10 | self.x_min = x_min
11 | self.x_max = x_max
12 | self.y_min = y_min
13 | self.y_max = y_max
14 | self.step = step
15 | self.z_min = z_min
16 | self.z_max = z_max
17 | self.z_step = z_step
18 |
19 | self.width = round((self.x_max - self.x_min) / self.step)
20 | self.height = round((self.y_max - self.y_min) / self.step)
21 | self.z_depth = round((self.z_max - self.z_min) / self.z_step)
22 | self.depth = self.z_depth
23 |
24 | def voxelize_single(self, lidar, bev):
25 | """Voxelize a single lidar sweep into image frame
26 | Image frame:
27 | 1. Increasing depth indices corresponds to increasing real world z
28 | values.
29 | 2. Increasing height indices corresponds to decreasing real world y
30 | values.
31 | 3. Increasing width indices corresponds to increasing real world x
32 | values.
33 | Args:
34 | lidar (torch.Tensor N x 4 or N x 5) x, y, z, intensity, height_to_ground (optional)
35 | bev (torch.Tensor D x H x W) D = depth, the bird's eye view
36 | raster to populate
37 | """
38 |
39 | # 1 & 2. Convert points to tensor index location. Clamp z indices to
40 | # valid range.
41 | indices_h = torch.floor((lidar[:, 1] - self.y_min) / self.step).long()
42 | indices_w = torch.floor((lidar[:, 0] - self.x_min) / self.step).long()
43 | indices_d = torch.floor(
44 | (lidar[:, 2] - self.z_min) / self.z_step).long()
45 |
46 | # 3. Remove points out of bound
47 | valid_mask = ~torch.any(
48 | torch.stack(
49 | [
50 | indices_h < 0,
51 | indices_h >= self.height,
52 | indices_w < 0,
53 | indices_w >= self.width,
54 | indices_d < 0,
55 | indices_d >= self.z_depth,
56 | ]
57 | ),
58 | dim=0,
59 | )
60 | indices_h = indices_h[valid_mask]
61 | indices_w = indices_w[valid_mask]
62 | indices_d = indices_d[valid_mask]
63 | # 4. Assign indices to 1
64 | bev[indices_d, indices_h, indices_w] = 1.0
65 |
66 | def forward(self, lidars):
67 | """Voxelize multiple sweeps in the current vehicle frame into voxels
68 | in image frame
69 | Args:
70 | list(list(tensor)): B * T * tensor[N x 4],
71 | where B = batch_size, T = 5, N is variable,
72 | 4 = [x, y, z, intensity]
73 | Returns:
74 | tensor: [B x D x H x W], B = batch_size, D = T * depth, H = height,
75 | W = width
76 | """
77 | batch_size = len(lidars)
78 | assert batch_size > 0 and len(lidars[0]) > 0
79 | num_sweep = len(lidars[0])
80 |
81 | bev = torch.zeros(
82 | (batch_size, num_sweep, self.depth, self.height, self.width),
83 | dtype=torch.float,
84 | device=lidars[0][0][0].device,
85 | )
86 | for b in range(batch_size):
87 | assert len(lidars[b]) == num_sweep
88 | for i in range(num_sweep):
89 | self.voxelize_single(lidars[b][i], bev[b][i])
90 |
91 | return bev
92 |
93 | def get_voxel_coordinates(self, downsample_scale = 1):
94 | x_coord = torch.arange(self.x_min, self.x_max, ((self.x_max - self.x_min) / self.width) / downsample_scale) + self.step / 2
95 | y_coord = torch.arange(self.y_min, self.y_max, ((self.y_max - self.y_min) / self.height) / downsample_scale) + self.step / 2
96 | z_coord = torch.arange(self.z_min, self.z_max, ((self.z_max - self.z_min) / self.z_depth) / downsample_scale) + self.z_step / 2
97 | x_grid, y_grid, z_grid = torch.meshgrid(z_coord, y_coord, x_coord)
98 |
99 | return torch.stack([x_grid, y_grid, z_grid], dim = -1)
100 |
--------------------------------------------------------------------------------
/examples/ctsd_generation_example.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.distributed
3 | import dwm.common
4 | import dwm.utils.preview
5 | import json
6 | import os
7 | import torch
8 | import torchvision
9 |
10 |
11 | def create_parser():
12 | parser = argparse.ArgumentParser(
13 | description="The script to finetune a stable diffusion model to the "
14 | "driving dataset.")
15 | parser.add_argument(
16 | "-c", "--config-path", type=str, required=True,
17 | help="The config to load the train model and dataset.")
18 | parser.add_argument(
19 | "-o", "--output-path", type=str, required=True,
20 | help="The path to save checkpoint files.")
21 | return parser
22 |
23 |
24 | if __name__ == "__main__":
25 | parser = create_parser()
26 | args = parser.parse_args()
27 |
28 | with open(args.config_path, "r", encoding="utf-8") as f:
29 | config = json.load(f)
30 |
31 | ddp = "LOCAL_RANK" in os.environ
32 | if ddp:
33 | local_rank = int(os.environ["LOCAL_RANK"])
34 | device = torch.device(config["device"], local_rank)
35 | if config["device"] == "cuda":
36 | torch.cuda.set_device(local_rank)
37 |
38 | torch.distributed.init_process_group(backend=config["ddp_backend"])
39 | else:
40 | device = torch.device(config["device"])
41 |
42 | pipeline = dwm.common.create_instance_from_config(
43 | config["pipeline"], output_path=None, config=config, device=device)
44 | print("The pipeline is loaded.")
45 |
46 | os.makedirs(args.output_path, exist_ok=True)
47 | for i_id, i in enumerate(config["inputs"]):
48 | i["batch"] = {
49 | k: torch.tensor(v) if k != "clip_text" else v
50 | for k, v in i["batch"].items()
51 | }
52 | with torch.no_grad():
53 | if (
54 | "sequence_length_per_iteration" in
55 | config["pipeline"]["inference_config"]
56 | ):
57 | latent_shape = tuple(
58 | i["latent_shape"][:1] + [
59 | config["pipeline"]["inference_config"]
60 | ["sequence_length_per_iteration"],
61 | ] + i["latent_shape"][2:]
62 | )
63 | pipeline_output = pipeline.autoregressive_inference_pipeline(
64 | **{
65 | k: latent_shape if k == "latent_shape" else v
66 | for k, v in i.items()
67 | })
68 | else:
69 | pipeline_output = pipeline.inference_pipeline(**i)
70 |
71 | output_images = pipeline_output["images"]
72 | collected_images = [
73 | output_images.cpu().unflatten(0, i["latent_shape"][:3])
74 | ]
75 |
76 | stacked_images = torch.stack(collected_images)
77 | resized_images = torch.nn.functional.interpolate(
78 | stacked_images.flatten(0, 3),
79 | tuple(pipeline.inference_config["preview_image_size"][::-1])
80 | )
81 | resized_images = resized_images.view(
82 | *stacked_images.shape[:4], -1, *resized_images.shape[-2:])
83 |
84 | if not ddp or torch.distributed.get_rank() == 0:
85 | if i["latent_shape"][1] == 1:
86 | # [C, B * T * S * H, V * W]
87 | preview_tensor = resized_images.permute(4, 1, 2, 0, 5, 3, 6)\
88 | .flatten(-2).flatten(1, 4)
89 | image_output_path = os.path.join(
90 | args.output_path, "{}.png".format(i_id))
91 | torchvision.transforms.functional.to_pil_image(preview_tensor)\
92 | .save(image_output_path)
93 | else:
94 | # [T, C, B * S * H, V * W]
95 | preview_tensor = resized_images.permute(2, 4, 1, 0, 5, 3, 6)\
96 | .flatten(-2).flatten(2, 4)
97 | video_output_path = os.path.join(
98 | args.output_path, "{}.mp4".format(i_id))
99 | dwm.utils.preview.save_tensor_to_video(
100 | video_output_path, "libx264", i["batch"]["fps"][0].item(),
101 | preview_tensor)
102 |
103 | print("{} done".format(i_id))
104 |
--------------------------------------------------------------------------------
/src/dwm/fs/ctar.py:
--------------------------------------------------------------------------------
1 | import dwm.common
2 | import fsspec.archive
3 | import json
4 | import os
5 | import re
6 | import tarfile
7 |
8 |
9 | class CombinedTarFileSystem(fsspec.archive.AbstractArchiveFileSystem):
10 |
11 | root_marker = ""
12 | protocol = "ctar"
13 | cachable = False
14 |
15 | def __init__(
16 | self, paths: list, fs=None,
17 | enable_cached_info: bool = False, **kwargs
18 | ):
19 | super().__init__(**kwargs)
20 |
21 | self.fs = fsspec.implementations.local.LocalFileSystem() \
22 | if fs is None else fs
23 | self.paths = paths
24 |
25 | _belongs_to = {}
26 | self._info = {}
27 | pattern = re.compile(r"\.tar$")
28 | for path in paths:
29 | cached_info_path = re.sub(pattern, ".info.json", path)
30 | if enable_cached_info:
31 | with fs.open(cached_info_path, "r", encoding="utf-8") as f:
32 | infodict = json.load(f)
33 |
34 | for filename, i in infodict.items():
35 | _belongs_to[filename] = path
36 |
37 | else:
38 | infodict = {}
39 | with fs.open(path) as f:
40 | with tarfile.TarFile(fileobj=f) as tf:
41 | for i in tf.getmembers():
42 | _belongs_to[i.name] = path
43 | infodict[i.name] = [
44 | i.offset_data, i.size, not i.isfile()
45 | ]
46 |
47 | self._info[path] = dwm.common.SerializedReadonlyDict(infodict)
48 |
49 | self._belongs_to = dwm.common.SerializedReadonlyDict(_belongs_to)
50 |
51 | self.dir_cache = None
52 | self.fp_cache = {}
53 |
54 | def close(self):
55 | # the file points are not forkable
56 | current_pid = os.getpid()
57 | if self._pid != current_pid:
58 | self.fp_cache.clear()
59 | self._pid = current_pid
60 |
61 | for fp in self.fp_cache.values():
62 | fp.close()
63 |
64 | self.fp_cache.clear()
65 |
66 | def _get_dirs(self):
67 | if self.dir_cache is None:
68 | self.dir_cache = {
69 | dirname.rstrip("/"): {
70 | "name": dirname.rstrip("/"),
71 | "size": 0,
72 | "type": "directory",
73 | }
74 | for dirname in self._all_dirnames(self._belongs_to.keys())
75 | }
76 |
77 | for file_name in self._belongs_to.keys():
78 | blob_path = self._belongs_to[file_name]
79 | _, file_size, is_dir = self._info[blob_path][file_name]
80 | f = {
81 | "name": file_name.rstrip("/"),
82 | "size": file_size,
83 | "type": "directory" if is_dir else "file",
84 | }
85 | self.dir_cache[f["name"]] = f
86 |
87 | def exists(self, path, **kwargs):
88 | return path in self._belongs_to
89 |
90 | def _open(
91 | self, path, mode="rb", block_size=None, autocommit=True,
92 | cache_options=None, **kwargs
93 | ):
94 | if "w" in mode:
95 | raise OSError("Combined TAR file system is read only.")
96 |
97 | if "b" not in mode:
98 | raise OSError(
99 | "Combined TAR file system only support the mode of binary.")
100 |
101 | if not self.exists(path):
102 | raise FileNotFoundError(path)
103 |
104 | blob_path = self._belongs_to[path]
105 | offset, size, is_dir = self._info[blob_path][path]
106 | if is_dir:
107 | raise FileNotFoundError(path)
108 |
109 | # the file points are not forkable
110 | current_pid = os.getpid()
111 | if self._pid != current_pid:
112 | self.fp_cache.clear()
113 | self._pid = current_pid
114 |
115 | if blob_path in self.fp_cache:
116 | f = self.fp_cache[blob_path]
117 | else:
118 | f = self.fp_cache[blob_path] = \
119 | self.fs.open(blob_path, mode, block_size, cache_options)
120 |
121 | return dwm.common.PartialReadableRawIO(f, offset, offset + size)
122 |
--------------------------------------------------------------------------------
/src/dwm/metrics/fvd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed
3 | import torchmetrics
4 |
5 | # PYTHONPATH includes ${workspaceFolder}/externals/TATS/tats/fvd
6 | import pytorch_i3d
7 |
8 |
9 | def _compute_fid(
10 | mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor,
11 | sigma2: torch.Tensor
12 | ):
13 | # The same implementation as torchmetrics.image.fid._compute_fid()
14 |
15 | a = (mu1 - mu2).square().sum(dim=-1)
16 | b = sigma1.trace() + sigma2.trace()
17 | c = torch.linalg.eigvals(sigma1 @ sigma2).sqrt().real.sum(dim=-1)
18 |
19 | return a + b - 2 * c
20 |
21 |
22 | class FrechetVideoDistance(torchmetrics.Metric):
23 |
24 | target_resolution = (224, 224)
25 | i3d_min = 10
26 |
27 | def __init__(
28 | self, inception_3d_checkpoint_path: str, sequence_count: int = -1,
29 | num_classes: int = 400, **kwargs
30 | ):
31 | super().__init__(**kwargs)
32 |
33 | self.inception = pytorch_i3d.InceptionI3d(num_classes)
34 | self.inception.eval()
35 | state_dict = torch.load(
36 | inception_3d_checkpoint_path, map_location="cpu",
37 | weights_only=True)
38 | self.inception.load_state_dict(state_dict)
39 | self.sequence_count = sequence_count
40 |
41 | mx_num_feats = (num_classes, num_classes)
42 | self.add_state(
43 | "real_features_sum", torch.zeros(num_classes).double(),
44 | dist_reduce_fx="sum")
45 | self.add_state(
46 | "real_features_cov_sum", torch.zeros(mx_num_feats).double(),
47 | dist_reduce_fx="sum")
48 | self.add_state(
49 | "real_features_num_samples", torch.tensor(0).long(),
50 | dist_reduce_fx="sum")
51 | self.add_state(
52 | "fake_features_sum", torch.zeros(num_classes).double(),
53 | dist_reduce_fx="sum")
54 | self.add_state(
55 | "fake_features_cov_sum", torch.zeros(mx_num_feats).double(),
56 | dist_reduce_fx="sum")
57 | self.add_state(
58 | "fake_features_num_samples", torch.tensor(0).long(),
59 | dist_reduce_fx="sum")
60 |
61 | def update(self, frames, real=True):
62 | """Update the state with extracted features.
63 |
64 | Args:
65 | frames: The video frame tensor to evaluate in the shape of
66 | `(batch_size, sequence_length, channels, height, width)`.
67 | real: Whether given frames are real or fake.
68 | """
69 |
70 | if self.sequence_count >= 0:
71 | frames = frames[:, :self.sequence_count]
72 |
73 | assert frames.shape[1] >= FrechetVideoDistance.i3d_min
74 |
75 | # normalize from [0, 1] to [-1, 1]
76 | frames = frames * 2 - 1
77 |
78 | frames = torch.nn.functional.interpolate(
79 | frames.flatten(0, 1), size=FrechetVideoDistance.target_resolution,
80 | mode="bilinear"
81 | ).unflatten(0, frames.shape[:2])
82 | features = self.inception(frames.transpose(1, 2))
83 | self.orig_dtype = features.dtype
84 | features = features.double()
85 |
86 | if real:
87 | self.real_features_sum += features.sum(dim=0)
88 | self.real_features_cov_sum += features.t().mm(features)
89 | self.real_features_num_samples += features.shape[0]
90 | else:
91 | self.fake_features_sum += features.sum(dim=0)
92 | self.fake_features_cov_sum += features.t().mm(features)
93 | self.fake_features_num_samples += features.shape[0]
94 |
95 | def compute(self):
96 | if (
97 | self.real_features_num_samples < 2 or
98 | self.fake_features_num_samples < 2
99 | ):
100 | raise RuntimeError(
101 | "More than one sample is required for both the real and fake "
102 | "distributed to compute FVD")
103 |
104 | mean_real = (
105 | self.real_features_sum / self.real_features_num_samples
106 | ).unsqueeze(0)
107 | mean_fake = (
108 | self.fake_features_sum / self.fake_features_num_samples
109 | ).unsqueeze(0)
110 |
111 | cov_real_num = self.real_features_cov_sum - \
112 | self.real_features_num_samples * mean_real.t().mm(mean_real)
113 | cov_real = cov_real_num / (self.real_features_num_samples - 1)
114 | cov_fake_num = self.fake_features_cov_sum - \
115 | self.fake_features_num_samples * mean_fake.t().mm(mean_fake)
116 | cov_fake = cov_fake_num / (self.fake_features_num_samples - 1)
117 | return _compute_fid(
118 | mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake
119 | ).to(self.orig_dtype)
120 |
--------------------------------------------------------------------------------
/src/dwm/tools/export_nusc_2_preview_format.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.common
3 | import json
4 | import os
5 | import copy
6 | import torch
7 | import torchvision
8 |
9 |
10 | def create_parser():
11 | parser = argparse.ArgumentParser(
12 | description="The script is designed to convert the nuscenes "
13 | "dataset into independent data packets, suitable for data "
14 | "loading in preview.py.")
15 | parser.add_argument(
16 | "--reference-frame-count", type=int, required=True,
17 | help="Save the nums of reference frame.")
18 | parser.add_argument(
19 | "-c", "--config-path", type=str, required=True,
20 | help="The config to load the dataset.")
21 | parser.add_argument(
22 | "-o", "--output-path", type=str, required=True,
23 | help="The path to save data packets.")
24 |
25 | return parser
26 |
27 | if __name__ == "__main__":
28 | parser = create_parser()
29 | args = parser.parse_args()
30 |
31 | with open(args.config_path, "r", encoding="utf-8") as f:
32 | config = json.load(f)
33 |
34 | validation_dataset = dwm.common.create_instance_from_config(
35 | config["validation_dataset"])
36 | validation_dataloader = torch.utils.data.DataLoader(
37 | validation_dataset,
38 | **dwm.common.instantiate_config(config["validation_dataloader"]))
39 | print("The validation dataset is loaded with {} items.".format(
40 | len(validation_dataset)))
41 |
42 | sensor_channels = [
43 | "CAM_FRONT_LEFT",
44 | "CAM_FRONT",
45 | "CAM_FRONT_RIGHT",
46 | "CAM_BACK_RIGHT",
47 | "CAM_BACK",
48 | "CAM_BACK_LEFT"
49 | ]
50 |
51 | for batch in validation_dataloader:
52 | scene_name = batch["scene"]["name"][0]
53 | output_path = os.path.join(args.output_path, scene_name)
54 | os.makedirs(output_path, exist_ok=True)
55 | timestamp = 0
56 |
57 | json_path = os.path.join(output_path, "data.json")
58 | with open(json_path, "w") as json_file: pass
59 |
60 | data_package = dict()
61 | frame_data = dict()
62 | for frame in range(batch["vae_images"].shape[1]):
63 | frame_data["camera_infos"] = dict()
64 | for sensor_channel in sensor_channels:
65 | frame_data["camera_infos"][sensor_channel] = dict()
66 |
67 | for view in range(len(sensor_channels)):
68 | frame_data["camera_infos"][sensor_channels[view]]["extrin"] = \
69 | batch["camera_transforms"][0, frame, view].tolist()
70 | frame_data["camera_infos"][sensor_channels[view]]["intrin"] = \
71 | batch["camera_intrinsics"][0, frame, view].tolist()
72 | frame_data["camera_infos"][sensor_channels[view]]["image_description"] = \
73 | batch["clip_text"][0][frame][view]
74 |
75 | if frame % batch["vae_images"].shape[1] < args.reference_frame_count:
76 | image_output_path = os.path.join(
77 | output_path, sensor_channels[view], "rgb", f"{timestamp}.png")
78 | os.makedirs(os.path.dirname(image_output_path), exist_ok=True)
79 | torchvision.transforms.functional.to_pil_image(
80 | batch["vae_images"][0, frame, view]).save(image_output_path)
81 | frame_data["camera_infos"][sensor_channels[view]]["rgb"] = \
82 | os.path.relpath(image_output_path, output_path)
83 | else:
84 | frame_data["camera_infos"][sensor_channels[view]]["rgb"] = None
85 |
86 | _3dbox_output_path = os.path.join(
87 | output_path, sensor_channels[view], "3dbox", f"{timestamp}.png")
88 | os.makedirs(os.path.dirname(_3dbox_output_path), exist_ok=True)
89 | torchvision.transforms.functional.to_pil_image(
90 | batch["3dbox_images"][0, frame, view]).save(_3dbox_output_path)
91 | frame_data["camera_infos"][sensor_channels[view]]["3dbox"] = \
92 | os.path.relpath(_3dbox_output_path, output_path)
93 |
94 | hdmap_output_path = os.path.join(
95 | output_path, sensor_channels[view], "hdmap", f"{timestamp}.png")
96 | os.makedirs(os.path.dirname(hdmap_output_path), exist_ok=True)
97 | torchvision.transforms.functional.to_pil_image(
98 | batch["hdmap_images"][0, frame, view]).save(hdmap_output_path)
99 | frame_data["camera_infos"][sensor_channels[view]]["hdmap"] = \
100 | os.path.relpath(hdmap_output_path, output_path)
101 |
102 | frame_data["timestamp"] = timestamp
103 | timestamp += 1/int(batch["fps"])
104 | timestamp = round(timestamp, 4)
105 | frame_data["ego_pose"] = batch["ego_transforms"][0, frame, 0].tolist()
106 | data_package[frame] = copy.deepcopy(frame_data)
107 |
108 | with open(json_path, "a") as json_file:
109 | json.dump(data_package, json_file, indent=4)
110 | json_file.write("\n")
111 |
--------------------------------------------------------------------------------
/src/dwm/utils/preview.py:
--------------------------------------------------------------------------------
1 | import av
2 | import torch
3 | import torchvision
4 |
5 |
6 | def make_ctsd_preview_tensor(output_images, batch, inference_config):
7 |
8 | # The output image sequece length may be shorter than the input due to the
9 | # autoregressive inference, so use the output sequence length to clip batch
10 | # data.
11 | batch_size, _, view_count = batch["vae_images"].shape[:3]
12 | output_images = output_images\
13 | .cpu().unflatten(0, (batch_size, -1, view_count))
14 | sequence_length = output_images.shape[1]
15 |
16 | collected_images = [batch["vae_images"][:, :sequence_length]]
17 | if "3dbox_images" in batch:
18 | collected_images.append(
19 | batch["3dbox_images"][:, :sequence_length])
20 |
21 | if "hdmap_images" in batch:
22 | collected_images.append(
23 | batch["hdmap_images"][:, :sequence_length])
24 |
25 | collected_images.append(output_images)
26 |
27 | stacked_images = torch.stack(collected_images)
28 | resized_images = torch.nn.functional.interpolate(
29 | stacked_images.flatten(0, 3),
30 | tuple(inference_config["preview_image_size"][::-1])
31 | )
32 | resized_images = resized_images.view(
33 | *stacked_images.shape[:4], -1, *resized_images.shape[-2:])
34 | if sequence_length == 1:
35 | # image preview with shape [C, B * T * S * H, V * W]
36 | preview_tensor = resized_images.permute(4, 1, 2, 0, 5, 3, 6)\
37 | .flatten(-2).flatten(1, 4)
38 | else:
39 | # video preview with shape [T, C, B * S * H, V * W]
40 | preview_tensor = resized_images.permute(2, 4, 1, 0, 5, 3, 6)\
41 | .flatten(-2).flatten(2, 4)
42 |
43 | return preview_tensor
44 |
45 |
46 | def make_lidar_preview_tensor(
47 | ground_truth_volumn, generated_volumn, batch, inference_config
48 | ):
49 | collected_images = [
50 | ground_truth_volumn.amax(-3, keepdim=True).repeat_interleave(3, -3)
51 | .cpu()
52 | ]
53 | if "3dbox_bev_images_denorm" in batch:
54 | collected_images.append(batch["3dbox_bev_images_denorm"])
55 |
56 | if "hdmap_bev_images_denorm" in batch:
57 | collected_images.append(batch["hdmap_bev_images_denorm"])
58 |
59 | if isinstance(generated_volumn, list):
60 | for gv in generated_volumn:
61 | collected_images.append(
62 | gv.amax(-3, keepdim=True).repeat_interleave(3, -3).cpu())
63 | else:
64 | collected_images.append(
65 | generated_volumn.amax(-3, keepdim=True).repeat_interleave(3, -3).cpu())
66 |
67 | # assume all BEV images have the same size
68 | stacked_images = torch.stack(collected_images)
69 | if ground_truth_volumn.shape[1] == 1:
70 | # BEV image preview with shape [C, B * T * H, S * W]
71 | preview_tensor = stacked_images.permute(3, 1, 2, 4, 0, 5).flatten(-2)\
72 | .flatten(1, 3)
73 | else:
74 | # BEV video preview with shape [T, C, B * H, S * W]
75 | preview_tensor = stacked_images.permute(2, 3, 1, 4, 0, 5).flatten(-2)\
76 | .flatten(2, 3)
77 |
78 | return preview_tensor
79 |
80 |
81 | def save_tensor_to_video(
82 | path: str, video_encoder: str, fps, tensor_list, pix_fmt: str = "yuv420p",
83 | stream_options: dict = {"crf": "16"}
84 | ):
85 | tensor_shape = tensor_list[0].shape
86 | with av.open(path, mode="w") as container:
87 | stream = container.add_stream(video_encoder, int(fps))
88 | stream.width = tensor_shape[-1]
89 | stream.height = tensor_shape[-2]
90 | stream.pix_fmt = pix_fmt
91 | stream.options = stream_options
92 | for i in tensor_list:
93 | frame = av.VideoFrame.from_image(
94 | torchvision.transforms.functional.to_pil_image(i))
95 | for p in stream.encode(frame):
96 | container.mux(p)
97 |
98 | for p in stream.encode():
99 | container.mux(p)
100 |
101 |
102 | def gray_to_colormap(img, cmap='rainbow', max_val=None):
103 | """
104 | Transfer gray map to matplotlib colormap
105 | """
106 | assert img.ndim == 2
107 | import matplotlib
108 | import matplotlib.cm
109 |
110 | img[img<0] = 0
111 | mask_invalid = img < 1e-10
112 | if max_val is None:
113 | img = img / (img.max() + 1e-8)
114 | else:
115 | img = img / (max_val + 1e-8)
116 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
117 | cmap_m = matplotlib.cm.get_cmap(cmap)
118 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
119 | colormap = map.to_rgba(img)[:, :, :3]
120 | colormap[mask_invalid] = 0
121 | return colormap
122 |
123 | def depths_to_colors(depths, concat="width", colormap="rainbow", max_val=None):
124 | colors = []
125 | if isinstance(depths, list) or len(depths.shape) == 4:
126 | for depth in depths:
127 | color = gray_to_colormap(depth.detach().cpu().numpy(), cmap=colormap, max_val=max_val)
128 | colors.append(color.permute(2, 0, 1))
129 | if concat == "width":
130 | colors = torch.cat(colors, dim=2)
131 | else:
132 | colors = torch.stack(colors)
133 | else:
134 | colors = gray_to_colormap(depths.detach().cpu().numpy(), cmap=colormap, max_val=max_val)
135 | colors = torch.from_numpy(colors).permute(2, 0, 1)
136 | return colors
137 |
--------------------------------------------------------------------------------
/src/dwm/streaming.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import av
3 | import dwm.common
4 | import einops
5 | import json
6 | import numpy as np
7 | from PIL import Image
8 | import queue
9 | import time
10 | import torch
11 |
12 |
13 | def create_parser():
14 | parser = argparse.ArgumentParser(
15 | description="The script to finetune a stable diffusion model to the "
16 | "driving dataset.")
17 | parser.add_argument(
18 | "-c", "--config-path", type=str, required=True,
19 | help="The config to load the train model and dataset.")
20 | parser.add_argument(
21 | "-l", "--log-path", type=str, required=True,
22 | help="The path to save log files.")
23 | parser.add_argument(
24 | "-s", "--streaming-path", type=str, required=True,
25 | help="The path to upload the video stream.")
26 | parser.add_argument(
27 | "-f", "--format", default="rtsp", type=str,
28 | help="The streaming format.")
29 | parser.add_argument(
30 | "--fps", default=2, type=int,
31 | help="The streaming FPS.")
32 | parser.add_argument(
33 | "-vcodec", "--video-encoder", default="libx264", type=str,
34 | help="The video encoder type.")
35 | parser.add_argument(
36 | "--pix-fmt", default="yuv420p", type=str,
37 | help="The pixel format.")
38 | return parser
39 |
40 |
41 | def merge_multiview_images(pipeline_frame, data_condition=None):
42 | image_data = np.concatenate([np.asarray(i) for i in pipeline_frame], 1)
43 | if data_condition is not None:
44 | _3dbox_data = torch.nn.functional.interpolate(
45 | einops.rearrange(
46 | data_condition["3dbox_images"],
47 | "b t v c h w -> b c h (t v w)"),
48 | image_data.shape[:2]
49 | )[0].permute(1, 2, 0).numpy()
50 | hdmap_data = torch.nn.functional.interpolate(
51 | einops.rearrange(
52 | data_condition["hdmap_images"],
53 | "b t v c h w -> b c h (t v w)"),
54 | image_data.shape[:2]
55 | )[0].permute(1, 2, 0).numpy()
56 | condition_data = np.maximum(_3dbox_data, hdmap_data)
57 | condition_ahpla = np.max(condition_data, -1, keepdims=True) * 0.6
58 | image_data = (
59 | condition_data * 255 * condition_ahpla +
60 | image_data * (1 - condition_ahpla)
61 | ).astype(np.uint8)
62 |
63 | return Image.fromarray(image_data)
64 |
65 |
66 | if __name__ == "__main__":
67 | parser = create_parser()
68 | args = parser.parse_args()
69 |
70 | with open(args.config_path, "r", encoding="utf-8") as f:
71 | config = json.load(f)
72 |
73 | # setup the global state
74 | if "global_state" in config:
75 | for key, value in config["global_state"].items():
76 | dwm.common.global_state[key] = \
77 | dwm.common.create_instance_from_config(value)
78 |
79 | # load the pipeline including the models
80 | pipeline = dwm.common.create_instance_from_config(
81 | config["pipeline"], output_path=args.log_path, config=config,
82 | device=torch.device(config["device"]))
83 | print("The pipeline is loaded.")
84 |
85 | data_adapter = dwm.common.create_instance_from_config(
86 | config["data_adapter"])
87 |
88 | size = pipeline.inference_config["preview_image_size"]
89 | latent_shape = (
90 | 1, pipeline.inference_config["sequence_length_per_iteration"],
91 | len(data_adapter.sensor_channels), pipeline.vae.config.latent_channels,
92 | config["latent_size"][0], config["latent_size"][1]
93 | )
94 | pipeline.reset_streaming(latent_shape, "pil")
95 |
96 | streaming_state = {}
97 | data_queue = queue.Queue()
98 | with av.open(
99 | args.streaming_path, mode="w", format=args.format,
100 | container_options=config.get("container_options", {})
101 | ) as container:
102 | stream = container.add_stream(args.video_encoder, args.fps)
103 | stream.pix_fmt = args.pix_fmt
104 | stream.options = config.get("stream_options", {})
105 | while True:
106 | data = data_adapter.query_data()
107 | data_queue.put_nowait(data)
108 | pipeline.send_frame_condition(data)
109 | pipeline_frame = pipeline.receive_frame()
110 | if pipeline_frame is None:
111 | continue
112 |
113 | matched_data = data_queue.get_nowait()
114 | image = merge_multiview_images(
115 | pipeline_frame,
116 | (
117 | matched_data
118 | if config.get("preview_condition", False)
119 | else None
120 | ))
121 | if not streaming_state.get("is_frame_size_set", False):
122 | stream.width = image.width
123 | stream.height = image.height
124 | streaming_state["is_frame_size_set"] = True
125 |
126 | while (
127 | "expected_time" in streaming_state and
128 | time.time() < streaming_state["expected_time"]
129 | ):
130 | time.sleep(0.01)
131 |
132 | frame = av.VideoFrame.from_image(image)
133 | for p in stream.encode(frame):
134 | container.mux(p)
135 |
136 | streaming_state["expected_time"] = (
137 | time.time()
138 | if "expected_time" not in streaming_state
139 | else streaming_state["expected_time"]
140 | ) + 1 / args.fps
141 | print("{:.1f}".format(streaming_state["expected_time"]))
142 |
--------------------------------------------------------------------------------
/src/dwm/metrics/pc_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchmetrics
3 | import torch.distributed
4 | from torchmetrics import Metric
5 |
6 | from dwm.utils.metrics_copilot4d import (
7 | compute_chamfer_distance,
8 | compute_chamfer_distance_inner,
9 | compute_mmd,
10 | jsd_2d,
11 | gaussian,
12 | point_cloud_to_histogram
13 | )
14 |
15 | class PointCloudChamfer(Metric):
16 | def __init__(self, inner_dist=None, **kwargs):
17 | super().__init__(**kwargs)
18 | self.inner_dist = inner_dist
19 |
20 | self.cd_func = compute_chamfer_distance if self.inner_dist is None else compute_chamfer_distance_inner
21 | self.chamfer_list = []
22 |
23 | def update(self, pred_pcd, gt_pcd, device=None):
24 | for pred, gt in zip(pred_pcd, gt_pcd):
25 | for p, g in zip(pred, gt):
26 | if self.inner_dist is None:
27 | cd = self.cd_func(p.to(torch.float32), g.to(torch.float32), device=device)
28 | else:
29 | cd = self.cd_func(p.to(torch.float32), g.to(torch.float32), device=device, pc_range=[
30 | -self.inner_dist, -self.inner_dist, -3, self.inner_dist, self.inner_dist, 5])
31 | if not isinstance(cd, torch.Tensor):
32 | cd = torch.tensor(cd).to(device)
33 | self.chamfer_list.append(cd.float())
34 |
35 | def compute(self):
36 | chamfer_list = torch.stack(self.chamfer_list, dim=0)
37 | world_size = torch.distributed.get_world_size() \
38 | if torch.distributed.is_initialized() else 1
39 | if world_size > 1:
40 | all_chamfer = chamfer_list.new_zeros(
41 | (len(chamfer_list)*world_size, ) + chamfer_list.shape[1:])
42 | torch.distributed.all_gather_into_tensor(
43 | all_chamfer, chamfer_list)
44 | chamfer_list = all_chamfer
45 | num_samples = (~torch.isnan(chamfer_list) & ~torch.isinf(chamfer_list)).sum()
46 | chamfer_list = torch.nan_to_num(chamfer_list, nan=0.0, posinf=0.0, neginf=0.0)
47 | return chamfer_list.sum() / num_samples
48 |
49 | def reset(self):
50 | self.chamfer_list.clear()
51 | super().reset()
52 |
53 |
54 | class PointCloudMMD(Metric):
55 | """
56 | Compute the Maximum Mean Discrepancy (MMD) between two point clouds.
57 | """
58 | def __init__(self, field_size=160, bins=100, **kwargs):
59 | super().__init__(**kwargs)
60 | self.field_size = field_size
61 | self.bins = bins
62 | self.mmd_list = []
63 | def update(self, pred_pcd, gt_pcd, device=None):
64 | pred_hist = []
65 | gt_hist = []
66 | for pred, gt in zip(pred_pcd, gt_pcd):
67 | for p, g in zip(pred, gt):
68 | p = point_cloud_to_histogram(self.field_size, self.bins, p.to(torch.float32))[0]
69 | g = point_cloud_to_histogram(self.field_size, self.bins, g.to(torch.float32))[0]
70 | pred_hist.append(p)
71 | gt_hist.append(g)
72 | mmd = compute_mmd(pred_hist, gt_hist, kernel=gaussian, is_parallel=True)
73 | if not isinstance(mmd, torch.Tensor):
74 | mmd = torch.tensor(mmd).to(device)
75 | self.mmd_list.append(mmd.float())
76 |
77 | def compute(self):
78 | mmd_list = torch.stack(self.mmd_list, dim=0)
79 | world_size = torch.distributed.get_world_size() \
80 | if torch.distributed.is_initialized() else 1
81 | if world_size > 1:
82 | all_mmd = mmd_list.new_zeros(
83 | (len(mmd_list)*world_size, ) + mmd_list.shape[1:])
84 | torch.distributed.all_gather_into_tensor(
85 | all_mmd, mmd_list)
86 | mmd_list = all_mmd
87 | num_samples =(~torch.isnan(mmd_list) & ~torch.isinf(mmd_list)).sum()
88 | mmd_list = torch.nan_to_num(mmd_list, nan=0.0, posinf=0.0, neginf=0.0)
89 | return mmd_list.sum() / num_samples
90 |
91 | def reset(self):
92 | self.mmd_list.clear()
93 | super().reset()
94 |
95 | class PointCloudJSD(Metric):
96 | def __init__(self, field_size=160, bins=100, **kwargs):
97 | super().__init__(**kwargs)
98 | self.field_size = field_size
99 | self.bins = bins
100 | self.jsd_list = []
101 |
102 | def update(self, pred_pcd, gt_pcd, device=None):
103 | for pred, gt in zip(pred_pcd, gt_pcd):
104 | for p, g in zip(pred, gt):
105 | p = point_cloud_to_histogram(self.field_size, self.bins, p)[0]
106 | g = point_cloud_to_histogram(self.field_size, self.bins, g)[0]
107 | jsd = jsd_2d(p, g)
108 | if not isinstance(jsd, torch.Tensor):
109 | jsd = torch.tensor(jsd).to(device)
110 | self.jsd_list.append(jsd.float())
111 |
112 | def compute(self):
113 | jsd_list = torch.stack(self.jsd_list, dim=0)
114 | world_size = torch.distributed.get_world_size() \
115 | if torch.distributed.is_initialized() else 1
116 | if world_size > 1:
117 | all_jsd = jsd_list.new_zeros(
118 | (len(jsd_list)*world_size, ) + jsd_list.shape[1:])
119 | torch.distributed.all_gather_into_tensor(
120 | all_jsd, jsd_list)
121 | jsd_list = all_jsd
122 | num_samples = (~torch.isnan(jsd_list) & ~torch.isinf(jsd_list)).sum()
123 | jsd_list = torch.nan_to_num(jsd_list, nan=0.0, posinf=0.0, neginf=0.0)
124 | return jsd_list.sum() / num_samples
125 |
126 | def reset(self):
127 | self.jsd_list.clear()
128 | super().reset()
129 |
--------------------------------------------------------------------------------
/src/dwm/tools/dataset_make_info_json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.tools.fs_make_info_json
3 | import fsspec.implementations.local
4 | import json
5 | import os
6 | import struct
7 | import re
8 |
9 |
10 | def create_parser():
11 | parser = argparse.ArgumentParser(
12 | description="The script to make information JSON(s) for dataset to "
13 | "accelerate initialization.")
14 | parser.add_argument(
15 | "-dt", "--dataset-type", type=str,
16 | choices=["nuscenes", "waymo", "argoverse"], required=True,
17 | help="The dataset type.")
18 | parser.add_argument(
19 | "-s", "--split", default=None, type=str,
20 | help="The split, optional depending on the dataset type.")
21 | parser.add_argument(
22 | "-i", "--input-path", type=str, required=True,
23 | help="The path of the dataset root.")
24 | parser.add_argument(
25 | "-o", "--output-path", type=str, required=True,
26 | help="The path to save the information JSON file(s) on the local file "
27 | "system.")
28 | parser.add_argument(
29 | "-fs", "--fs-config-path", default=None, type=str,
30 | help="The path of file system JSON config to open the dataset.")
31 | return parser
32 |
33 |
34 | if __name__ == "__main__":
35 | import tqdm
36 |
37 | parser = create_parser()
38 | args = parser.parse_args()
39 |
40 | if args.fs_config_path is None:
41 | fs = fsspec.implementations.local.LocalFileSystem()
42 | else:
43 | import dwm.common
44 | with open(args.fs_config_path, "r", encoding="utf-8") as f:
45 | fs = dwm.common.create_instance_from_config(json.load(f))
46 |
47 | if args.dataset_type == "nuscenes":
48 | files = [
49 | os.path.relpath(i, args.input_path)
50 | for i in fs.ls(args.input_path, detail=False)
51 | ]
52 | filtered_files = [
53 | i for i in files
54 | if (
55 | (args.split is None or i.startswith(args.split)) and
56 | i.endswith(".zip")
57 | )
58 | ]
59 | assert len(filtered_files) > 0, (
60 | "No files detected, please check the split (one of \"v1.0-mini\", "
61 | "\"v1.0-trainval\", \"v1.0-test\") is correct, and ensure the "
62 | "blob files are already converted to the ZIP format."
63 | )
64 |
65 | os.makedirs(args.output_path, exist_ok=True)
66 | for i in tqdm.tqdm(filtered_files):
67 | with fs.open("{}/{}".format(args.input_path, i)) as f:
68 | items = dwm.tools.fs_make_info_json.make_info_dict(
69 | os.path.splitext(i)[-1], f)
70 |
71 | with open(
72 | os.path.join(
73 | args.output_path, i.replace(".zip", ".info.json")),
74 | "w", encoding="utf-8"
75 | ) as f:
76 | json.dump(items, f)
77 |
78 | elif args.dataset_type == "waymo":
79 | import waymo_open_dataset.dataset_pb2 as waymo_pb
80 |
81 | files = [
82 | os.path.relpath(i, args.input_path)
83 | for i in fs.ls(args.input_path, detail=False)
84 | if i.endswith(".tfrecord")
85 | ]
86 | assert len(files) > 0, "No files detected."
87 |
88 | pattern = re.compile(
89 | "^segment-(?P.*)_with_camera_labels.tfrecord$")
90 | info_dict = {}
91 | for i in tqdm.tqdm(files):
92 | match = re.match(pattern, i)
93 | scene = match.group("scene")
94 | pt = 0
95 | info_list = []
96 | with fs.open("{}/{}".format(args.input_path, i)) as f:
97 | while True:
98 | start = f.read(8)
99 | if len(start) == 0:
100 | break
101 |
102 | size, = struct.unpack(" 0, (
130 | "No files detected, please check the split (one of \"train\", "
131 | "\"val\", \"test\") is correct."
132 | )
133 |
134 | os.makedirs(args.output_path, exist_ok=True)
135 | for i in tqdm.tqdm(files):
136 | with fs.open("{}/{}".format(args.input_path, i)) as f:
137 | items = dwm.tools.fs_make_info_json.make_info_dict(
138 | os.path.splitext(i)[-1], f, enable_tqdm=False)
139 |
140 | with open(
141 | os.path.join(
142 | args.output_path, i.replace(".tar", ".info.json")),
143 | "w", encoding="utf-8"
144 | ) as f:
145 | json.dump(items, f)
146 |
147 | else:
148 | raise Exception("Unknown dataset type {}.".format(args.dataset_type))
149 |
--------------------------------------------------------------------------------
/src/dwm/fs/czip.py:
--------------------------------------------------------------------------------
1 | import dwm.common
2 | import fsspec
3 | import fsspec.archive
4 | import io
5 | import json
6 | import os
7 | import re
8 | import struct
9 | import zipfile
10 | import zlib
11 |
12 |
13 | class CombinedZipFileSystem(fsspec.archive.AbstractArchiveFileSystem):
14 |
15 | """This file system can be used to access files inside ZIP files, and it is
16 | also compatible with the process fork under the multi-worker situation of
17 | the PyTorch data loader.
18 |
19 | Args:
20 | paths (list): The path list of ZIP files to open.
21 | fs (fsspec.AbstractFileSystem or None): The file system to open the ZIP
22 | blobs, or local file system as default.
23 | enable_cached_info (bool): Load the byte offset of ZIP entries from
24 | cached info file with the extension `.info.json` to accelerate
25 | opening large ZIP files.
26 | """
27 |
28 | root_marker = ""
29 | protocol = "czip"
30 | cachable = False
31 |
32 | def __init__(
33 | self, paths: list, fs=None,
34 | enable_cached_info: bool = False, **kwargs
35 | ):
36 | super().__init__(**kwargs)
37 |
38 | self.fs = fsspec.implementations.local.LocalFileSystem() \
39 | if fs is None else fs
40 | self.paths = paths
41 |
42 | _belongs_to = {}
43 | self._info = {}
44 | pattern = re.compile(r"\.zip$")
45 | for path in paths:
46 | cached_info_path = re.sub(pattern, ".info.json", path)
47 | if enable_cached_info:
48 | with fs.open(cached_info_path, "r", encoding="utf-8") as f:
49 | infodict = json.load(f)
50 |
51 | for filename, i in infodict.items():
52 | _belongs_to[filename] = path
53 |
54 | else:
55 | with fs.open(path) as f:
56 | with zipfile.ZipFile(f) as zf:
57 | infodict = {}
58 | for i in zf.infolist():
59 | _belongs_to[i.filename] = path
60 | infodict[i.filename] = [
61 | i.header_offset, i.file_size, i.is_dir()
62 | ]
63 |
64 | self._info[path] = dwm.common.SerializedReadonlyDict(infodict)
65 |
66 | self._belongs_to = dwm.common.SerializedReadonlyDict(_belongs_to)
67 |
68 | self.dir_cache = None
69 | self.fp_cache = {}
70 |
71 | @classmethod
72 | def _strip_protocol(cls, path):
73 | # zip file paths are always relative to the archive root
74 | return super()._strip_protocol(path).lstrip("/")
75 |
76 | def close(self):
77 | # the file points are not forkable
78 | current_pid = os.getpid()
79 | if self._pid != current_pid:
80 | self.fp_cache.clear()
81 | self._pid = current_pid
82 |
83 | for fp in self.fp_cache.values():
84 | fp.close()
85 |
86 | self.fp_cache.clear()
87 |
88 | def _get_dirs(self):
89 | if self.dir_cache is None:
90 | self.dir_cache = {
91 | dirname.rstrip("/"): {
92 | "name": dirname.rstrip("/"),
93 | "size": 0,
94 | "type": "directory",
95 | }
96 | for dirname in self._all_dirnames(self._belongs_to.keys())
97 | }
98 |
99 | for file_name in self._belongs_to.keys():
100 | zip_path = self._belongs_to[file_name]
101 | _, file_size, is_dir = self._info[zip_path][file_name]
102 | f = {
103 | "name": file_name.rstrip("/"),
104 | "size": file_size,
105 | "type": "directory" if is_dir else "file",
106 | }
107 | self.dir_cache[f["name"]] = f
108 |
109 | def exists(self, path, **kwargs):
110 | return path in self._belongs_to
111 |
112 | def _open(
113 | self, path, mode="rb", block_size=None, autocommit=True,
114 | cache_options=None, **kwargs,
115 | ):
116 | if "w" in mode:
117 | raise OSError("Combined stateless ZIP file system is read only.")
118 |
119 | if "b" not in mode:
120 | raise OSError(
121 | "Combined stateless ZIP file system only support the mode of "
122 | "binary.")
123 |
124 | if not self.exists(path):
125 | raise FileNotFoundError(path)
126 |
127 | zip_path = self._belongs_to[path]
128 | header_offset, size, is_dir = self._info[zip_path][path]
129 | if is_dir:
130 | raise FileNotFoundError(path)
131 |
132 | # the file points are not forkable
133 | current_pid = os.getpid()
134 | if self._pid != current_pid:
135 | self.fp_cache.clear()
136 | self._pid = current_pid
137 |
138 | if zip_path in self.fp_cache:
139 | f = self.fp_cache[zip_path]
140 | else:
141 | f = self.fp_cache[zip_path] = \
142 | self.fs.open(zip_path, mode, block_size, cache_options)
143 |
144 | f.seek(header_offset)
145 | fh = struct.unpack(zipfile.structFileHeader, f.read(30))
146 | method = fh[zipfile._FH_COMPRESSION_METHOD]
147 | offset = header_offset + 30 + fh[zipfile._FH_FILENAME_LENGTH] + \
148 | fh[zipfile._FH_EXTRA_FIELD_LENGTH]
149 |
150 | if method == zipfile.ZIP_STORED:
151 | return dwm.common.PartialReadableRawIO(f, offset, offset + size)
152 |
153 | elif method == zipfile.ZIP_DEFLATED:
154 | f.seek(offset)
155 | data = f.read(size)
156 | result = io.BytesIO(zlib.decompress(data, -15))
157 | return result
158 |
159 | else:
160 | raise NotImplementedError("The compression method is unsupported")
161 |
--------------------------------------------------------------------------------
/src/dwm/utils/carla_simulation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import carla
3 | import dwm.common
4 | import json
5 | import random
6 | import time
7 |
8 |
9 | def create_parser():
10 | parser = argparse.ArgumentParser(
11 | description="The tool to setup the Carla simulation.")
12 | parser.add_argument(
13 | "-c", "--config-path", type=str, required=True,
14 | help="The path of the simulation config about the environment, ego "
15 | "vehicle, and scenario.")
16 | parser.add_argument(
17 | "--host", default="127.0.0.1", type=str,
18 | help="The host address of the Carla simulator.")
19 | parser.add_argument(
20 | "-p", "--port", default=2000, type=int,
21 | help="The port of the Carla simulator.")
22 | parser.add_argument(
23 | "-tp", "--traffic-port", default=8000, type=int,
24 | help="The port of the Traffic manager.")
25 | parser.add_argument(
26 | "--client-timeout", default=10.0, type=float,
27 | help="The timeout of the Carla client.")
28 | parser.add_argument(
29 | "--step-sleep", default=0.0, type=float,
30 | help="The time to sleep for each step.")
31 | return parser
32 |
33 |
34 | def make_actor(
35 | world: carla.World, blueprint_library: carla.BlueprintLibrary,
36 | spawn_points: list, actor_config: dict, random_state: random.Random,
37 | attach_to=None
38 | ):
39 | # prepare the blueprint
40 | if "pattern" in actor_config:
41 | bp_list = blueprint_library.filter(actor_config["pattern"])
42 | bp = (
43 | bp_list[actor_config["matched_index"]]
44 | if "matched_index" in actor_config
45 | else random_state.choice(bp_list)
46 | )
47 | else:
48 | bp = blueprint_library.find(actor_config["id"])
49 |
50 | for k, v in actor_config.get("attributes", {}).items():
51 | bp.set_attribute(k, v)
52 |
53 | # prepare the spawn location for vehicles, pedestrians, cameras
54 | if "spawn_index" in actor_config:
55 | spawn_transform = spawn_points[
56 | actor_config["spawn_index"] % len(spawn_points)
57 | ]
58 | elif "spawn_from_navigation" in actor_config:
59 | location = world.get_random_location_from_navigation()
60 | spawn_transform = carla.Transform(location, carla.Rotation(0, 0, 0))
61 | else:
62 | spawn_transform = actor_config["spawn_transform"]
63 | spawn_transform = carla.Transform(
64 | carla.Location(*spawn_transform.get("location", [0, 0, 0])),
65 | carla.Rotation(*spawn_transform.get("rotation", [0, 0, 0])))
66 |
67 | # instantiate the actor, set attributes and apply custom setup functions
68 | actor = world.try_spawn_actor(bp, spawn_transform, attach_to)
69 |
70 | if actor is not None:
71 | if actor.attributes.get("role_name") == "autopilot":
72 | actor.set_autopilot(True)
73 |
74 | if actor is None:
75 | print("Warning: failed to spawn {}".format(bp.id))
76 | return None, None, None
77 |
78 | if actor_config.get("report_actor_id", False):
79 | attributes = actor.attributes
80 | report_text = "{}{}: {}".format(
81 | actor_config["id"],
82 | " ({})".format(attributes["role_name"])
83 | if "role_name" in attributes else "", actor.id)
84 | print(report_text)
85 |
86 | if actor_config.get("report_actor_attributes", False):
87 | print("{}: {}".format(actor.type_id, actor.attributes))
88 |
89 | if "state_machine" in actor_config:
90 | _class = dwm.common.get_class(actor_config["state_machine"])
91 | state_machine = _class(
92 | actor, **actor_config.get("state_machine_args", {}))
93 | else:
94 | state_machine = None
95 |
96 | children = [
97 | make_actor(
98 | world, blueprint_library, spawn_points, i, rs, actor)
99 | for i in actor_config["child_configs"]
100 | ] if "child_configs" in actor_config else None
101 |
102 | return actor, state_machine, children
103 |
104 |
105 | def update_actor_state(actors: list):
106 | for _, state_machine, children in actors:
107 | if state_machine is not None:
108 | state_machine.update()
109 |
110 | if children is not None:
111 | update_actor_state(children)
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = create_parser()
116 | args = parser.parse_args()
117 |
118 | with open(args.config_path, "r", encoding="utf-8") as f:
119 | config = json.load(f)
120 |
121 | rs = random.Random(config.get("seed", None))
122 | client = carla.Client(args.host, args.port, 1)
123 | client.set_timeout(args.client_timeout)
124 |
125 | # This script just use the activated map. To config the map, please use the
126 | # {CarlaRoot}/PythonAPI/util/config.py
127 | world = client.get_world()
128 | traffic_manager = client.get_trafficmanager(args.traffic_port)
129 |
130 | if config.get("master", False):
131 | traffic_manager.set_synchronous_mode(True)
132 |
133 | if "world_settings" in config:
134 | settings = world.get_settings()
135 | for k, v in config["world_settings"].items():
136 | setattr(settings, k, v)
137 |
138 | world.apply_settings(settings)
139 |
140 | if "traffic_manager_settings" in config:
141 | for k, v in config["traffic_manager_settings"].items():
142 | getattr(traffic_manager, k)(v)
143 |
144 | actors = [
145 | make_actor(
146 | world, world.get_blueprint_library(),
147 | world.get_map().get_spawn_points(), i, rs)
148 | for i in config["actor_configs"]
149 | ]
150 |
151 | step = 0
152 | total_steps = config.get("total_steps", -1)
153 | while total_steps == -1 or step < total_steps:
154 | if args.step_sleep > 0.0:
155 | time.sleep(args.step_sleep)
156 |
157 | if config.get("master", False):
158 | world.tick()
159 | else:
160 | world.wait_for_tick()
161 |
162 | update_actor_state(actors)
163 | step += 1
164 |
--------------------------------------------------------------------------------
/src/dwm/export_generation_result_as_nuscenes_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.common
3 | import json
4 | import numpy as np
5 | import os
6 | import torch
7 |
8 |
9 | def create_parser():
10 | parser = argparse.ArgumentParser(
11 | description="The script to run the diffusion model to generate data for"
12 | "detection evaluation.")
13 | parser.add_argument(
14 | "-c", "--config-path", type=str, required=True,
15 | help="The config to load the train model and dataset.")
16 | parser.add_argument(
17 | "-o", "--output-path", type=str, required=True,
18 | help="The path to save checkpoint files.")
19 | return parser
20 |
21 |
22 | """
23 | This script requires:
24 | 1. The validation dataset should be only nuScenes.
25 | 2. Add `"enable_sample_data": true` to the dataset arguments, so the Dataset
26 | load the filename (path) to save the generated images.
27 | 3. Add `"sample_data"` to "validation_dataloader.collate_fn.keys" to pass the
28 | object data directly to the script here, no need to collate to tensors.
29 |
30 | Note:
31 | *. Set the "model_checkpoint_path" with the trained checkpoint, rather than
32 | the pretrained checkpoint.
33 | *. Set the fps_stride to [0, 1] for the image dataset, or
34 | [2, 0.5 * sequence_length] for the origin video dataset, or
35 | [12, 0.1 * sequence_length] for the 12Hz video dataset. Make sure no
36 | sample is missed.
37 | """
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = create_parser()
42 | args = parser.parse_args()
43 |
44 | with open(args.config_path, "r", encoding="utf-8") as f:
45 | config = json.load(f)
46 |
47 | # set distributed training (if enabled), log, random number generator, and
48 | # load the checkpoint (if required).
49 | ddp = "LOCAL_RANK" in os.environ
50 | if ddp:
51 | local_rank = int(os.environ["LOCAL_RANK"])
52 | device = torch.device(config["device"], local_rank)
53 | if config["device"] == "cuda":
54 | torch.cuda.set_device(local_rank)
55 |
56 | torch.distributed.init_process_group(backend=config["ddp_backend"])
57 | else:
58 | device = torch.device(config["device"])
59 |
60 | # setup the global state
61 | if "global_state" in config:
62 | for key, value in config["global_state"].items():
63 | dwm.common.global_state[key] = \
64 | dwm.common.create_instance_from_config(value)
65 |
66 | should_log = (ddp and local_rank == 0) or not ddp
67 |
68 | pipeline = dwm.common.create_instance_from_config(
69 | config["pipeline"], output_path=args.output_path, config=config,
70 | device=device)
71 | if should_log:
72 | print("The pipeline is loaded.")
73 |
74 | # load the dataset
75 | validation_dataset = dwm.common.create_instance_from_config(
76 | config["validation_dataset"])
77 | if ddp:
78 | validation_datasampler = \
79 | torch.utils.data.distributed.DistributedSampler(
80 | validation_dataset)
81 | validation_dataloader = torch.utils.data.DataLoader(
82 | validation_dataset,
83 | **dwm.common.instantiate_config(config["validation_dataloader"]),
84 | sampler=validation_datasampler)
85 | else:
86 | validation_datasampler = None
87 | validation_dataloader = torch.utils.data.DataLoader(
88 | validation_dataset,
89 | **dwm.common.instantiate_config(config["validation_dataloader"]))
90 |
91 | if should_log:
92 | print("The validation dataset is loaded with {} items.".format(
93 | len(validation_dataset)))
94 |
95 | if ddp:
96 | validation_datasampler.set_epoch(0)
97 |
98 | for batch in validation_dataloader:
99 | batch_size, sequence_length, view_count = batch["vae_images"].shape[:3]
100 | latent_height = batch["vae_images"].shape[-2] // \
101 | (2 ** (len(pipeline.vae.config.down_block_types) - 1))
102 | latent_width = batch["vae_images"].shape[-1] // \
103 | (2 ** (len(pipeline.vae.config.down_block_types) - 1))
104 | latent_shape = (
105 | batch_size, sequence_length, view_count,
106 | pipeline.vae.config.latent_channels, latent_height,
107 | latent_width
108 | )
109 |
110 | with torch.no_grad():
111 | pipeline_output = pipeline.inference_pipeline(
112 | latent_shape, batch, "pil")
113 |
114 | if "images" in pipeline_output:
115 | paths = [
116 | os.path.join(args.output_path, k["filename"])
117 | for i in batch["sample_data"]
118 | for j in i
119 | for k in j if not k["filename"].endswith(".bin")
120 | ]
121 | image_results = pipeline_output["images"]
122 | image_sizes = batch["image_size"].flatten(0, 2)
123 | for path, image, image_size in zip(paths, image_results, image_sizes):
124 | dir = os.path.dirname(path)
125 | os.makedirs(dir, exist_ok=True)
126 | image.resize(tuple(image_size.int().tolist()))\
127 | .save(path, quality=95)
128 |
129 | if "raw_points" in pipeline_output:
130 | paths = [
131 | os.path.join(args.output_path, k["filename"])
132 | for i in batch["sample_data"]
133 | for j in i
134 | for k in j if k["filename"].endswith(".bin")
135 | ]
136 | raw_points = [
137 | j
138 | for i in pipeline_output["raw_points"]
139 | for j in i
140 | ]
141 | for path, points in zip(paths, raw_points):
142 | os.makedirs(os.path.dirname(path), exist_ok=True)
143 | points = points.numpy()
144 | padded_points = np.concatenate([
145 | points, np.zeros((points.shape[0], 2), dtype=np.float32)
146 | ], axis=-1)
147 | with open(path, "wb") as f:
148 | f.write(padded_points.tobytes())
149 |
--------------------------------------------------------------------------------
/src/dwm/preview.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dwm.common
3 | import json
4 | import os
5 | import torch
6 |
7 |
8 | def customize_text(clip_text, preview_config):
9 |
10 | # text
11 | if preview_config["text"] is not None:
12 | text_config = preview_config["text"]
13 |
14 | if text_config["type"] == "add":
15 | new_clip_text = \
16 | [
17 | [
18 | [
19 | text_config["prompt"] + k
20 | for k in j
21 | ]
22 | for j in i
23 | ]
24 | for i in clip_text
25 | ]
26 |
27 | elif text_config["type"] == "replace":
28 | new_clip_text = \
29 | [
30 | [
31 | [
32 | text_config["prompt"]
33 | for k in j
34 | ]
35 | for j in i
36 | ]
37 | for i in clip_text
38 | ]
39 |
40 | elif text_config["type"] == "template":
41 | time = text_config["time"]
42 | weather = text_config["weather"]
43 | new_clip_text = \
44 | [
45 | [
46 | [
47 | text_config["template"][time][weather][idx][0]
48 | for idx, k in enumerate(j)
49 | ]
50 | for j in i
51 | ]
52 | for i in clip_text
53 | ]
54 |
55 | else:
56 | raise NotImplementedError(
57 | f"{text_config['type']}has not been implemented yet.")
58 |
59 | return new_clip_text
60 |
61 | else:
62 |
63 | return clip_text
64 |
65 |
66 | def create_parser():
67 | parser = argparse.ArgumentParser(
68 | description="The script to finetune a stable diffusion model to the "
69 | "driving dataset.")
70 | parser.add_argument(
71 | "-c", "--config-path", type=str, required=True,
72 | help="The config to load the train model and dataset.")
73 | parser.add_argument(
74 | "-o", "--output-path", type=str, required=True,
75 | help="The path to save checkpoint files.")
76 | parser.add_argument(
77 | "-pc", "--preview-config-path", default=None, type=str,
78 | help="The config for preview setting")
79 | parser.add_argument(
80 | "-eic", "--export-item-config", default=False, type=bool,
81 | help="The flag to export the item config as JSON")
82 | return parser
83 |
84 |
85 | if __name__ == "__main__":
86 | parser = create_parser()
87 | args = parser.parse_args()
88 |
89 | with open(args.config_path, "r", encoding="utf-8") as f:
90 | config = json.load(f)
91 |
92 | if args.preview_config_path is not None:
93 | with open(args.preview_config_path, "r", encoding="utf-8") as f:
94 | preview_config = json.load(f)
95 | else:
96 | preview_config = None
97 |
98 | # set distributed training (if enabled), log, random number generator, and
99 | # load the checkpoint (if required).
100 | ddp = "LOCAL_RANK" in os.environ
101 | if ddp:
102 | local_rank = int(os.environ["LOCAL_RANK"])
103 | device = torch.device(config["device"], local_rank)
104 | if config["device"] == "cuda":
105 | torch.cuda.set_device(local_rank)
106 |
107 | torch.distributed.init_process_group(backend=config["ddp_backend"])
108 | else:
109 | device = torch.device(config["device"])
110 |
111 | # setup the global state
112 | if "global_state" in config:
113 | for key, value in config["global_state"].items():
114 | dwm.common.global_state[key] = \
115 | dwm.common.create_instance_from_config(value)
116 |
117 | should_log = (ddp and local_rank == 0) or not ddp
118 | should_save = not torch.distributed.is_initialized() or \
119 | torch.distributed.get_rank() == 0
120 |
121 | # load the pipeline including the models
122 | pipeline = dwm.common.create_instance_from_config(
123 | config["pipeline"], output_path=args.output_path, config=config,
124 | device=device)
125 | if should_log:
126 | print("The pipeline is loaded.")
127 |
128 | validation_dataset = dwm.common.create_instance_from_config(
129 | config["validation_dataset"])
130 |
131 | preview_dataloader = torch.utils.data\
132 | .DataLoader(
133 | validation_dataset,
134 | **dwm.common.instantiate_config(config["preview_dataloader"])) if \
135 | "preview_dataloader" in config else None
136 |
137 | if should_log:
138 | print("The validation dataset is loaded with {} items.".format(
139 | len(validation_dataset)))
140 |
141 | export_batch_except = ["vae_images"]
142 | output_path = args.output_path
143 | global_step = 0
144 | for batch in preview_dataloader:
145 | if ddp:
146 | torch.distributed.barrier()
147 |
148 | if preview_config is not None:
149 | new_clip_text = customize_text(batch["clip_text"], preview_config)
150 | batch["clip_text"] = new_clip_text
151 |
152 | pipeline.preview_pipeline(
153 | batch, output_path, global_step)
154 |
155 | if args.export_item_config:
156 | with open(
157 | os.path.join(
158 | output_path, "preview",
159 | "{}.json".format(global_step)),
160 | "w", encoding="utf-8"
161 | ) as f:
162 | json.dump({
163 | k: v.tolist() if isinstance(v, torch.Tensor) else v
164 | for k, v in batch.items()
165 | if k not in export_batch_except
166 | }, f, indent=4)
167 |
168 | global_step += 1
169 | if should_log:
170 | print(f"preview: {global_step}")
171 |
172 | if torch.distributed.is_initialized():
173 | torch.distributed.destroy_process_group()
174 |
--------------------------------------------------------------------------------
/src/dwm/common.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import importlib
3 | import io
4 | import numpy as np
5 | import os
6 | import pickle
7 |
8 |
9 | class PartialReadableRawIO(io.RawIOBase):
10 | def __init__(
11 | self, base_io_object: io.RawIOBase, start: int, end: int,
12 | close_with_this_object: bool = False
13 | ):
14 | super().__init__()
15 | self.base_io_object = base_io_object
16 | self.p = self.start = start
17 | self.end = end
18 | self.close_with_this_object = close_with_this_object
19 | self.base_io_object.seek(start)
20 |
21 | def close(self):
22 | if self.close_with_this_object:
23 | self.base_io_object.close()
24 |
25 | @property
26 | def closed(self):
27 | return self.base_io_object.closed if self.close_with_this_object \
28 | else False
29 |
30 | def readable(self):
31 | return self.base_io_object.readable()
32 |
33 | def read(self, size=-1):
34 | read_count = min(size, self.end - self.p) \
35 | if size >= 0 else self.end - self.p
36 | data = self.base_io_object.read(read_count)
37 | self.p += read_count
38 | return data
39 |
40 | def readall(self):
41 | return self.read(-1)
42 |
43 | def seek(self, offset, whence=os.SEEK_SET):
44 | if whence == os.SEEK_SET:
45 | p = max(0, min(self.end - self.start, offset))
46 | elif whence == os.SEEK_CUR:
47 | p = max(
48 | 0, min(self.end - self.start, self.p - self.start + offset))
49 | elif whence == os.SEEK_END:
50 | p = max(
51 | 0, min(self.end - self.start, self.end - self.start + offset))
52 |
53 | self.p = self.base_io_object.seek(self.start + p, os.SEEK_SET)
54 | return self.p
55 |
56 | def seekable(self):
57 | return self.base_io_object.seekable()
58 |
59 | def tell(self):
60 | return self.p - self.start
61 |
62 | def writable(self):
63 | return False
64 |
65 |
66 | class ReadonlyDictIndices:
67 | def __init__(self, base_dict_keys):
68 | sorted_table = sorted(
69 | enumerate(base_dict_keys), key=lambda i: i[1])
70 | self.sorted_keys = [i[1] for i in sorted_table]
71 | self.key_indices = np.array([i[0] for i in sorted_table], np.int64)
72 |
73 | def __len__(self):
74 | return len(self.sorted_keys)
75 |
76 | def __contains__(self, key):
77 | i = bisect.bisect_left(self.sorted_keys, key)
78 | in_range = i >= 0 and i < len(self.sorted_keys)
79 | return in_range and self.sorted_keys[i] == key
80 |
81 | def __getitem__(self, key):
82 | i = bisect.bisect_left(self.sorted_keys, key)
83 | if i < 0 or i >= len(self.sorted_keys) or self.sorted_keys[i] != key:
84 | raise KeyError("{} not found".format(key))
85 |
86 | return self.key_indices[i]
87 |
88 | def get_all_indices(self, key):
89 | i0 = bisect.bisect_left(self.sorted_keys, key)
90 | i1 = bisect.bisect_right(self.sorted_keys, key)
91 | return [self.key_indices[i] for i in range(i0, i1)] if i1 > i0 else []
92 |
93 |
94 | class SerializedReadonlyList:
95 | """A list to prevent memory divergence accessed by forked process.
96 | """
97 |
98 | def __init__(self, items: list):
99 | serialized_items = [pickle.dumps(i) for i in items]
100 | self.offsets = np.cumsum([len(i) for i in serialized_items])
101 | self.blob = b''.join(serialized_items)
102 |
103 | def __len__(self):
104 | return len(self.offsets)
105 |
106 | def __getitem__(self, index: int):
107 | start = 0 if index == 0 else self.offsets[index - 1]
108 | end = self.offsets[index]
109 | return pickle.loads(self.blob[start:end])
110 |
111 |
112 | class SerializedReadonlyDict:
113 | """A dict to prevent memory divergence accessed by forked process.
114 | """
115 |
116 | def __init__(self, base_dict: dict):
117 | self.indices = ReadonlyDictIndices(base_dict.keys())
118 | self.values = SerializedReadonlyList(base_dict.values())
119 |
120 | def __len__(self):
121 | return len(self.indices)
122 |
123 | def __contains__(self, key):
124 | return key in self.indices
125 |
126 | def __getitem__(self, key):
127 | return self.values[self.indices[key]]
128 |
129 | def keys(self):
130 | return self.indices.sorted_keys
131 |
132 |
133 | def get_class(class_name: str):
134 | if "." in class_name:
135 | i = class_name.rfind(".")
136 | module_name = class_name[:i]
137 | class_name = class_name[i+1:]
138 |
139 | module_type = importlib.import_module(module_name, package=None)
140 | class_type = getattr(module_type, class_name)
141 | elif class_name in globals():
142 | class_type = globals()[class_name]
143 | else:
144 | raise RuntimeError("Failed to find the class {}.".format(class_name))
145 |
146 | return class_type
147 |
148 |
149 | def create_instance(class_name: str, **kwargs):
150 | class_type = get_class(class_name)
151 | return class_type(**kwargs)
152 |
153 |
154 | def create_instance_from_config(_config: dict, level: int = 0, **kwargs):
155 | if isinstance(_config, dict):
156 | if "_class_name" in _config:
157 | args = instantiate_config(_config, level)
158 | if level == 0:
159 | args.update(kwargs)
160 |
161 | if _config["_class_name"] == "get_class":
162 | return get_class(**args)
163 | else:
164 | return create_instance(_config["_class_name"], **args)
165 |
166 | else:
167 | return instantiate_config(_config, level)
168 |
169 | elif isinstance(_config, list):
170 | return [create_instance_from_config(i, level + 1) for i in _config]
171 | else:
172 | return _config
173 |
174 |
175 | def instantiate_config(_config: dict, level: int = 0):
176 | return {
177 | k: create_instance_from_config(v, level + 1)
178 | for k, v in _config.items() if k != "_class_name"
179 | }
180 |
181 |
182 | def get_state(key: str):
183 | return global_state[key]
184 |
185 |
186 | global_state = {}
187 |
--------------------------------------------------------------------------------
/src/dwm/fs/s3fs.py:
--------------------------------------------------------------------------------
1 | import botocore.session
2 | import fsspec
3 | import io
4 | import os
5 | import re
6 |
7 |
8 | class S3File(io.RawIOBase):
9 |
10 | @staticmethod
11 | def find_bucket_key(s3_path):
12 | """
13 | This is a helper function that given an s3 path such that the path is of
14 | the form: bucket/key
15 | It will return the bucket and the key represented by the s3 path
16 | """
17 |
18 | bucket_format_list = [
19 | re.compile(
20 | r"^(?Parn:(aws).*:s3:[a-z\-0-9]*:[0-9]{12}:accesspoint[:/][^/]+)/?"
21 | r"(?P.*)$"
22 | ),
23 | re.compile(
24 | r"^(?Parn:(aws).*:s3-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]"
25 | r"[a-zA-Z0-9\-]{1,63}[/:](bucket|accesspoint)[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$"
26 | ),
27 | re.compile(
28 | r"^(?Parn:(aws).*:s3-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]"
29 | r"[a-zA-Z0-9\-]{1,63}[/:]bucket[/:]"
30 | r"[a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$"
31 | ),
32 | re.compile(
33 | r"^(?Parn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:"
34 | r"accesspoint[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$"
35 | ),
36 | ]
37 | for bucket_format in bucket_format_list:
38 | match = bucket_format.match(s3_path)
39 | if match:
40 | return match.group("bucket"), match.group("key")
41 |
42 | s3_components = s3_path.split("/", 1)
43 | bucket = s3_components[0]
44 | s3_key = ""
45 | if len(s3_components) > 1:
46 | s3_key = s3_components[1]
47 |
48 | return bucket, s3_key
49 |
50 | def __init__(self, client, path):
51 | super().__init__()
52 | self.client = client
53 | self.bucket, self.key = S3File.find_bucket_key(path)
54 | self.p = 0
55 | self.head = client.head_object(Bucket=self.bucket, Key=self.key)
56 |
57 | def readable(self):
58 | return True
59 |
60 | def read(self, size=-1):
61 | read_count = min(size, self.head["ContentLength"] - self.p) \
62 | if size >= 0 else self.head["ContentLength"] - self.p
63 |
64 | if read_count == 0:
65 | return b""
66 |
67 | end = self.p + read_count - 1
68 | response = self.client.get_object(
69 | Bucket=self.bucket, Key=self.key,
70 | Range="bytes={}-{}".format(self.p, end))
71 | data = response["Body"].read()
72 | self.p += read_count
73 | return data
74 |
75 | def readall(self):
76 | return self.read(-1)
77 |
78 | def seek(self, offset, whence=os.SEEK_SET):
79 | if whence == os.SEEK_SET:
80 | self.p = max(0, min(offset, self.head["ContentLength"]))
81 | elif whence == os.SEEK_CUR:
82 | self.p = max(0, min(self.p + offset, self.head["ContentLength"]))
83 | elif whence == os.SEEK_END:
84 | self.p = max(
85 | 0, min(
86 | self.head["ContentLength"] + offset,
87 | self.head["ContentLength"]))
88 |
89 | return self.p
90 |
91 | def seekable(self):
92 | return True
93 |
94 | def tell(self):
95 | return self.p
96 |
97 | def writable(self):
98 | return False
99 |
100 |
101 | class ForkableS3FileSystem(fsspec.AbstractFileSystem):
102 |
103 | root_marker = ""
104 | protocol = "s3"
105 | cachable = False
106 |
107 | """This file system can be used to access files on the S3 service, and it
108 | is also compatible with the process fork under the multi-worker situation
109 | of the PyTorch data loader.
110 |
111 | Args:
112 | kwargs: The parameters follow the
113 | [Botocore confiruation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
114 | """
115 |
116 | def __init__(self, **kwargs):
117 | # readonly, TODO: support list
118 |
119 | super().__init__()
120 | self.kwargs = kwargs
121 | self.client = botocore.session.get_session()\
122 | .create_client("s3", **kwargs)
123 |
124 | def reinit_if_forked(self):
125 | current_pid = os.getpid()
126 | if self._pid != current_pid:
127 | self.client = botocore.session.get_session()\
128 | .create_client("s3", **self.kwargs)
129 | self._pid = current_pid
130 |
131 | def _open(
132 | self, path, mode="rb", block_size=None, autocommit=True,
133 | cache_options=None, **kwargs,
134 | ):
135 | self.reinit_if_forked()
136 | return S3File(self.client, path)
137 |
138 | def ls(self, path, detail=True, **kwargs):
139 | self.reinit_if_forked()
140 | bucket, key = S3File.find_bucket_key(path)
141 | if len(key) > 0 and not key.endswith("/"):
142 | key = key + "/"
143 |
144 | # NOTE: only files are listed
145 | paths = []
146 | continuation_token = None
147 | while True:
148 | if continuation_token is None:
149 | response = self.client.list_objects(
150 | Bucket=bucket, Delimiter="/", Prefix=key)
151 | else:
152 | response = self.client.list_objects(
153 | Bucket=bucket, Delimiter="/", Prefix=key,
154 | Marker=continuation_token)
155 |
156 | if "Contents" in response:
157 | for i in response["Contents"]:
158 | paths.append({
159 | "name": "{}/{}".format(bucket, i["Key"]),
160 | "size": i["Size"],
161 | "type": "file",
162 | "Owner": i["Owner"],
163 | "ETag": i["ETag"],
164 | "LastModified": i["LastModified"]
165 | })
166 |
167 | if response["IsTruncated"]:
168 | continuation_token = response["NextMarker"]
169 | else:
170 | break
171 |
172 | if detail:
173 | return sorted(paths, key=lambda i: i["name"])
174 | else:
175 | return sorted([i["name"] for i in paths])
176 |
--------------------------------------------------------------------------------
/src/dwm/fs/dirfs.py:
--------------------------------------------------------------------------------
1 | import fsspec
2 | import fsspec.implementations.local
3 |
4 |
5 | class DirFileSystem(fsspec.AbstractFileSystem):
6 | """Directory prefix filesystem
7 |
8 | The DirFileSystem is a filesystem-wrapper. It assumes every path it is
9 | dealing with is relative to the `path`. After performing the necessary
10 | paths operation it delegates everything to the wrapped filesystem.
11 | """
12 |
13 | protocol = "dir"
14 |
15 | def __init__(self, path=None, fs=None, **kwargs):
16 | super().__init__(**kwargs)
17 | self.path = path
18 | self.fs = fsspec.implementations.local.LocalFileSystem() \
19 | if fs is None else fs
20 |
21 | @property
22 | def sep(self):
23 | return self.fs.sep
24 |
25 | def _join(self, path):
26 | if isinstance(path, str):
27 | if not self.path:
28 | return path
29 |
30 | if not path:
31 | return self.path
32 |
33 | return self.fs.sep.join((self.path, self._strip_protocol(path)))
34 |
35 | return [self._join(_path) for _path in path]
36 |
37 | def _relpath(self, path):
38 | if isinstance(path, str):
39 | if not self.path:
40 | return path
41 |
42 | if path == self.path:
43 | return ""
44 |
45 | prefix = self.path + self.fs.sep
46 | assert path.startswith(prefix)
47 | return path[len(prefix):]
48 |
49 | return [self._relpath(_path) for _path in path]
50 |
51 | def rm_file(self, path, **kwargs):
52 | return self.fs.rm_file(self._join(path), **kwargs)
53 |
54 | def rm(self, path, *args, **kwargs):
55 | return self.fs.rm(self._join(path), *args, **kwargs)
56 |
57 | def cp_file(self, path1, path2, **kwargs):
58 | return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs)
59 |
60 | def copy(self, path1, path2, *args, **kwargs):
61 | return self.fs.copy(
62 | self._join(path1), self._join(path2), *args, **kwargs)
63 |
64 | def pipe(self, path, *args, **kwargs):
65 | return self.fs.pipe(self._join(path), *args, **kwargs)
66 |
67 | def pipe_file(self, path, *args, **kwargs):
68 | return self.fs.pipe_file(self._join(path), *args, **kwargs)
69 |
70 | def cat_file(self, path, *args, **kwargs):
71 | return self.fs.cat_file(self._join(path), *args, **kwargs)
72 |
73 | def cat(self, path, *args, **kwargs):
74 | ret = self.fs.cat(self._join(path), *args, **kwargs)
75 |
76 | if isinstance(ret, dict):
77 | return {self._relpath(key): value for key, value in ret.items()}
78 |
79 | return ret
80 |
81 | def put_file(self, lpath, rpath, **kwargs):
82 | return self.fs.put_file(lpath, self._join(rpath), **kwargs)
83 |
84 | def put(self, lpath, rpath, *args, **kwargs):
85 | return self.fs.put(lpath, self._join(rpath), *args, **kwargs)
86 |
87 | def get_file(self, rpath, lpath, **kwargs):
88 | return self.fs.get_file(self._join(rpath), lpath, **kwargs)
89 |
90 | def get(self, rpath, *args, **kwargs):
91 | return self.fs.get(self._join(rpath), *args, **kwargs)
92 |
93 | def isfile(self, path):
94 | return self.fs.isfile(self._join(path))
95 |
96 | def isdir(self, path):
97 | return self.fs.isdir(self._join(path))
98 |
99 | def size(self, path):
100 | return self.fs.size(self._join(path))
101 |
102 | def exists(self, path):
103 | return self.fs.exists(self._join(path))
104 |
105 | def info(self, path, **kwargs):
106 | return self.fs.info(self._join(path), **kwargs)
107 |
108 | def ls(self, path, detail=True, **kwargs):
109 | ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy()
110 | if detail:
111 | out = []
112 | for entry in ret:
113 | entry = entry.copy()
114 | entry["name"] = self._relpath(entry["name"])
115 | out.append(entry)
116 |
117 | return out
118 |
119 | return self._relpath(ret)
120 |
121 | def walk(self, path, *args, **kwargs):
122 | for i in self.fs.walk(self._join(path), *args, **kwargs):
123 | root, dirs, files = i
124 | yield self._relpath(root), dirs, files
125 |
126 | def glob(self, path, **kwargs):
127 | detail = kwargs.get("detail", False)
128 | ret = self.fs.glob(self._join(path), **kwargs)
129 | if detail:
130 | return {self._relpath(path): info for path, info in ret.items()}
131 | return self._relpath(ret)
132 |
133 | def du(self, path, *args, **kwargs):
134 | total = kwargs.get("total", True)
135 | ret = self.fs.du(self._join(path), *args, **kwargs)
136 | if total:
137 | return ret
138 |
139 | return {self._relpath(path): size for path, size in ret.items()}
140 |
141 | def find(self, path, *args, **kwargs):
142 | detail = kwargs.get("detail", False)
143 | ret = self.fs.find(self._join(path), *args, **kwargs)
144 | if detail:
145 | return {self._relpath(path): info for path, info in ret.items()}
146 | return self._relpath(ret)
147 |
148 | def expand_path(self, path, *args, **kwargs):
149 | return self._relpath(
150 | self.fs.expand_path(self._join(path), *args, **kwargs))
151 |
152 | def mkdir(self, path, *args, **kwargs):
153 | return self.fs.mkdir(self._join(path), *args, **kwargs)
154 |
155 | def makedirs(self, path, *args, **kwargs):
156 | return self.fs.makedirs(self._join(path), *args, **kwargs)
157 |
158 | def rmdir(self, path):
159 | return self.fs.rmdir(self._join(path))
160 |
161 | def mv(self, path1, path2, **kwargs):
162 | return self.fs.mv(self._join(path1), self._join(path2), **kwargs)
163 |
164 | def touch(self, path, **kwargs):
165 | return self.fs.touch(self._join(path), **kwargs)
166 |
167 | def created(self, path):
168 | return self.fs.created(self._join(path))
169 |
170 | def modified(self, path):
171 | return self.fs.modified(self._join(path))
172 |
173 | def sign(self, path, *args, **kwargs):
174 | return self.fs.sign(self._join(path), *args, **kwargs)
175 |
176 | def __repr__(self):
177 | return "{}(path='{}', fs={})".format(
178 | self.__class__.__qualname__, self.path, self.fs)
179 |
180 | def open(self, path, *args, **kwargs):
181 | return self.fs.open(self._join(path), *args, **kwargs)
182 |
--------------------------------------------------------------------------------
/src/dwm/functional.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def create_frustum(
5 | frustum_depth_range: list, frustum_height: int, frustum_width: int,
6 | device=None
7 | ):
8 | """Create the lifted frustum from the camera view.
9 |
10 | Args:
11 | frustum_depth_range (list): 3 float numbers (start, stop, step) of the
12 | depth range. The stop is exclusive as same as the common definition
13 | of range functions.
14 | frustum_height: The frustum height.
15 | frustum_width: The frustum width.
16 | device: The prefered device of the returned frustum tensor.
17 |
18 | Returns:
19 | The lifted frustum tensor in the shape of [4, D, frustum_height,
20 | frustum_width]. The D = (stop - start) / step of the
21 | frustum_depth_range. The 4 items in the first dimension are the (nx * z,
22 | ny * z, z, 1), where nx and ny are the normalized coordinate in the
23 | range of -1 ~ 1, z is in the unit of camera coordinate system.
24 | """
25 |
26 | depth_channels = int(
27 | (frustum_depth_range[1] - frustum_depth_range[0]) /
28 | frustum_depth_range[2])
29 | x = torch.arange(
30 | 1 / frustum_width - 1, 1, 2 / frustum_width, device=device)
31 | y = torch.arange(
32 | 1 / frustum_height - 1, 1, 2 / frustum_height, device=device)
33 | z = torch.arange(*frustum_depth_range, device=device)
34 | frustum = torch.stack([
35 | x.unsqueeze(0).unsqueeze(0).repeat(depth_channels, frustum_height, 1),
36 | y.unsqueeze(0).unsqueeze(-1).repeat(depth_channels, 1, frustum_width),
37 | z.unsqueeze(-1).unsqueeze(-1).repeat(1, frustum_height, frustum_width),
38 | torch.ones(
39 | (depth_channels, frustum_height, frustum_width), device=device)
40 | ])
41 | frustum[:2] *= frustum[2:3]
42 | return frustum
43 |
44 |
45 | def make_homogeneous_matrix(a: torch.Tensor):
46 | right = torch.cat([
47 | torch.zeros(
48 | tuple(a.shape[:-1]) + (1,), dtype=a.dtype, device=a.device),
49 | torch.ones(
50 | tuple(a.shape[:-2]) + (1, 1), dtype=a.dtype, device=a.device)
51 | ], -2)
52 | return torch.cat([torch.nn.functional.pad(a, (0, 0, 0, 1)), right], -1)
53 |
54 |
55 | def make_homogeneous_vector(a: torch.Tensor):
56 | return torch.cat([
57 | a,
58 | torch.ones(tuple(a.shape[:-1]) + (1,), dtype=a.dtype, device=a.device)
59 | ], -1)
60 |
61 |
62 | def make_transform_to_frustum(
63 | image_size: torch.Tensor, frustum_width: int, frustum_height: int
64 | ):
65 | """Make the 4x4 transform from the image coordinates to the frustum
66 | coordinates.
67 |
68 | Args:
69 | image_size (torch.Tensor): The image size tensor in the shape of
70 | [..., 2], and the 2 numbers of the last dimension is
71 | (width, height).
72 | frustum_width (int): The width of the frustum.
73 | frustum_height (int): The height of the frustum.
74 |
75 | Returns:
76 | The transform matrix in the shape of [..., 4, 4].
77 | """
78 |
79 | base_shape = list(image_size.shape[:-1])
80 | frustum_size = torch.tensor(
81 | [frustum_width, frustum_height], dtype=image_size.dtype,
82 | device=image_size.device).view(*([1 for _ in base_shape] + [-1]))
83 | scale = frustum_size / image_size
84 | zeros = torch.zeros(
85 | base_shape + [4], dtype=scale.dtype, device=scale.device)
86 | ones = torch.ones(
87 | base_shape + [1], dtype=scale.dtype, device=scale.device)
88 | return torch.cat([
89 | scale[..., 0:1], zeros, scale[..., 1:2], zeros, ones, zeros, ones
90 | ], -1).unflatten(-1, (4, 4))
91 |
92 |
93 | def normalize_intrinsic_transform(
94 | image_sizes: torch.Tensor, instrinsics: torch.Tensor
95 | ):
96 | """Make the normalized 3x3 intrinsic transform from the camera coordinates
97 | to the normalized coordinates (-1 ~ 1).
98 |
99 | Args:
100 | image_sizes (torch.Tensor): The image size tensor in the shape of
101 | [..., 2], and the 2 numbers of the last dimension is
102 | (width, height).
103 | instrinsics (torch.Tensor): The camera intrinsic transform from the
104 | camera coordinates (X-right, Y-down, Z-forward) to the image
105 | coordinates (X-right, Y-down).
106 |
107 | Returns:
108 | The transform matrix in the shape of [..., 3, 3].
109 | """
110 | base_shape = list(image_sizes.shape[:-1])
111 | scale = 2 / image_sizes
112 | translate = 1 / image_sizes - 1
113 | zeros = torch.zeros(
114 | base_shape + [2], dtype=scale.dtype, device=scale.device)
115 | ones = torch.ones(
116 | base_shape + [1], dtype=scale.dtype, device=scale.device)
117 | normalization_transform = torch.cat([
118 | scale[..., 0:1], zeros[..., :1], translate[..., 0:1], zeros[..., :1],
119 | scale[..., 1:2], translate[..., 1:2], zeros, ones
120 | ], -1).unflatten(-1, (3, 3))
121 | return normalization_transform @ instrinsics
122 |
123 |
124 | def grid_sample_sequence(
125 | input: torch.Tensor, sequence: torch.Tensor, bundle_size: int = 128,
126 | mode: str = "bilinear", padding_mode: str = "border"
127 | ):
128 | count = sequence.shape[-2]
129 | bundle_count = (count + bundle_size - 1) // bundle_size
130 | assert bundle_count * bundle_size == count
131 |
132 | samples = torch.nn.functional.grid_sample(
133 | input.flatten(0, -4),
134 | sequence.unflatten(-2, (bundle_count, bundle_size)), mode=mode,
135 | padding_mode=padding_mode, align_corners=False)
136 | return samples.unflatten(0, input.shape[:-3]).flatten(-2)
137 |
138 |
139 | def _sample_logistic(
140 | shape: torch.Size, out: torch.Tensor = None,
141 | generator: torch.Generator = None
142 | ):
143 | x = torch.rand(shape, generator=generator) if out is None else \
144 | out.resize_(shape).uniform_(generator=generator)
145 | return torch.log(x) - torch.log(1 - x)
146 |
147 |
148 | def _sigmoid_sample(
149 | logits: torch.Tensor, tau: float = 1.0, generator: torch.Generator = None
150 | ):
151 | # Refer to Bernouilli reparametrization based on Maddison et al. 2017
152 | noise = _sample_logistic(logits.size(), None, generator)
153 | y = logits + noise
154 | return torch.sigmoid(y / tau)
155 |
156 |
157 | def gumbel_sigmoid(
158 | logits, tau: float = 1.0, hard: bool = False,
159 | generator: torch.Generator = None
160 | ):
161 | # use CPU random generator
162 | y_soft = _sigmoid_sample(logits.cpu(), tau, generator).to(logits.device)
163 | if hard:
164 | y_hard = torch.where(
165 | y_soft > 0.5, torch.ones_like(y_soft), torch.zeros_like(y_soft))
166 | return y_hard.data - y_soft.data + y_soft
167 |
168 | else:
169 | return y_soft
170 |
171 |
172 | def take_sequence_clip(item, start: int, stop: int):
173 | if isinstance(item, (int, float, bool, str)):
174 | return item
175 | elif isinstance(item, torch.Tensor):
176 | return item if len(item.shape) <= 1 else item[:, start:stop]
177 | elif isinstance(item, list):
178 | assert len(item) > 0 and all([isinstance(i, list) for i in item])
179 | return [i[start:stop] for i in item]
180 | else:
181 | raise Exception("Unsupported type to take sequence clip.")
182 |
183 |
184 | def memory_efficient_split_call(
185 | block: torch.nn.Module, tensor: torch.Tensor, func, split_size: int
186 | ):
187 | if split_size == -1:
188 | return func(block, tensor)
189 | else:
190 | return torch.cat([
191 | func(block, i)
192 | for i in tensor.split(split_size)
193 | ])
194 |
--------------------------------------------------------------------------------
/src/dwm/datasets/waymo_common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_inclination(inclination_range, height, dtype=np.float32):
5 | diff = inclination_range[1] - inclination_range[0]
6 | normal_range = (np.arange(0, height, dtype=dtype) + 0.5) / height
7 | inclination = normal_range * diff + inclination_range[0]
8 | return inclination
9 |
10 |
11 | def compute_range_image_polar(
12 | range_image: np.array, extrinsic: np.array, inclination: np.array,
13 | dtype=np.float32
14 | ):
15 | _, width = range_image.shape
16 | az_correction = np.arctan2(extrinsic[1, 0], extrinsic[0, 0])
17 | ratios = (np.arange(width, 0, -1, dtype=dtype) - 0.5) / width
18 | azimuth = (ratios * 2 - 1) * np.pi - az_correction
19 | return azimuth[None, :], inclination[:, None], range_image
20 |
21 |
22 | def get_rotation_matrix(
23 | roll: np.array, pitch: np.array, yaw: np.array,
24 | ):
25 | """Gets a rotation matrix given roll, pitch, yaw.
26 |
27 | roll-pitch-yaw is z-y'-x'' intrinsic rotation which means we need to apply
28 | x(roll) rotation first, then y(pitch) rotation, then z(yaw) rotation.
29 |
30 | https://en.wikipedia.org/wiki/Euler_angles
31 | http://planning.cs.uiuc.edu/node102.html
32 |
33 | Args:
34 | roll : x-rotation in radians.
35 | pitch: y-rotation in radians. The shape must be the same as roll.
36 | yaw: z-rotation in radians. The shape must be the same as roll.
37 |
38 | Returns:
39 | A rotation tensor with the same data type of the input. Its shape is
40 | [3 ,3].
41 | """
42 |
43 | cos_roll = np.cos(roll)
44 | sin_roll = np.sin(roll)
45 | cos_yaw = np.cos(yaw)
46 | sin_yaw = np.sin(yaw)
47 | cos_pitch = np.cos(pitch)
48 | sin_pitch = np.sin(pitch)
49 | ones = np.ones_like(cos_yaw)
50 | zeros = np.zeros_like(cos_yaw)
51 |
52 | r_roll = np.stack([
53 | np.stack([ones, zeros, zeros], axis=-1),
54 | np.stack([zeros, cos_roll, -sin_roll], axis=-1),
55 | np.stack([zeros, sin_roll, cos_roll], axis=-1),
56 | ], axis=-2)
57 | r_pitch = np.stack([
58 | np.stack([cos_pitch, zeros, sin_pitch], axis=-1),
59 | np.stack([zeros, ones, zeros], axis=-1),
60 | np.stack([-sin_pitch, zeros, cos_pitch], axis=-1),
61 | ], axis=-2)
62 | r_yaw = np.stack([
63 | np.stack([cos_yaw, -sin_yaw, zeros], axis=-1),
64 | np.stack([sin_yaw, cos_yaw, zeros], axis=-1),
65 | np.stack([zeros, zeros, ones], axis=-1),
66 | ], axis=-2)
67 |
68 | return np.matmul(r_yaw, np.matmul(r_pitch, r_roll))
69 |
70 |
71 | def compute_range_image_cartesian(
72 | azimuth: np.array, inclination: np.array, range_image_range: np.array,
73 | extrinsic: np.array, pixel_pose=None, frame_pose=None,
74 | ):
75 | """Computes range image cartesian coordinates from polar ones.
76 |
77 | Args:
78 | range_image_polar: [B, H, W, 3] float tensor. Lidar range image in
79 | polar coordinate in sensor frame.
80 | extrinsic: [B, 4, 4] float tensor. Lidar extrinsic.
81 | pixel_pose: [B, H, W, 4, 4] float tensor. If not None, it sets pose for
82 | each range image pixel.
83 | frame_pose: [B, 4, 4] float tensor. This must be set when pixel_pose is
84 | set. It decides the vehicle frame at which the cartesian points are
85 | computed.
86 | dtype: float type to use internally. This is needed as extrinsic and
87 | inclination sometimes have higher resolution than range_image.
88 |
89 | Returns:
90 | range_image_cartesian: [B, H, W, 3] cartesian coordinates.
91 | """
92 |
93 | cos_azimuth = np.cos(azimuth)
94 | sin_azimuth = np.sin(azimuth)
95 | cos_incl = np.cos(inclination)
96 | sin_incl = np.sin(inclination)
97 |
98 | # [H, W].
99 | x = cos_azimuth * cos_incl * range_image_range
100 | y = sin_azimuth * cos_incl * range_image_range
101 | z = sin_incl * range_image_range
102 |
103 | # [H, W, 3]
104 | range_image_points = np.stack([x, y, z], -1)
105 |
106 | # To vehicle frame.
107 | rotation = extrinsic[0:3, 0:3].T
108 | translation = extrinsic[0:3, 3:4].T
109 | range_image_points = range_image_points @ rotation + translation
110 |
111 | if pixel_pose is not None:
112 | # To global frame.
113 | # [H, W, 3, 3]
114 | pixel_pose_rotation = np.swapaxes(
115 | get_rotation_matrix(
116 | pixel_pose[..., 0], pixel_pose[..., 1], pixel_pose[..., 2]),
117 | -1, -2)
118 |
119 | # [H, W, 1, 3]
120 | pixel_pose_translation = pixel_pose[..., None, 3:]
121 | range_image_points = (
122 | range_image_points[..., None, :] @ pixel_pose_rotation +
123 | pixel_pose_translation)
124 |
125 | # [H, W, 3]
126 | range_image_points = range_image_points.reshape(
127 | *range_image_points.shape[:-2], -1)
128 |
129 | if frame_pose is None:
130 | raise ValueError('frame_pose must be set when pixel_pose is set.')
131 |
132 | # To vehicle frame corresponding to the given frame_pose
133 | # [4, 4]
134 | world_to_vehicle = np.linalg.inv(frame_pose)
135 | world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3].T
136 | world_to_vehicle_translation = world_to_vehicle[0:3, 3:4].T
137 |
138 | # [H, W, 3]
139 | range_image_points = range_image_points @ world_to_vehicle_rotation + \
140 | world_to_vehicle_translation
141 |
142 | return range_image_points
143 |
144 |
145 | def convert_range_image_to_cartesian(
146 | range_image: np.array, calibration: dict, lidar_pose=None, frame_pose=None,
147 | dtype=np.float32
148 | ):
149 | """Converts one range image from polar coordinates to Cartesian coordinates.
150 |
151 | Args:
152 | range_image: One range image return captured by a LiDAR sensor.
153 | calibration: Parameters for calibration of a LiDAR sensor.
154 |
155 | Returns:
156 | A [H, W, 3] image in Cartesian coordinates.
157 | """
158 |
159 | extrinsic = np.array(
160 | calibration["[LiDARCalibrationComponent].extrinsic.transform"],
161 | dtype).reshape(4, 4)
162 |
163 | # Compute inclinations mapping range image rows to circles in the 3D worlds.
164 | bi_values_key = "[LiDARCalibrationComponent].beam_inclination.values"
165 | bi_min_key = "[LiDARCalibrationComponent].beam_inclination.min"
166 | bi_max_key = "[LiDARCalibrationComponent].beam_inclination.max"
167 | if calibration[bi_values_key] is not None:
168 | inclination = np.array(calibration[bi_values_key], dtype)
169 | else:
170 | inclination = compute_inclination(
171 | (calibration[bi_min_key], calibration[bi_max_key]),
172 | range_image.shape[0], dtype)
173 |
174 | inclination = np.flip(inclination, axis=-1)
175 |
176 | # Compute points from the range image
177 | azimuth, inclination, range_image_range = compute_range_image_polar(
178 | range_image[..., 0], extrinsic, inclination, dtype)
179 | range_image_cartesian = compute_range_image_cartesian(
180 | azimuth, inclination, range_image_range, extrinsic,
181 | pixel_pose=lidar_pose, frame_pose=frame_pose)
182 | range_image_mask = range_image[..., 0] > 0
183 |
184 | flatten_range_image_cartesian = range_image_cartesian.reshape(-1, 3)
185 | flatten_range_image_mask = range_image_mask.reshape(-1)
186 | return flatten_range_image_cartesian[flatten_range_image_mask]
187 |
188 |
189 | def laser_calibration_to_dict(lc):
190 | lc.beam_inclination_max
191 | return {
192 | "key.laser_name": lc.name,
193 | "[LiDARCalibrationComponent].beam_inclination.values":
194 | lc.beam_inclinations,
195 | "[LiDARCalibrationComponent].beam_inclination.min":
196 | lc.beam_inclination_min,
197 | "[LiDARCalibrationComponent].beam_inclination.max":
198 | lc.beam_inclination_max,
199 | "[LiDARCalibrationComponent].extrinsic.transform":
200 | lc.extrinsic.transform
201 | }
202 |
--------------------------------------------------------------------------------
/src/dwm/models/maskgit_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from diffusers.models.attention import maybe_allow_in_graph
5 | from diffusers.models.attention import _chunked_feed_forward
6 | from diffusers.models.attention import Attention
7 | from diffusers.models.attention import FeedForward
8 | from diffusers.models.normalization import AdaLayerNormZero
9 | from diffusers.models.normalization import AdaLayerNormContinuous
10 | from diffusers.models.normalization import SD35AdaLayerNormZeroX
11 | from typing import Optional
12 | from einops import rearrange
13 |
14 |
15 | @maybe_allow_in_graph
16 | class TemporalTransformerBlock(nn.Module):
17 | r"""
18 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
19 |
20 | Reference: https://arxiv.org/abs/2403.03206
21 |
22 | Parameters:
23 | dim (`int`): The number of channels in the input and output.
24 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
25 | attention_head_dim (`int`): The number of channels in each head.
26 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
27 | processing of `context` conditions.
28 | """
29 |
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_attention_heads: int,
34 | attention_head_dim: int,
35 | context_pre_only: bool = False,
36 | qk_norm: Optional[str] = None,
37 | ):
38 | super().__init__()
39 |
40 | self.context_pre_only = context_pre_only
41 | self.norm1 = nn.LayerNorm(dim, eps=1e-6)
42 |
43 | if hasattr(F, "scaled_dot_product_attention"):
44 | processor = TemporalAttnProcessor()
45 | else:
46 | raise ValueError(
47 | "The current PyTorch version does not support the `scaled_dot_product_attention` function."
48 | )
49 | self.attn = Attention(
50 | query_dim=dim,
51 | cross_attention_dim=None,
52 | added_kv_proj_dim=None,
53 | dim_head=attention_head_dim,
54 | heads=num_attention_heads,
55 | out_dim=dim,
56 | context_pre_only=None,
57 | bias=True,
58 | processor=processor,
59 | qk_norm=qk_norm,
60 | eps=1e-6,
61 | )
62 |
63 | self.norm2 = nn.LayerNorm(dim, eps=1e-6)
64 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
65 |
66 | # let chunk size default to None
67 | self._chunk_size = None
68 | self._chunk_dim = 0
69 |
70 | # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
71 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
72 | # Sets chunk feed-forward
73 | self._chunk_size = chunk_size
74 | self._chunk_dim = dim
75 |
76 | def forward(
77 | self, hidden_states: torch.FloatTensor, rotary_emb: Optional[torch.nn.Module] = None
78 | ):
79 | norm_hidden_states = self.norm1(hidden_states)
80 | # Attention.
81 | attn_output = self.attn(
82 | hidden_states=norm_hidden_states,
83 | rotary_emb=rotary_emb
84 | )
85 |
86 | # Process attention outputs for the `hidden_states`.
87 | hidden_states = hidden_states + attn_output
88 |
89 |
90 | norm_hidden_states = self.norm2(hidden_states)
91 | if self._chunk_size is not None:
92 | # "feed_forward_chunk_size" can be used to save memory
93 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
94 | else:
95 | ff_output = self.ff(norm_hidden_states)
96 |
97 | hidden_states = hidden_states + ff_output
98 |
99 | return hidden_states
100 |
101 |
102 | class TemporalAttnProcessor:
103 | """Attention processor used typically in processing the SD3-like self-attention projections."""
104 |
105 | def __init__(self):
106 | if not hasattr(F, "scaled_dot_product_attention"):
107 | raise ImportError("TemporalAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
108 |
109 | def __call__(
110 | self,
111 | attn: Attention,
112 | hidden_states: torch.FloatTensor,
113 | encoder_hidden_states: torch.FloatTensor = None,
114 | attention_mask: Optional[torch.FloatTensor] = None,
115 | rotary_emb: Optional[torch.nn.Module] = None,
116 | *args,
117 | **kwargs,
118 | ) -> torch.FloatTensor:
119 | residual = hidden_states
120 |
121 | batch_size = hidden_states.shape[0]
122 |
123 | # `sample` projections.
124 | query = attn.to_q(hidden_states)
125 | key = attn.to_k(hidden_states)
126 | value = attn.to_v(hidden_states)
127 |
128 | inner_dim = key.shape[-1]
129 | head_dim = inner_dim // attn.heads
130 |
131 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
132 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
133 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
134 |
135 | if rotary_emb is not None:
136 | query = rotary_emb(query)
137 | key = rotary_emb(key)
138 |
139 | if attn.norm_q is not None:
140 | query = attn.norm_q(query)
141 | if attn.norm_k is not None:
142 | key = attn.norm_k(key)
143 |
144 | # `context` projections.
145 | if encoder_hidden_states is not None:
146 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
147 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
148 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
149 |
150 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
151 | batch_size, -1, attn.heads, head_dim
152 | ).transpose(1, 2)
153 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
154 | batch_size, -1, attn.heads, head_dim
155 | ).transpose(1, 2)
156 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
157 | batch_size, -1, attn.heads, head_dim
158 | ).transpose(1, 2)
159 |
160 | # if attn.norm_added_q is not None:
161 | # encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
162 | # if attn.norm_added_k is not None:
163 | # encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
164 |
165 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
166 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
167 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
168 |
169 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
170 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
171 | hidden_states = hidden_states.to(query.dtype)
172 |
173 | if encoder_hidden_states is not None:
174 | # Split the attention outputs.
175 | hidden_states, encoder_hidden_states = (
176 | hidden_states[:, : residual.shape[1]],
177 | hidden_states[:, residual.shape[1] :],
178 | )
179 | # if not attn.context_pre_only:
180 | # encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
181 |
182 | # linear proj
183 | hidden_states = attn.to_out[0](hidden_states)
184 | # dropout
185 | hidden_states = attn.to_out[1](hidden_states)
186 |
187 | if encoder_hidden_states is not None:
188 | return hidden_states, encoder_hidden_states
189 | else:
190 | return hidden_states
191 |
192 |
--------------------------------------------------------------------------------
/configs/experimental/simulation/make_carla_cameras_from_nuscenes.json:
--------------------------------------------------------------------------------
1 | {
2 | "CAM_FRONT_LEFT": {
3 | "image_size": [
4 | 1600,
5 | 900
6 | ],
7 | "intrinsic": [
8 | [
9 | 1272.5979470598488,
10 | 0.0,
11 | 826.6154927353808
12 | ],
13 | [
14 | 0.0,
15 | 1272.5979470598488,
16 | 479.75165386361925
17 | ],
18 | [
19 | 0.0,
20 | 0.0,
21 | 1.0
22 | ]
23 | ],
24 | "transform": [
25 | [
26 | 0.8207583481515547,
27 | -0.00034143667149916235,
28 | 0.5712754303840931,
29 | 1.52387798135
30 | ],
31 | [
32 | -0.5712716007618155,
33 | 0.0032195018514857843,
34 | 0.8207547703003994,
35 | 0.494631336551
36 | ],
37 | [
38 | -0.002119458082718406,
39 | -0.9999947591006804,
40 | 0.00244737994761024,
41 | 1.50932822144
42 | ],
43 | [
44 | 0.0,
45 | 0.0,
46 | 0.0,
47 | 1.0
48 | ]
49 | ]
50 | },
51 | "CAM_FRONT": {
52 | "image_size": [
53 | 1600,
54 | 900
55 | ],
56 | "intrinsic": [
57 | [
58 | 1252.8131021185304,
59 | 0.0,
60 | 826.588114781398
61 | ],
62 | [
63 | 0.0,
64 | 1252.8131021185304,
65 | 469.9846626224581
66 | ],
67 | [
68 | 0.0,
69 | 0.0,
70 | 1.0
71 | ]
72 | ],
73 | "transform": [
74 | [
75 | 0.01026020777540071,
76 | 0.008433448071667293,
77 | 0.9999117986552757,
78 | 1.72200568478
79 | ],
80 | [
81 | -0.9998725753702897,
82 | 0.012316255772487295,
83 | 0.01015592763520401,
84 | 0.00475453292289
85 | ],
86 | [
87 | -0.012229519973835201,
88 | -0.9998885871922779,
89 | 0.008558740785910124,
90 | 1.49491291905
91 | ],
92 | [
93 | 0.0,
94 | 0.0,
95 | 0.0,
96 | 1.0
97 | ]
98 | ]
99 | },
100 | "CAM_FRONT_RIGHT": {
101 | "image_size": [
102 | 1600,
103 | 900
104 | ],
105 | "intrinsic": [
106 | [
107 | 1256.7485116440405,
108 | 0.0,
109 | 817.7887570959712
110 | ],
111 | [
112 | 0.0,
113 | 1256.7485116440403,
114 | 451.9541780095127
115 | ],
116 | [
117 | 0.0,
118 | 0.0,
119 | 1.0
120 | ]
121 | ],
122 | "transform": [
123 | [
124 | -0.8439797263539992,
125 | 0.01645551436192705,
126 | 0.5361225956350965,
127 | 1.58082565783
128 | ],
129 | [
130 | -0.5361413772261798,
131 | 0.003621074712163108,
132 | -0.8441204365752225,
133 | -0.499078711449
134 | ],
135 | [
136 | -0.015831775940933213,
137 | -0.9998580418564494,
138 | 0.005766368488289819,
139 | 1.51749368405
140 | ],
141 | [
142 | 0.0,
143 | 0.0,
144 | 0.0,
145 | 1.0
146 | ]
147 | ]
148 | },
149 | "CAM_BACK_RIGHT": {
150 | "image_size": [
151 | 1600,
152 | 900
153 | ],
154 | "intrinsic": [
155 | [
156 | 1259.5137405846733,
157 | 0.0,
158 | 807.2529053838625
159 | ],
160 | [
161 | 0.0,
162 | 1259.5137405846733,
163 | 501.19579884916527
164 | ],
165 | [
166 | 0.0,
167 | 0.0,
168 | 1.0
169 | ]
170 | ],
171 | "transform": [
172 | [
173 | -0.9347755391782977,
174 | 0.01587583795544978,
175 | -0.3548839938953783,
176 | 1.0148780988
177 | ],
178 | [
179 | 0.3550745592741918,
180 | 0.011370495332616137,
181 | -0.9347688319537238,
182 | -0.480568219723
183 | ],
184 | [
185 | -0.010805031705694884,
186 | -0.9998093166224763,
187 | -0.01626596707041017,
188 | 1.56239545128
189 | ],
190 | [
191 | 0.0,
192 | 0.0,
193 | 0.0,
194 | 1.0
195 | ]
196 | ]
197 | },
198 | "CAM_BACK": {
199 | "image_size": [
200 | 1600,
201 | 900
202 | ],
203 | "intrinsic": [
204 | [
205 | 809.2209905677063,
206 | 0.0,
207 | 829.2196003259838
208 | ],
209 | [
210 | 0.0,
211 | 809.2209905677063,
212 | 481.77842384512485
213 | ],
214 | [
215 | 0.0,
216 | 0.0,
217 | 1.0
218 | ]
219 | ],
220 | "transform": [
221 | [
222 | 0.002421709860318977,
223 | -0.016753608478469628,
224 | -0.9998567156969553,
225 | 0.0283260309358
226 | ],
227 | [
228 | 0.9999890666843356,
229 | -0.003959107249965843,
230 | 0.002488369260045864,
231 | 0.00345136761476
232 | ],
233 | [
234 | -0.004000229136375599,
235 | -0.9998518100562368,
236 | 0.016743837496914438,
237 | 1.57910346144
238 | ],
239 | [
240 | 0.0,
241 | 0.0,
242 | 0.0,
243 | 1.0
244 | ]
245 | ]
246 | },
247 | "CAM_BACK_LEFT": {
248 | "image_size": [
249 | 1600,
250 | 900
251 | ],
252 | "intrinsic": [
253 | [
254 | 1256.7414812095406,
255 | 0.0,
256 | 792.1125740759628
257 | ],
258 | [
259 | 0.0,
260 | 1256.7414812095406,
261 | 492.7757465151356
262 | ],
263 | [
264 | 0.0,
265 | 0.0,
266 | 1.0
267 | ]
268 | ],
269 | "transform": [
270 | [
271 | 0.9477603556843314,
272 | 0.008665721547931132,
273 | -0.31886550999310576,
274 | 1.03569100218
275 | ],
276 | [
277 | 0.31896113144149074,
278 | -0.01397629983587656,
279 | 0.9476647401230364,
280 | 0.484795032713
281 | ],
282 | [
283 | 0.0037556387837154315,
284 | -0.9998647750135773,
285 | -0.016010211253286277,
286 | 1.59097014818
287 | ],
288 | [
289 | 0.0,
290 | 0.0,
291 | 0.0,
292 | 1.0
293 | ]
294 | ]
295 | }
296 | }
--------------------------------------------------------------------------------
/src/dwm/utils/carla_control.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import carla
3 | import tkinter
4 | import tkinter.font
5 | import tkinter.ttk
6 |
7 |
8 | class KeyPanel():
9 |
10 | def __init__(
11 | self, master, title: str, detail: str, style_class: str = "Inactivated"
12 | ):
13 | default_font = tkinter.font.nametofont("TkDefaultFont")
14 | default_font_family = default_font.cget("family")
15 | self.key_panel = tkinter.ttk.Frame(
16 | master, style="{}.TFrame".format(style_class))
17 | self.label_group = tkinter.ttk.Frame(
18 | self.key_panel, style="{}.TFrame".format(style_class))
19 | self.title = tkinter.ttk.Label(
20 | self.label_group, text=title,
21 | style="{}.TLabel".format(style_class),
22 | font=(default_font_family, 18), padding=(0, -4, 0, -4))
23 | self.detail = tkinter.ttk.Label(
24 | self.label_group, text=detail,
25 | style="{}.TLabel".format(style_class),
26 | font=(default_font_family, 10), padding=(0, -2, 0, -2))
27 |
28 | self.label_group.place(relx=0.5, rely=0.5, anchor="center")
29 | self.title.pack(anchor="center")
30 | self.detail.pack(anchor="center")
31 |
32 | def set_style_class(self, style_class: str):
33 | self.key_panel.configure(style="{}.TFrame".format(style_class))
34 | self.label_group.configure(style="{}.TFrame".format(style_class))
35 | self.title.configure(style="{}.TLabel".format(style_class))
36 | self.detail.configure(style="{}.TLabel".format(style_class))
37 |
38 |
39 | class KeyboardControlPanel():
40 |
41 | def __init__(self, master, hero_vehicle=None, style_config=None):
42 | self.master = master
43 | self.hero_vehicle = hero_vehicle
44 |
45 | default_style_config = {
46 | "Inactivated.TFrame": {
47 | "background": self.master.cget("background")
48 | },
49 | "Inactivated.TLabel": {
50 | "background": self.master.cget("background"),
51 | "foreground": "black",
52 | },
53 | "Activated.TFrame": {
54 | "background": "dimgray"
55 | },
56 | "Activated.TLabel": {
57 | "background": "dimgray",
58 | "foreground": "white",
59 | }
60 | }
61 |
62 | self.style = tkinter.ttk.Style()
63 | for k, v in (style_config or default_style_config).items():
64 | self.style.configure(k, **v)
65 |
66 | self.frame = tkinter.ttk.Frame(master, padding=2)
67 | self.label_reverse = KeyPanel(self.frame, title="Q", detail="Reverse")
68 | self.label_up = KeyPanel(self.frame, title="W", detail="Throttle")
69 | self.label_autopilot = KeyPanel(
70 | self.frame, title="E", detail="Auto pilot")
71 | self.label_left = KeyPanel(self.frame, title="A", detail="Left")
72 | self.label_down = KeyPanel(self.frame, title="S", detail="Brake")
73 | self.label_right = KeyPanel(self.frame, title="D", detail="Right")
74 |
75 | self.pressed_key = {}
76 | self.is_auto = False
77 | self.reverse = False
78 |
79 | def setup_layout(self):
80 | for i in range(2):
81 | self.frame.grid_rowconfigure(i, weight=1)
82 |
83 | for i in range(3):
84 | self.frame.grid_columnconfigure(i, weight=1)
85 |
86 | grid_args = {
87 | "padx": 2,
88 | "pady": 2,
89 | "sticky": tkinter.NSEW
90 | }
91 | self.frame.pack(fill=tkinter.BOTH, expand=True)
92 | self.label_reverse.key_panel.grid(column=0, row=0, **grid_args)
93 | self.label_up.key_panel.grid(column=1, row=0, **grid_args)
94 | self.label_autopilot.key_panel.grid(column=2, row=0, **grid_args)
95 | self.label_left.key_panel.grid(column=0, row=1, **grid_args)
96 | self.label_down.key_panel.grid(column=1, row=1, **grid_args)
97 | self.label_right.key_panel.grid(column=2, row=1, **grid_args)
98 |
99 | def update_manual_control(self):
100 | control = carla.VehicleControl()
101 | control.throttle = \
102 | 0.8 if any([i in self.pressed_key for i in ["w", "Up"]]) else 0
103 | control.steer = (
104 | -0.8 if any([i in self.pressed_key for i in ["a", "Left"]]) else 0
105 | ) + (
106 | 0.8 if any([i in self.pressed_key for i in ["d", "Right"]]) else 0
107 | )
108 | control.brake = \
109 | 1.0 if any([i in self.pressed_key for i in ["s", "Down"]]) else 0
110 | control.reverse = self.reverse
111 | self.hero_vehicle.apply_control(control)
112 |
113 | def on_key_pressed_event(self, event):
114 | self.pressed_key[event.keysym] = True
115 | if event.keysym in ["w", "Up"]:
116 | self.label_up.set_style_class("Activated")
117 | elif event.keysym in ["a", "Left"]:
118 | self.label_left.set_style_class("Activated")
119 | elif event.keysym in ["d", "Right"]:
120 | self.label_right.set_style_class("Activated")
121 | elif event.keysym in ["s", "Down"]:
122 | self.label_down.set_style_class("Activated")
123 |
124 | if self.hero_vehicle is not None and not self.is_auto:
125 | self.update_manual_control()
126 |
127 | def on_key_released_event(self, event):
128 | if event.keysym == "e":
129 | self.is_auto = not self.is_auto
130 | self.label_autopilot.set_style_class(
131 | "Activated" if self.is_auto else "Inactivated")
132 | if self.hero_vehicle is not None:
133 | self.hero_vehicle.set_autopilot(self.is_auto)
134 | elif event.keysym == "q":
135 | self.reverse = not self.reverse
136 | self.label_reverse.set_style_class(
137 | "Activated" if self.reverse else "Inactivated")
138 | elif event.keysym in ["w", "Up"]:
139 | self.label_up.set_style_class("Inactivated")
140 | elif event.keysym in ["a", "Left"]:
141 | self.label_left.set_style_class("Inactivated")
142 | elif event.keysym in ["d", "Right"]:
143 | self.label_right.set_style_class("Inactivated")
144 | elif event.keysym in ["s", "Down"]:
145 | self.label_down.set_style_class("Inactivated")
146 |
147 | if event.keysym in self.pressed_key:
148 | del self.pressed_key[event.keysym]
149 |
150 | if self.hero_vehicle is not None and not self.is_auto:
151 | self.update_manual_control()
152 |
153 |
154 | def create_parser():
155 | parser = argparse.ArgumentParser(
156 | description="Carla control Client")
157 | parser.add_argument(
158 | "--host", default="127.0.0.1", type=str,
159 | help="The host address of the Carla simulator.")
160 | parser.add_argument(
161 | "-p", "--port", default=2000, type=int,
162 | help="The port of the Carla simulator.")
163 | parser.add_argument(
164 | "--client-timeout", default=10.0, type=float,
165 | help="The timeout of the Carla client.")
166 | return parser
167 |
168 |
169 | if __name__ == "__main__":
170 | parser = create_parser()
171 | args = parser.parse_args()
172 |
173 | client = carla.Client(args.host, args.port, 1)
174 | client.set_timeout(args.client_timeout)
175 | world = client.get_world()
176 | world.wait_for_tick()
177 |
178 | hero_vehicle, = [
179 | i for i in world.get_actors()
180 | if (
181 | i.type_id.startswith("vehicle") and
182 | i.attributes.get("role_name") == "hero"
183 | )
184 | ]
185 | print("Hero vehicle: {}".format(hero_vehicle.id))
186 |
187 | window_args = {
188 | "title": "Carla Control",
189 | "geometry": "244x124"
190 | }
191 | window = tkinter.Tk()
192 | for k, v in window_args.items():
193 | getattr(window, k)(v)
194 |
195 | control_panel = KeyboardControlPanel(window, hero_vehicle)
196 | control_panel.setup_layout()
197 | window.bind("", control_panel.on_key_pressed_event)
198 | window.bind("", control_panel.on_key_released_event)
199 | window.mainloop()
200 |
--------------------------------------------------------------------------------
/docs/Datasets.md:
--------------------------------------------------------------------------------
1 |
2 | # Datasets
3 |
4 | Currently we support 4 datasets: nuScenes, Waymo Perception, Argoverse 2 Sensor, OpenDV.
5 |
6 | ## nuScenes
7 |
8 | 1. Download the [nuScenes](https://www.nuscenes.org/download) dataset files to `{NUSCENES_TGZ_ROOT}` on your file system. After the dataset is downloaded, there will be some `*.tgz` files under path `{NUSCENES_TGZ_ROOT}`.
9 |
10 | 2. Since the TGZ format does not support random access to content, we recommend converting these files to ZIP format using the following command lines:
11 |
12 | ```
13 | mkdir -p {NUSCENES_ZIP_ROOT}
14 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval_meta.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval_meta.zip
15 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval01_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval01_blobs.zip
16 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval02_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval02_blobs.zip
17 | ...
18 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval10_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval10_blobs.zip
19 | ```
20 |
21 | 3. Now the `{NUSCENES_ZIP_ROOT}` is ready to update the nuScenes file system of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L12).
22 |
23 | 4. Prepare the HD map data.
24 |
25 | 1. Download the `nuScenes-map-expansion-v1.3.zip` file from the [nuScenes](https://www.nuscenes.org/download) to `{NUSCENES_ZIP_ROOT}`.
26 |
27 | 2. Add the file into the [config](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L27), so the dataset can load the map data.
28 |
29 | 5. *Optional*. When the 3D box conditions are used for training, the 12hz metadata is recommended.
30 |
31 | 1. Download 12 Hz nuScenes meta from [Corner Case Scene Generation](https://coda-dataset.github.io/w-coda2024/track2/). After the metadata is downloaded, there will be `interp_12Hz.tar` file.
32 |
33 | 2. Extract and repack the 12 Hz metadata to `interp_12Hz_trainval.zip`, then update the [FS](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L15) and [dataset name](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L206) in the config.
34 |
35 | ```
36 | python -m tarfile -e interp_12Hz.tar
37 | cd data/nuscenes
38 | python -m zipfile -c ../../interp_12Hz_trainval.zip interp_12Hz_trainval/
39 | cd ../..
40 | rm -rf data/
41 | ```
42 |
43 | 6. *Alternative solution for 5*. In the case of a broken download link, you can also regenerate 12 Hz annotations according to the instructions of [ASAP](https://github.com/JeffWang987/ASAP/blob/main/docs/prepare_data.md) from the origin nuScenes dataset.
44 |
45 | 7. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images)
46 |
47 | ## Waymo
48 |
49 | There are two versions of the Waymo Perception dataset. This project chooses version 1 (>= 1.4.2) because only this version provides HD map annotation, while version 2 does not provide HD map annotation.
50 |
51 | 1. *Optional*. The Waymo Perception 1.x requires protobuffer, if you try to avoid installing waymo_open_dataset and its dependencies, you need to compile the proto files. Install the [proto buffer compiler](https://github.com/protocolbuffers/protobuf/releases/tag/v25.4), then run following commands to compile proto files. After compilation, `import waymo_open_dataset.dataset_pb2` works by adding `externals/waymo-open-dataset/src` to the environmant variable `PYTHONPATH`.
52 |
53 | ```
54 | cd externals/waymo-open-dataset/src
55 | protoc --proto_path=. --python_out=. waymo_open_dataset/*.proto
56 | protoc --proto_path=. --python_out=. waymo_open_dataset/protos/*.proto
57 | ```
58 |
59 | 2. Download the [Waymo Perception](https://waymo.com/open/download) dataset (>= 1.4.2 for the annotation of HD map) to `{WAYMO_ROOT}`. After the dataset is downloaded, there will be some `*.tfrecord` files under the path `{WAYMO_ROOT}/training` and `{WAYMO_ROOT}/validation`.
60 |
61 | 3. Then make information JSON files to support inner-scene random access, by
62 |
63 | ```
64 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt waymo -i {WAYMO_ROOT}/training -o {WAYMO_ROOT}/training.info.json
65 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt waymo -i {WAYMO_ROOT}/validation -o {WAYMO_ROOT}/validation.info.json
66 | ```
67 |
68 | 4. Now the `{WAYMO_ROOT}` and its information JSON files are ready to update the Waymo dataset of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_waymo.json#L182).
69 |
70 | 5. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images)
71 |
72 | ## Argoverse
73 |
74 | 1. Download the [Argoverse 2 Sensor](https://www.argoverse.org/av2.html#download-link) dataset files to `{ARGOVERSE_ROOT}` on your file system. After the dataset is downloaded, there will be some `*.tar` files under path `{ARGOVERSE_ROOT}`.
75 |
76 | 2. Then make information JSON files to accelerate the loading speed, by:
77 |
78 | ```
79 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt argoverse -i {ARGOVERSE_ROOT} -o {ARGOVERSE_ROOT}
80 | ```
81 |
82 | 3. Now the `{ARGOVERSE_ROOT}` is ready to update the Argoverse file system of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_argo.json#L184).
83 |
84 | 4. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images)
85 |
86 | ## OpenDV
87 |
88 | 1. Download the [OpenDV](https://github.com/OpenDriveLab/DriveAGI/tree/main/opendv) dataset video files to `{OPENDV_ORIGIN_ROOT}` on your file system, and the meta file to `{OPENDV_JSON_META_PATH}` prepared as JSON format. After the dataset is downloaded, there will be about 2K video files under the path `{OPENDV_ORIGIN_ROOT}` in the format of `.mp4` and `.webp`.
89 |
90 | 2. *Optional.* It is recommended to transcode the original video files for better read and seek performance during training, by:
91 |
92 | ```
93 | apt update && apt install -y ffmpeg
94 | python src/dwm/tools/transcode_video.py -c src/dwm/tools/transcode_video.json -i {OPENDV_ORIGIN_ROOT} -o {OPENDV_ROOT}
95 | ```
96 |
97 | 3. Now the `{OPENDV_ORIGIN_ROOT}` (or `{OPENDV_ROOT}`) is ready to update the OpenDV [file system config](../configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json#L31), and `{OPENDV_JSON_META_PATH}` to update the [dataset config](../configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json#L409).
98 |
99 | ## KITTI360
100 |
101 | Register an account and download the [KITTI360](https://www.cvlibs.net/datasets/kitti-360/download.php) dataset. We only require the LiDAR data from KITTI360, so you only need to download the Raw Velodyne Scans, 3D Bounding Boxes, and Vehicle Poses. The scenes `2013_05_28_drive_0000_sync` and `2013_05_28_drive_0002_sync` are used for validation, while all other scenes are used as training data. You may keep the files in their original zip format, as our code processes them automatically.
102 |
103 | ## Text description for images
104 |
105 | We made the image captions for both nuScenes, Waymo, Argoverse, OpenDV datasets by [DriveMLM](https://arxiv.org/abs/2312.09245) model. The caption files are available here.
106 |
107 | | Dataset | Downloads |
108 | | :-: | :-: |
109 | | nuScenes | [mini](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/nuscenes_v1.0-mini_caption_v2.zip?download=true), [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/nuscenes_v1.0-trainval_caption_v2.zip?download=true) |
110 | | Waymo | [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/waymo_caption_v2.zip?download=true) |
111 | | Argoverse | [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/av2_sensor_caption_v2.zip?download=true) |
112 | | OpenDV | [all](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/opendv_caption.zip?download=true) |
113 |
114 | 1. Download the packages above and unzip them.
115 |
116 | 2. You will get some JSON files such as `nuscenes_v1.0-trainval_caption_v2_train.json` (text of image captions), `nuscenes_v1.0-trainval_caption_v2_times_train.json` (indicates the moments of frames selected for image caption annotations in each scenario)
117 |
118 | 3. Update [paths](../configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json#L315) to those files in your config. Please notice that the paths in the "image_description_settings" of all the datasets should be updated to your local downloaded and extracted files.
119 |
--------------------------------------------------------------------------------
/src/dwm/utils/sampler.py:
--------------------------------------------------------------------------------
1 | import dwm.common
2 | import torch
3 | import random
4 | from collections import OrderedDict, defaultdict
5 | from typing import Iterator, List, Optional
6 | from torch.utils.data import Dataset, DistributedSampler
7 |
8 |
9 | class VariableVideoBatchSampler(DistributedSampler):
10 |
11 | def __init__(
12 | self,
13 | dataset,
14 | bucket_config: dict,
15 | num_replicas: Optional[int] = None,
16 | rank: Optional[int] = None,
17 | shuffle: bool = True,
18 | seed: int = 0,
19 | drop_last: bool = False,
20 | verbose: bool = False,
21 | num_bucket_build_workers: int = 1,
22 | ) -> None:
23 | super().__init__(
24 | dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last
25 | )
26 | self.dataset = dataset
27 | self.bucket = bucket_config
28 |
29 | self.res = [k for k in self.bucket.keys()]
30 | self.res_w = [v[0] for v in self.bucket.values()]
31 |
32 | self.res_tbw = {}
33 |
34 | for k, v in self.bucket.items():
35 | self.res_tbw[k] = {}
36 | self.res_tbw[k]["t_bs"] = [(tri[0], tri[1]) for tri in v[1]]
37 | self.res_tbw[k]["w"] = [tri[2] for tri in v[1]]
38 |
39 | self.verbose = verbose
40 | self.last_micro_batch_access_index = 0
41 | self.approximate_num_batch = None
42 |
43 | self._get_num_batch_cached_bucket_sample_dict = None
44 | self.num_bucket_build_workers = num_bucket_build_workers
45 |
46 | def __iter__(self) -> Iterator[List[int]]:
47 |
48 | if self._get_num_batch_cached_bucket_sample_dict is not None:
49 | bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict
50 | self._get_num_batch_cached_bucket_sample_dict = None
51 | else:
52 | bucket_sample_dict = self.group_by_bucket()
53 |
54 | g = torch.Generator()
55 | g.manual_seed(self.seed + self.epoch)
56 | bucket_micro_batch_count = OrderedDict()
57 | bucket_last_consumed = OrderedDict()
58 |
59 | # process the samples
60 | for bucket_id, data_list in bucket_sample_dict.items():
61 |
62 | # handle droplast
63 | bs_per_gpu = int(bucket_id.split("-")[-1])
64 | remainder = len(data_list) % bs_per_gpu
65 |
66 | if remainder > 0:
67 | if not self.drop_last:
68 | # if there is remainder, we pad to make it divisible
69 | data_list += data_list[: bs_per_gpu - remainder]
70 | else:
71 | # we just drop the remainder to make it divisible
72 | data_list = data_list[:-remainder]
73 |
74 | bucket_sample_dict[bucket_id] = data_list
75 |
76 | # handle shuffle
77 | if self.shuffle:
78 | data_indices = torch.randperm(
79 | len(data_list), generator=g).tolist()
80 | data_list = [data_list[i] for i in data_indices]
81 | bucket_sample_dict[bucket_id] = data_list
82 |
83 | # compute how many micro-batches each bucket has
84 | num_micro_batches = len(data_list) // bs_per_gpu
85 | bucket_micro_batch_count[bucket_id] = num_micro_batches
86 |
87 | # compute the bucket access order
88 | # each bucket may have more than one batch of data
89 | # thus bucket_id may appear more than 1 time
90 | bucket_id_access_order = []
91 | for bucket_id, num_micro_batch in bucket_micro_batch_count.items():
92 | bucket_id_access_order.extend([bucket_id] * num_micro_batch)
93 |
94 | # randomize the access order
95 | if self.shuffle:
96 | bucket_id_access_order_indices = torch.randperm(
97 | len(bucket_id_access_order), generator=g).tolist()
98 | bucket_id_access_order = [
99 | bucket_id_access_order[i] for i in bucket_id_access_order_indices]
100 |
101 | # make the number of bucket accesses divisible by dp size
102 | remainder = len(bucket_id_access_order) % self.num_replicas
103 | if remainder > 0:
104 | if self.drop_last:
105 | bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder]
106 | else:
107 | bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder]
108 |
109 | # prepare each batch from its bucket
110 | # according to the predefined bucket access order
111 | num_iters = len(bucket_id_access_order) // self.num_replicas
112 | start_iter_idx = self.last_micro_batch_access_index // self.num_replicas
113 |
114 | # re-compute the micro-batch consumption
115 | # this is useful when resuming from a state dict with a different number of GPUs
116 | self.last_micro_batch_access_index = start_iter_idx * self.num_replicas
117 | for i in range(self.last_micro_batch_access_index):
118 | bucket_id = bucket_id_access_order[i]
119 | bucket_bs = int(bucket_id.split("-")[-1])
120 | if bucket_id in bucket_last_consumed:
121 | bucket_last_consumed[bucket_id] += bucket_bs
122 | else:
123 | bucket_last_consumed[bucket_id] = bucket_bs
124 |
125 | for i in range(start_iter_idx, num_iters):
126 | bucket_access_list = bucket_id_access_order[
127 | i * self.num_replicas: (i + 1) * self.num_replicas]
128 | self.last_micro_batch_access_index += self.num_replicas
129 |
130 | # compute the data samples consumed by each access
131 | bucket_access_boundaries = []
132 | for bucket_id in bucket_access_list:
133 | bucket_bs = int(bucket_id.split("-")[-1])
134 | last_consumed_index = bucket_last_consumed.get(bucket_id, 0)
135 | bucket_access_boundaries.append(
136 | [last_consumed_index, last_consumed_index + bucket_bs])
137 |
138 | # update consumption
139 | if bucket_id in bucket_last_consumed:
140 | bucket_last_consumed[bucket_id] += bucket_bs
141 | else:
142 | bucket_last_consumed[bucket_id] = bucket_bs
143 |
144 | # compute the range of data accessed by each GPU
145 | bucket_id = bucket_access_list[self.rank]
146 | boundary = bucket_access_boundaries[self.rank]
147 | cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0]: boundary[1]]
148 |
149 | # encode t, h, w into the sample index
150 |
151 | b_id = bucket_id.split("-")
152 | real_t, real_h, real_w = b_id[-2], b_id[0], b_id[1]
153 | cur_micro_batch = [
154 | f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch]
155 |
156 | if len(cur_micro_batch) > 0:
157 | yield cur_micro_batch
158 |
159 | self.reset()
160 |
161 | def __len__(self) -> int:
162 | return self.get_num_batch() // dist.get_world_size()
163 |
164 | def group_by_bucket(self) -> dict:
165 |
166 | bucket_sample_dict = OrderedDict()
167 | for i in range(len(self.dataset)):
168 |
169 | res_i = random.choices(self.res, weights=self.res_w, k=1)[0]
170 | t_bs_i = random.choices(
171 | self.res_tbw[res_i]['t_bs'], weights=self.res_tbw[res_i]['w'], k=1)[0]
172 |
173 | bucket_id = f"{res_i}-{t_bs_i[0]}-{t_bs_i[1]}"
174 |
175 | if bucket_id not in bucket_sample_dict:
176 | bucket_sample_dict[bucket_id] = []
177 | bucket_sample_dict[bucket_id].append(i)
178 |
179 | return bucket_sample_dict
180 |
181 | def get_num_batch(self) -> int:
182 | bucket_sample_dict = self.group_by_bucket()
183 | self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict
184 |
185 | return self.approximate_num_batch
186 |
187 | def reset(self):
188 | self.last_micro_batch_access_index = 0
189 |
190 | def state_dict(self, num_steps: int) -> dict:
191 | # the last_micro_batch_access_index in the __iter__ is often
192 | # not accurate during multi-workers and data prefetching
193 | # thus, we need the user to pass the actual steps which have been executed
194 | # to calculate the correct last_micro_batch_access_index
195 | return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas}
196 |
197 | def load_state_dict(self, state_dict: dict) -> None:
198 | self.__dict__.update(state_dict)
199 |
--------------------------------------------------------------------------------
/docs/LiDAR_Generation.md:
--------------------------------------------------------------------------------
1 | # Layout-Condition LiDAR Generation with Masked Generative Transformer
2 |
3 |
4 | ## Introduction
5 |
6 | We propose a pipeline for generating LiDAR data conditioned on layout information using the Mask Generative Image Transformer (MaskGIT) [[1]](#1). Our approach builds upon the models introduced in Copilot4D [[5]](#5) and UltraLiDAR [[4]](#4).
7 |
8 | ## Method
9 |
10 | We first train a Vector Quantized Variational AutoEncoder (VQ-VAE) to tokenize LiDAR data into a 2D latent space. Then,
11 | the LiDAR MaskGIT model incorporates the layout information such as High-definition maps (HDmaps) and 3D object
12 | bounding boxes (3Dboxes) and guides the generation of LiDAR in the latent space.
13 |
14 | ### LiDAR VQ-VAE
15 |
16 | We build our LiDAR VQ-VAE following the approach of UltraLiDAR [[4]](#4) and Copilot4D [[5]](#5). The point cloud $\mathbf{x}$ is fed into a voxelizer $V$ and converted into a Birds-Eye-View (BEV) image. Additionally, we adopt an auxiliary depth rendering branch during decoding. Specifically, given a latent representation $\mathbf z$ of a LiDAR point cloud, the regular decoder transforms the latent back into the original voxel shape, and binary cross entropy loss is used to optimize the network. Furthermore, a depth render network $D$ decodes the latent into a 3D feature voxel, which is used to render the depth of each point. For any point $p$ in $\mathbf x$, we sample a ray from the LiDAR center to $p$ and use the method from [[3]](#3) to calculate the depth value in this direction. For further details, please refer to [[5]](#5).
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ### LiDAR MaskGIT
25 |
26 | Our LiDAR MaskGIT is designed to generate sequences of LiDAR point clouds conditioned on layouts. Specifically, given the LiDAR point clouds from the first $k$ frames, our model can predict subsequent LiDAR point clouds at following timestamps guided by layout conditions such as 3D bounding boxes (3D boxes) and high-definition maps (HD maps). Additionally, we extend our model to directly generate LiDAR data without reference frames.
27 |
28 | We follow the Copilot4D framework to build our LiDAR MaskGIT, as illustrated in Figure [1](#fig-main). The input to LiDAR MaskGIT is the masked quantized latent, where masked regions are filled with mask tokens, and the model outputs the probability distribution of VQ-VAE codebook IDs at each position. The architecture employs a spatio-temporal transformer that interleaves spatial and temporal attention. We also adopt free space suppression to encourage more diverse generation results. Please refer to~\cite{copilot4d} for additional details. Furthermore, our model can generate LiDAR data based on layout conditions, which we describe below.
29 |
30 | #### Layout Conditioning
31 | Since LiDAR point clouds are encoded into 2D BEV space, both 3D boxes and HD maps are also projected into BEV images. Similar to [[2]](#2), each instance is mapped into the color space. These two conditional images are concatenated and processed through a lightweight image adapter, generating multi-level features that are added to the latent representation at the corresponding transformer layers.
32 |
33 |
34 | ## Experiment
35 | We conduct our experiments on nuScenes [[6]](#6) and KITTI360 [[7]](#7) datasets and report the quantitative results in the following table.
36 |
37 |
38 |
39 |
40 | | Dataset |
41 | IoU |
42 | CD |
43 | MMD |
44 | JSD |
45 |
46 |
47 | | nuScenes |
48 | 0.055 |
49 | 4.438 |
50 | ----- |
51 | ----- |
52 |
53 |
54 | | KITTI360 |
55 | 0.045 |
56 | 5.838 |
57 | 0.00461 |
58 | 0.471 |
59 |
60 |
61 | | nuScenes Temporal |
62 | 0.126 |
63 | 3.487 |
64 | ----- |
65 | ----- |
66 |
67 |
68 | | KITTI360 Temporal |
69 | 0.117 |
70 | 3.347 |
71 | 0.00411 |
72 | 0.313 |
73 |
74 |
75 |
76 | ## Visualization
77 | In this section, we provide some qualitative results of our method. First, you need to install `open3d` in the environment. The visualization code is provided in `src/dwm/utils/lidar_visualizer.py`. You can run the following bash script to generate visualization results.
78 | ```
79 | python src/dwm/utils/lidar_visualizer.py \
80 | --data_type nuscenes \
81 | --lidar_root /path/to/generated/lidar \
82 | --data_root /path/to/nuscenes/json \
83 | --output_path /path/to/output/folder \
84 | ```
85 | ### Single Frame Generation
86 | #### NuScenes
87 |
88 |
89 |
90 |  |
91 |  |
92 |  |
93 |  |
94 |
95 |
96 |
97 | #### KITTI360
98 |
99 |
100 |
101 |  |
102 |  |
103 |  |
104 |  |
105 |
106 |
107 |
108 |
109 | ### Autoregressive Generation
110 | #### NuScenes
111 |
112 |
113 |
114 | Reference Frame |
115 | Frame=2 |
116 | Frame=4 |
117 | Frame=6 |
118 |
119 |
120 |
121 | #### KITTI360
122 |
123 |
124 |
125 | Reference Frame |
126 | Frame=2 |
127 | Frame=4 |
128 | Frame=6 |
129 |
130 |
131 |
132 | ## References
133 |
134 | [1] Huiwen Chang, Han Zhang, Lu Jiang, Ce Liu, and William T Freeman. Maskgit: Masked generative image transformer. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 11315–11325, 2022.
135 |
136 | [2] Rui Chen, Zehuan Wu, Yichen Liu, Yuxin Guo, Jingcheng Ni, Haifeng Xia, and Siyu Xia. Unimlvg: Unified framework for multi-view long video generation with comprehensive control capabilities for autonomous driving. arXiv preprint arXiv:2412.04842, 2024.
137 |
138 | [3] Cheng Sun, Min Sun, and Hwann-Tzong Chen. Improved direct voxel grid optimization for radiance fields reconstruction. arXiv preprint arXiv:2206.05085, 2022.
139 |
140 | [4] Yuwen Xiong, Wei-Chiu Ma, Jingkang Wang, and Raquel Urtasun. Ultralidar: Learning compact representations for lidar completion and generation. arXiv preprint arXiv:2311.01448, 2023.
141 |
142 | [5] Lunjun Zhang, Yuwen Xiong, Ze Yang, Sergio Casas, Rui Hu, and Raquel Urtasun. Copilot4d: Learning unsupervised world models for autonomous driving via discrete diffusion. arXiv preprint arXiv:2311.01017, 2023.
143 |
144 | [6] Holger Caesar, Varun Bankiti, Alex H Lang, Sourabh Vora, Venice Erin Liong, Qiang Xu, Anush Krishnan, Yu Pan, Giancarlo Baldan,
145 | and Oscar Beijbom. nuscenes: A multimodal dataset for autonomous driving. In Proceedings of the IEEE/CVF conference on computer
146 | vision and pattern recognition, pages 11621–11631, 2020.
147 |
148 | [7] Yiyi Liao, Jun Xie, and Andreas Geiger. Kitti-360: A novel dataset and benchmarks for urban scene understanding in 2d and 3d. IEEE
149 | Transactions on Pattern Analysis and Machine Intelligence, 45(3):3292–3310, 2022.
150 |
--------------------------------------------------------------------------------
/src/dwm/schedulers/temporal_independent.py:
--------------------------------------------------------------------------------
1 | import diffusers.schedulers
2 | import diffusers.utils.torch_utils
3 | import torch
4 |
5 |
6 | class DDPMScheduler(diffusers.schedulers.DDPMScheduler):
7 |
8 | def add_noise(
9 | self, original_samples: torch.Tensor, noise: torch.Tensor,
10 | timesteps: torch.IntTensor,
11 | ):
12 | while len(timesteps.shape) < len(original_samples.shape):
13 | timesteps = timesteps.unsqueeze(-1)
14 |
15 | # Make sure alphas_cumprod and timestep have same device and dtype as
16 | # original_samples Move the self.alphas_cumprod to device to avoid
17 | # redundant CPU to GPU data movement for the subsequent add_noise calls
18 | self.alphas_cumprod = self.alphas_cumprod\
19 | .to(device=original_samples.device)
20 | alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
21 | timesteps = timesteps.to(original_samples.device)
22 |
23 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
24 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
25 | noisy_samples = sqrt_alpha_prod * original_samples + \
26 | sqrt_one_minus_alpha_prod * noise
27 | return noisy_samples
28 |
29 | def get_velocity(
30 | self, sample: torch.Tensor, noise: torch.Tensor,
31 | timesteps: torch.IntTensor
32 | ):
33 | while len(timesteps.shape) < len(sample.shape):
34 | timesteps = timesteps.unsqueeze(-1)
35 |
36 | # Make sure alphas_cumprod and timestep have same device and dtype as
37 | # sample
38 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
39 | alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
40 | timesteps = timesteps.to(sample.device)
41 |
42 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
43 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
44 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
45 | return velocity
46 |
47 |
48 | class DDIMScheduler(diffusers.schedulers.DDIMScheduler):
49 |
50 | def _get_variance(self, timestep, prev_timestep):
51 | alpha_prod_t = self.alphas_cumprod[timestep]
52 | alpha_prod_t_prev = torch.where(
53 | prev_timestep >= 0,
54 | self.alphas_cumprod[prev_timestep],
55 | torch.ones(
56 | prev_timestep.shape, dtype=self.alphas_cumprod.dtype,
57 | device=self.alphas_cumprod.device) * self.final_alpha_cumprod
58 | )
59 | beta_prod_t = 1 - alpha_prod_t
60 | beta_prod_t_prev = 1 - alpha_prod_t_prev
61 |
62 | variance = (beta_prod_t_prev / beta_prod_t) * \
63 | (1 - alpha_prod_t / alpha_prod_t_prev)
64 |
65 | return variance
66 |
67 | def step(
68 | self, model_output: torch.Tensor, timestep: torch.IntTensor,
69 | sample: torch.Tensor, eta: float = 0.0,
70 | use_clipped_model_output: bool = False, generator=None,
71 | variance_noise=None, return_dict: bool = True,
72 | ):
73 | if self.num_inference_steps is None:
74 | raise ValueError(
75 | "Number of inference steps is 'None', you need to run "
76 | "'set_timesteps' after creating the scheduler")
77 |
78 | while len(timestep.shape) < len(sample.shape):
79 | timestep = timestep.unsqueeze(-1)
80 |
81 | # 1. get previous step value (=t-1)
82 | prev_timestep = timestep - self.config.num_train_timesteps // \
83 | self.num_inference_steps
84 |
85 | # 2. compute alphas, betas
86 | alpha_prod_t = self.alphas_cumprod[timestep].to(device=sample.device)
87 | alpha_prod_t_prev = torch.where(
88 | prev_timestep >= 0,
89 | self.alphas_cumprod[prev_timestep],
90 | torch.ones(
91 | prev_timestep.shape, dtype=self.alphas_cumprod.dtype,
92 | device=self.alphas_cumprod.device) * self.final_alpha_cumprod
93 | ).to(device=sample.device)
94 |
95 | beta_prod_t = 1 - alpha_prod_t
96 |
97 | # 3. compute predicted original sample from predicted noise also called
98 | # "predicted x_0" of formula (12) from
99 | # https://arxiv.org/pdf/2010.02502.pdf
100 | if self.config.prediction_type == "epsilon":
101 | pred_original_sample = \
102 | (sample - beta_prod_t ** (0.5) * model_output) / \
103 | alpha_prod_t ** (0.5)
104 | pred_epsilon = model_output
105 | elif self.config.prediction_type == "sample":
106 | pred_original_sample = model_output
107 | pred_epsilon = \
108 | (sample - alpha_prod_t ** (0.5) * pred_original_sample) / \
109 | beta_prod_t ** (0.5)
110 | elif self.config.prediction_type == "v_prediction":
111 | pred_original_sample = (alpha_prod_t ** 0.5) * sample - \
112 | (beta_prod_t ** 0.5) * model_output
113 | pred_epsilon = (alpha_prod_t ** 0.5) * model_output + \
114 | (beta_prod_t ** 0.5) * sample
115 | else:
116 | raise ValueError(
117 | "prediction_type given as {} must be one of `epsilon`, "
118 | "`sample`, or `v_prediction`"
119 | .format(self.config.prediction_type))
120 |
121 | # 4. Clip or threshold "predicted x_0"
122 | if self.config.thresholding:
123 | pred_original_sample = self._threshold_sample(pred_original_sample)
124 | elif self.config.clip_sample:
125 | pred_original_sample = pred_original_sample.clamp(
126 | -self.config.clip_sample_range, self.config.clip_sample_range)
127 |
128 | # 5. compute variance: "sigma_t(η)" -> see formula (16)
129 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
130 | variance = self._get_variance(timestep, prev_timestep)\
131 | .to(sample.device)
132 | std_dev_t = eta * variance ** (0.5)
133 |
134 | if use_clipped_model_output:
135 | # the pred_epsilon is always re-derived from the clipped x_0 in
136 | # Glide
137 | pred_epsilon = \
138 | (sample - alpha_prod_t ** (0.5) * pred_original_sample) / \
139 | beta_prod_t ** (0.5)
140 |
141 | # 6. compute "direction pointing to x_t" of formula (12) from
142 | # https://arxiv.org/pdf/2010.02502.pdf
143 | pred_sample_direction = \
144 | (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
145 |
146 | # 7. compute x_t without "random noise" of formula (12) from
147 | # https://arxiv.org/pdf/2010.02502.pdf
148 | prev_sample = alpha_prod_t_prev ** (0.5) * \
149 | pred_original_sample + pred_sample_direction
150 |
151 | if eta > 0:
152 | if variance_noise is not None and generator is not None:
153 | raise ValueError(
154 | "Cannot pass both generator and variance_noise. Please "
155 | "make sure that either `generator` or `variance_noise` "
156 | "stays `None`.")
157 |
158 | if variance_noise is None:
159 | variance_noise = diffusers.utils.torch_utils.randn_tensor(
160 | model_output.shape, generator=generator,
161 | device=model_output.device, dtype=model_output.dtype)
162 |
163 | variance = std_dev_t * variance_noise
164 | prev_sample = prev_sample + variance
165 |
166 | if not return_dict:
167 | return (prev_sample, pred_original_sample)
168 |
169 | return diffusers.schedulers.scheduling_ddim.DDIMSchedulerOutput(
170 | prev_sample=prev_sample, pred_original_sample=pred_original_sample)
171 |
172 |
173 | class FlowMatchEulerDiscreteScheduler(
174 | diffusers.schedulers.FlowMatchEulerDiscreteScheduler
175 | ):
176 | def step_by_indices(
177 | self, model_output: torch.FloatTensor, timestep_indices,
178 | sample: torch.FloatTensor, return_dict: bool = True
179 | ):
180 | if isinstance(timestep_indices, torch.Tensor):
181 | while len(timestep_indices.shape) < model_output.ndim:
182 | timestep_indices = timestep_indices.unsqueeze(-1)
183 |
184 | # Upcast to avoid precision issues when computing prev_sample
185 | sample = sample.to(torch.float32)
186 |
187 | sigma = self.sigmas[timestep_indices]
188 | sigma_next = self.sigmas[timestep_indices + 1]
189 | prev_sample = sample + (sigma_next - sigma) * model_output
190 |
191 | # Cast sample back to model compatible dtype
192 | prev_sample = prev_sample.to(model_output.dtype)
193 | if not return_dict:
194 | return (prev_sample,)
195 |
196 | return diffusers.schedulers.scheduling_flow_match_euler_discrete.\
197 | FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
198 |
--------------------------------------------------------------------------------