├── src └── dwm │ ├── __init__.py │ ├── fs │ ├── __init__.py │ ├── README.md │ ├── ctar.py │ ├── czip.py │ ├── s3fs.py │ └── dirfs.py │ ├── datasets │ ├── __init__.py │ ├── README.md │ └── waymo_common.py │ ├── models │ ├── __init__.py │ ├── adapters.py │ ├── voxelizer.py │ └── maskgit_base.py │ ├── utils │ ├── __init__.py │ ├── lidar.py │ ├── carla_actor_state_machines.py │ ├── make_carla_cameras.py │ ├── make_blank_code.py │ ├── preview.py │ ├── carla_simulation.py │ ├── carla_control.py │ └── sampler.py │ ├── pipelines │ └── __init__.py │ ├── tools │ ├── transcode_video.json │ ├── tar2zip.py │ ├── transcode_video.py │ ├── prepare_opendv.py │ ├── fs_make_info_json.py │ ├── export_nusc_2_preview_format.py │ └── dataset_make_info_json.py │ ├── metrics │ ├── general_metrics.py │ ├── voxel_metrics.py │ ├── fvd.py │ └── pc_metrics.py │ ├── distributed.py │ ├── evaluate.py │ ├── streaming.py │ ├── export_generation_result_as_nuscenes_data.py │ ├── preview.py │ ├── common.py │ ├── functional.py │ └── schedulers │ └── temporal_independent.py ├── configs ├── fs │ ├── local.json │ ├── local_nuscenes.json │ ├── s3_cn_sh.json │ ├── s3_st_sh.json │ ├── local_czip_nuscenes.json │ └── s3_czip_nuscenes.json ├── README.md └── experimental │ └── simulation │ └── make_carla_cameras_from_nuscenes.json ├── .gitignore ├── requirements.txt ├── .gitmodules ├── LICENSE ├── README_intro_zh.md ├── docs ├── CtsdPipelineFaqs.md ├── InteractiveGeneration.md ├── Datasets.md └── LiDAR_Generation.md └── examples └── ctsd_generation_example.py /src/dwm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dwm/fs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dwm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dwm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dwm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dwm/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/fs/local.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "fsspec.implementations.local.LocalFileSystem" 3 | } -------------------------------------------------------------------------------- /configs/fs/local_nuscenes.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "dwm.fs.dirfs.DirFileSystem", 3 | "path": "/mnt/storage/user/wuzehuan/Downloads/data/nuscenes" 4 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | build 4 | dist 5 | *.egg-info 6 | output 7 | /externals/chamferdist/chamferdist/*.so 8 | work_dirs 9 | wandb 10 | taming 11 | -------------------------------------------------------------------------------- /configs/fs/s3_cn_sh.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 3 | "endpoint_url": "http://aoss-internal.cn-sh-01.sensecoreapi-oss.cn", 4 | "aws_access_key_id": "AE939C3A07AE4E6D93908AA603B9F3A9", 5 | "aws_secret_access_key": "" 6 | } -------------------------------------------------------------------------------- /configs/fs/s3_st_sh.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 3 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn", 4 | "aws_access_key_id": "4853705CDE93446E8D902F70291C2C92", 5 | "aws_secret_access_key": "" 6 | } -------------------------------------------------------------------------------- /src/dwm/tools/transcode_video.json: -------------------------------------------------------------------------------- 1 | { 2 | "ffmpeg_args": [ 3 | "-r", 4 | "10", 5 | "-g", 6 | "20", 7 | "-vf", 8 | "scale=-1:720", 9 | "-row-mt", 10 | "1", 11 | "-b:v", 12 | "2M", 13 | "-crf", 14 | "10", 15 | "-pix_fmt", 16 | "yuv420p" 17 | ] 18 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | accelerate==1.4.0 3 | av==14.0.0 4 | bitsandbytes 5 | botocore 6 | diffusers==0.31.0 7 | einops 8 | easydict 9 | fsspec 10 | numpy 11 | pandas 12 | protobuf==4.21.6 13 | pyarrow 14 | safetensors 15 | sentencepiece 16 | scipy 17 | tensorboard 18 | torch-fidelity 19 | torchmetrics==1.6.0 20 | tqdm 21 | transformers==4.50.0 22 | transforms3d 23 | Ninja 24 | chamferdist 25 | timm 26 | rotary-embedding-torch 27 | opencv-python 28 | wandb 29 | git+https://github.com/autonomousvision/kitti360Scripts.git 30 | -------------------------------------------------------------------------------- /src/dwm/metrics/general_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | from torchmetrics import MeanMetric 4 | import torch.distributed 5 | 6 | 7 | class CustomMeanMetrics(MeanMetric): 8 | """ 9 | Description: 10 | calculate the mean value of certain metrics 11 | """ 12 | 13 | def __init__( 14 | self, **kwargs 15 | ): 16 | super().__init__(**kwargs) 17 | 18 | def compute(self, **kwargs): 19 | self.num_samples = int(self.weight) 20 | return super().compute(**kwargs) 21 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "externals/TATS"] 2 | path = externals/TATS 3 | url = https://github.com/songweige/TATS.git 4 | [submodule "externals/taming-transformers"] 5 | path = externals/taming-transformers 6 | url = https://github.com/CompVis/taming-transformers.git 7 | [submodule "externals/waymo-open-dataset"] 8 | path = externals/waymo-open-dataset 9 | url = https://github.com/waymo-research/waymo-open-dataset.git 10 | [submodule "externals/dvgo_cuda"] 11 | path = externals/dvgo_cuda 12 | url = https://github.com/sunset1995/DirectVoxGO.git 13 | -------------------------------------------------------------------------------- /configs/fs/local_czip_nuscenes.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "dwm.fs.czip.CombinedZipFileSystem", 3 | "fs": { 4 | "_class_name": "dwm.fs.dirfs.DirFileSystem", 5 | "path": "/mnt/storage/user/wuzehuan" 6 | }, 7 | "paths": [ 8 | "Downloads/data/nuscenes/v1.0-trainval_meta.zip", 9 | "Downloads/data/nuscenes/v1.0-trainval01_blobs.zip", 10 | "Downloads/data/nuscenes/v1.0-trainval02_blobs.zip", 11 | "Downloads/data/nuscenes/v1.0-trainval03_blobs.zip", 12 | "Downloads/data/nuscenes/v1.0-trainval04_blobs.zip", 13 | "Downloads/data/nuscenes/v1.0-trainval05_blobs.zip", 14 | "Downloads/data/nuscenes/v1.0-trainval06_blobs.zip", 15 | "Downloads/data/nuscenes/v1.0-trainval07_blobs.zip", 16 | "Downloads/data/nuscenes/v1.0-trainval08_blobs.zip", 17 | "Downloads/data/nuscenes/v1.0-trainval09_blobs.zip", 18 | "Downloads/data/nuscenes/v1.0-trainval10_blobs.zip" 19 | ] 20 | } -------------------------------------------------------------------------------- /configs/fs/s3_czip_nuscenes.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "dwm.fs.czip.CombinedZipFileSystem", 3 | "fs": { 4 | "_class_name": "dwm.fs.dirfs.DirFileSystem", 5 | "path": "users/wuzehuan", 6 | "fs": { 7 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 8 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn", 9 | "aws_access_key_id": "4853705CDE93446E8D902F70291C2C92", 10 | "aws_secret_access_key": "" 11 | } 12 | }, 13 | "paths": [ 14 | "data/nuscenes/v1.0-trainval_meta.zip", 15 | "data/nuscenes/v1.0-trainval01_blobs.zip", 16 | "data/nuscenes/v1.0-trainval02_blobs.zip", 17 | "data/nuscenes/v1.0-trainval03_blobs.zip", 18 | "data/nuscenes/v1.0-trainval04_blobs.zip", 19 | "data/nuscenes/v1.0-trainval05_blobs.zip", 20 | "data/nuscenes/v1.0-trainval06_blobs.zip", 21 | "data/nuscenes/v1.0-trainval07_blobs.zip", 22 | "data/nuscenes/v1.0-trainval08_blobs.zip", 23 | "data/nuscenes/v1.0-trainval09_blobs.zip", 24 | "data/nuscenes/v1.0-trainval10_blobs.zip" 25 | ] 26 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SenseTime Research. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/dwm/utils/lidar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dwm.functional 3 | 4 | 5 | def preprocess_points(batch, device): 6 | mhv = dwm.functional.make_homogeneous_vector 7 | return [ 8 | [ 9 | (mhv(p_j.to(device)) @ t_j.permute(1, 0))[:, :3] 10 | for p_j, t_j in zip(p_i, t_i.flatten(0, 1)) 11 | ] 12 | for p_i, t_i in zip( 13 | batch["lidar_points"], batch["lidar_transforms"].to(device)) 14 | ] 15 | 16 | 17 | def postprocess_points(batch, ego_space_points): 18 | return [ 19 | [ 20 | ( 21 | dwm.functional.make_homogeneous_vector(p_j.cpu()) @ 22 | torch.linalg.inv(t_j).permute(1, 0) 23 | )[:, :3] 24 | for p_j, t_j in zip(p_i, t_i.flatten(0, 1)) 25 | ] 26 | for p_i, t_i in zip( 27 | ego_space_points, batch["lidar_transforms"]) 28 | ] 29 | 30 | 31 | def voxels2points(grid_size, voxels): 32 | interval = torch.tensor([grid_size["interval"]]) 33 | min = torch.tensor([grid_size["min"]]) 34 | return [ 35 | [ 36 | torch.nonzero(v_j).flip(-1).cpu() * interval + min 37 | for v_j in v_i 38 | ] 39 | for v_i in voxels 40 | ] 41 | -------------------------------------------------------------------------------- /src/dwm/utils/carla_actor_state_machines.py: -------------------------------------------------------------------------------- 1 | import carla 2 | 3 | 4 | class ClassicPedestrian: 5 | 6 | def __init__(self, controller: carla.Actor): 7 | self.controller = controller 8 | self.state = "idle" 9 | 10 | def update(self): 11 | if self.state == "idle": 12 | world = self.controller.get_world() 13 | self.destination = world.get_random_location_from_navigation() 14 | 15 | # should start by 1 tick after the creation of actor 16 | self.controller.start() 17 | self.controller.go_to_location(self.destination) 18 | self.controller.set_max_speed( 19 | float(self.controller.parent.attributes["speed"])) 20 | 21 | self.state = "acting" 22 | 23 | elif self.state == "acting": 24 | # TODO: stop and transfer to idle when arrival 25 | pass 26 | 27 | 28 | class BevSpectator: 29 | 30 | def __init__(self, actor: carla.Actor): 31 | self.spectator = None 32 | self.hero = actor 33 | world = self.hero.get_world() 34 | self.spectator = world.get_spectator() 35 | 36 | def update(self): 37 | vehicle_transform = self.hero.get_transform() 38 | spectator_transform = carla.Transform( 39 | vehicle_transform.location + carla.Location(x=0, y=0, z=50), 40 | carla.Rotation(pitch=-90, yaw=0, roll=0) 41 | ) 42 | self.spectator.set_transform(spectator_transform) 43 | -------------------------------------------------------------------------------- /README_intro_zh.md: -------------------------------------------------------------------------------- 1 | # Open Driving World Models (OpenDWM) 2 | 3 | [[English README](README.md)] 4 | 5 | https://github.com/user-attachments/assets/649d3b81-3b1f-44f9-9f51-4d1ed7756476 6 | 7 | [视频链接](https://youtu.be/j9RRj-xzOA4) 8 | 9 | 欢迎来到 OpenDWM 项目!这是一个专注于自动驾驶视频生成的开源项目。我们的使命是提供一个高质量、可控的、使用最新技术的自动驾驶视频生成工具。我们的目标是构建一个既用户友好,又高度可复用的代码库,并希望通过聚集社区智慧,不断改进。 10 | 11 | 驾驶世界模型根据文本和道路环境布局条件,生成自动驾驶场景的多视角图像或视频。无论是环境、天气条件、车辆类型,还是驾驶路径,你都可以根据需求来调整。 12 | 13 | 亮点如下: 14 | 15 | 1. **透明且可复现的训练。** 我们提供完整的训练代码和配置,让大家可以根据需要进行实验复现、在自有数据上微调、定制开发功能。 16 | 17 | 2. **环境多样性的显著改进。** 通过对多个数据集的使用,模型的泛化能力得到前所未有的提升。以布局条件控制生成任务为例,下雪的城市街道,远处有雪山的湖边高速路,这些场景对于仅使用单一数据集训练的生成模型都是不可能的任务。 18 | 19 | 3. **大幅提升生成质量。** 对于流行模型架构(SD 2.1, 3.5)的支持,可以更便捷地利用社区内先进的预训练生成能力。包括多任务、自监督在内的多种训练技巧,让模型更有效地利用视频数据里的信息。 20 | 21 | 4. **方便测评。** 测评遵循流行框架 `torchmetrics`,易于配置、开发、并集成到已有管线。一些公开配置(例如在 nuScenes 验证集上的 FID, FVD)用于和其他研究工作对齐。 22 | 23 | 此外,我们设计的代码模块考虑到了相当程度的可复用性,以便于在其他项目中应用。 24 | 25 | 截止现在,本项目实现了以下论文中的技巧: 26 | 27 | > [UniMLVG: Unified Framework for Multi-view Long Video Generation with Comprehensive Control Capabilities for Autonomous Driving](https://sensetime-fvg.github.io/UniMLVG)
28 | > Rui Chen1,2, Zehuan Wu2, Yichen Liu2, Yuxin Guo2, Jingcheng Ni2, Haifeng Xia1, Siyu Xia1
29 | > 1Southeast University 2SenseTime Research 30 | 31 | > [MaskGWM: A Generalizable Driving World Model with Video Mask Reconstruction](https://sensetime-fvg.github.io/MaskGWM)
32 | > Jingcheng Ni, Yuxin Guo, Yichen Liu, Rui Chen, Lewei Lu, Zehuan Wu
33 | > SenseTime Research 34 | 35 | ## 设置和运行 36 | 37 | 请参考 [README](README.md#setup) 中的步骤。 38 | -------------------------------------------------------------------------------- /src/dwm/tools/tar2zip.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import tarfile 5 | import zipfile 6 | 7 | 8 | def create_parser(): 9 | parser = argparse.ArgumentParser( 10 | description="The tool to convert the TAR file to the ZIP file for the " 11 | "capability of random access in package.") 12 | parser.add_argument( 13 | "-i", "--input-path", type=str, required=True, 14 | help="The input path of TAR file.") 15 | parser.add_argument( 16 | "-o", "--output-path", type=str, required=True, 17 | help="The output path to save the converted ZIP file.") 18 | parser.add_argument( 19 | "-e", "--extensions-to-compress", default=".bin|.json|.pcd|.txt", 20 | type=str, help="The extension list to compress, split by ':'.") 21 | return parser 22 | 23 | 24 | def convert_content(tar_file_obj, zip_file_obj, extensions_to_compress: list): 25 | i = tarfile.TarInfo.fromtarfile(tar_file_obj) 26 | while i is not None: 27 | if i.isfile(): 28 | data = tar_file_obj.extractfile(i).read() 29 | modified = datetime.datetime.fromtimestamp(i.mtime) 30 | zi = zipfile.ZipInfo(i.name, modified.timetuple()[:6]) 31 | ext = os.path.splitext(i.name)[-1] 32 | compress_type = zipfile.ZIP_DEFLATED \ 33 | if ext in extensions_to_compress else zipfile.ZIP_STORED 34 | zip_file_obj.writestr(zi, data, compress_type) 35 | 36 | i = tar_file_obj.next() 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = create_parser() 41 | args = parser.parse_args() 42 | 43 | etc = args.extensions_to_compress.split("|") 44 | with tarfile.open(args.input_path) as f_in: 45 | with zipfile.ZipFile(args.output_path, "w") as f_out: 46 | convert_content(f_in, f_out, etc) 47 | -------------------------------------------------------------------------------- /src/dwm/tools/transcode_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import subprocess 5 | 6 | 7 | def create_parser(): 8 | parser = argparse.ArgumentParser( 9 | description="The script to convert videos for better performance.") 10 | parser.add_argument( 11 | "-c", "--config-path", type=str, required=True, 12 | help="The config path.") 13 | parser.add_argument( 14 | "-i", "--input-path", type=str, required=True, 15 | help="The input path of videos.") 16 | parser.add_argument( 17 | "-o", "--output-path", type=str, required=True, 18 | help="The output path to save the merged dict file.") 19 | parser.add_argument("-f", "--range-from", type=int, default=0) 20 | parser.add_argument("-t", "--range-to", type=int, default=-1) 21 | return parser 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = create_parser() 26 | args = parser.parse_args() 27 | 28 | with open(args.config_path, "r", encoding="utf-8") as f: 29 | config = json.load(f) 30 | 31 | files = os.listdir(args.input_path) 32 | range_to = range_to = len(files) if args.range_to == -1 else args.range_to 33 | print( 34 | "Dataset count: {}, processing range {} - {}".format( 35 | len(files), args.range_from, range_to)) 36 | actual_files = files[args.range_from:range_to] 37 | print(actual_files) 38 | os.makedirs(args.output_path, exist_ok=True) 39 | for i in actual_files: 40 | if not os.path.isfile(os.path.join(args.input_path, i)): 41 | continue 42 | 43 | if "reformat" in config: 44 | output_file = os.path.splitext(i)[0] + config["reformat"] 45 | else: 46 | output_file = i 47 | 48 | subprocess.run([ 49 | "ffmpeg", 50 | "-hide_banner", 51 | "-y", 52 | "-i", 53 | os.path.join(args.input_path, i), 54 | *config["ffmpeg_args"], 55 | os.path.join(args.output_path, output_file) 56 | ], stdout=subprocess.PIPE, check=True) 57 | -------------------------------------------------------------------------------- /src/dwm/tools/prepare_opendv.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from datasets import load_dataset 4 | import argparse 5 | 6 | def create_parser(): 7 | parser = argparse.ArgumentParser( 8 | description="The script to make information JSON(s) for \ 9 | official opendv annotations") 10 | parser.add_argument( 11 | "--meta-path", type=str, required=True, 12 | help="The path to official video metas of OpenDV-2K.") 13 | parser.add_argument( 14 | "-o", "--output-path", type=str, required=True, 15 | help="The path to save the information JSON file(s) on the local file " 16 | "system.") 17 | return parser 18 | 19 | if __name__ == "__main__": 20 | parser = create_parser() 21 | args = parser.parse_args() 22 | 23 | ds = load_dataset("OpenDriveLab/OpenDV-YouTube-Language") 24 | split = None 25 | 26 | with open(args.meta_path, "r", encoding="utf-8") as f: 27 | meta_dict = { 28 | i["videoid"]: i 29 | for i in json.load(f) 30 | if split is None or i["split"] == split and 31 | i["videoid"] not in ignore_list 32 | } 33 | 34 | new_image_descriptions = dict() 35 | for sp in ["train", "validation"]: 36 | with tqdm( 37 | ds[sp], 38 | desc=f"Prepare {sp}", 39 | ) as pbar: 40 | for frame_annos in pbar: 41 | k1 = frame_annos['folder'].split('/')[-1] 42 | k2 = int(frame_annos['first_frame'].split('.')[0]) 43 | v = (frame_annos['blip'], frame_annos['cmd']) 44 | if k1 not in meta_dict: 45 | continue 46 | start_discard = meta_dict[k1]["start_discard"] 47 | default_fps, time_base = 10, 0.001 48 | t = int((k2/default_fps+start_discard)/time_base) 49 | new_image_descriptions[f"{k1}.{t}"] = { 50 | "image_description": v[0], 51 | "action": v[1] 52 | } 53 | with open( 54 | args.output_path, "w", encoding="utf-8" 55 | ) as f: 56 | json.dump(new_image_descriptions, f) -------------------------------------------------------------------------------- /src/dwm/models/adapters.py: -------------------------------------------------------------------------------- 1 | import diffusers.models.adapter 2 | import torch 3 | from typing import Optional 4 | 5 | 6 | class ImageAdapter(torch.nn.Module): 7 | def __init__( 8 | self, in_channels: int = 3, 9 | channels: list = [320, 320, 640, 1280, 1280], 10 | is_downblocks: list = [False, True, True, True, False], 11 | num_res_blocks: int = 2, downscale_factor: int = 8, 12 | use_zero_convs: bool = False, zero_gate_coef: Optional[float] = None, 13 | gradient_checkpointing: bool = True 14 | ): 15 | super().__init__() 16 | 17 | in_channels = in_channels * downscale_factor ** 2 18 | self.unshuffle = torch.nn.PixelUnshuffle(downscale_factor) 19 | self.body = torch.nn.ModuleList([ 20 | diffusers.models.adapter.AdapterBlock( 21 | in_channels if i == 0 else channels[i - 1], channels[i], 22 | num_res_blocks, down=is_downblocks[i]) 23 | for i in range(len(channels)) 24 | ]) 25 | self.gradient_checkpointing = gradient_checkpointing 26 | 27 | self.zero_convs = torch.nn.ModuleList([ 28 | torch.nn.Conv2d(channel, channel, 1) 29 | for channel in channels 30 | ]) if use_zero_convs else [None for _ in channels] 31 | for i in self.zero_convs: 32 | if i is not None: 33 | torch.nn.init.zeros_(i.weight) 34 | torch.nn.init.zeros_(i.bias) 35 | 36 | self.zero_gate_coef = zero_gate_coef 37 | self.zero_gates = torch.nn.Parameter(torch.zeros(len(channels))) \ 38 | if zero_gate_coef else None 39 | 40 | def forward(self, x: torch.Tensor, return_features: bool = False): 41 | base_shape = x.shape[:-3] 42 | x = self.unshuffle(x.flatten(0, -4)) 43 | features = [] 44 | for i, (block, zero_conv) in enumerate(zip(self.body, self.zero_convs)): 45 | if self.training and self.gradient_checkpointing: 46 | x = torch.utils.checkpoint.checkpoint( 47 | block, x, use_reentrant=False) 48 | else: 49 | x = block(x) 50 | 51 | x_out = x 52 | if zero_conv is not None: 53 | x_out = zero_conv(x_out) 54 | 55 | if self.zero_gates is not None: 56 | x_out = x_out * torch.tanh( 57 | self.zero_gate_coef * self.zero_gates[i]) 58 | 59 | features.append(x_out.view(*base_shape, *x_out.shape[1:])) 60 | return features if not return_features else features[-1] 61 | -------------------------------------------------------------------------------- /src/dwm/utils/make_carla_cameras.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import numpy as np 5 | import transforms3d 6 | 7 | 8 | def create_parser(): 9 | parser = argparse.ArgumentParser( 10 | description="Make Carla camera parameters from intrinsic matrices and " 11 | "transform matrices.") 12 | parser.add_argument( 13 | "-i", "--input-path", type=str, required=True, 14 | help="The path of input camera parameter file.") 15 | parser.add_argument( 16 | "-o", "--output-path", type=str, required=True, 17 | help="The path of output Carla parameter file.") 18 | return parser 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = create_parser() 23 | args = parser.parse_args() 24 | 25 | rear_ego_to_center_ego = [-1.5, 0, 0] 26 | lh_from_rh = rh_from_lh = np.diag([1, -1, 1, 1]) 27 | 28 | # z fontal camera (OpenCV style) is x-right, y-down, z-front. 29 | # x fontal camera (Y flipped Carla style) is x-front, y-left, z-up. 30 | z_frontal_camera_from_x_frontal_camera = np.array([ 31 | [0, -1, 0, 0], 32 | [0, 0, -1, 0], 33 | [1, 0, 0, 0], 34 | [0, 0, 0, 1] 35 | ], np.float32) 36 | 37 | with open(args.input_path, "r", encoding="utf-8") as f: 38 | config = json.load(f) 39 | 40 | result = {} 41 | for k, v in config.items(): 42 | carla_transform = lh_from_rh @ np.array(v["transform"]) @ \ 43 | z_frontal_camera_from_x_frontal_camera @ rh_from_lh 44 | euler_rotation = transforms3d.euler.mat2euler( 45 | carla_transform[:3, :3], "szyx") 46 | result[k] = { 47 | "attributes": { 48 | "fov": str( 49 | math.degrees( 50 | math.atan(v["intrinsic"][0][2] / v["intrinsic"][0][0]) 51 | + math.atan( 52 | (v["image_size"][0] - v["intrinsic"][0][2]) / 53 | v["intrinsic"][0][0])) 54 | ), 55 | "role_name": k 56 | }, 57 | "spawn_transform": { 58 | "location": [ 59 | (carla_transform[i][3] + rear_ego_to_center_ego[i]).item() 60 | for i in range(3) 61 | ], 62 | "rotation": [ 63 | math.degrees(-euler_rotation[1]), 64 | math.degrees(euler_rotation[0]), 65 | math.degrees(-euler_rotation[2]) 66 | ] 67 | } 68 | } 69 | 70 | with open(args.output_path, "w", encoding="utf-8") as f: 71 | json.dump(result, f, indent=4) 72 | -------------------------------------------------------------------------------- /src/dwm/tools/fs_make_info_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import fsspec.implementations.local 3 | import json 4 | import os 5 | import tarfile 6 | import zipfile 7 | 8 | 9 | def create_parser(): 10 | parser = argparse.ArgumentParser( 11 | description="The script to make information JSON for TAR or ZIP files " 12 | "to accelerate file system initialization.") 13 | parser.add_argument( 14 | "-i", "--input-path", type=str, required=True, 15 | help="The path of the TAR or ZIP file.") 16 | parser.add_argument( 17 | "-o", "--output-path", type=str, required=True, 18 | help="The path to save the information JSON file on the local file " 19 | "system.") 20 | parser.add_argument( 21 | "-fs", "--fs-config-path", default=None, type=str, 22 | help="The path of file system JSON config to open the input file.") 23 | return parser 24 | 25 | 26 | def enum_tar_members(tarfile): 27 | tarfile._check() 28 | while True: 29 | tarinfo = tarfile.next() 30 | if tarinfo is None: 31 | break 32 | 33 | yield tarinfo 34 | 35 | 36 | def make_info_dict(ext, fobj, enable_tqdm: bool = True): 37 | if ext == ".zip": 38 | with zipfile.ZipFile(fobj) as zf: 39 | return { 40 | i.filename: [i.header_offset, i.file_size, i.is_dir()] 41 | for i in zf.infolist() 42 | } 43 | elif ext == ".tar": 44 | with tarfile.TarFile(fileobj=fobj) as tf: 45 | tar_member_generator = enum_tar_members(tf) 46 | if enable_tqdm: 47 | # Since TAR files do not have centralized index information, 48 | # the scanning process will take a longer time. 49 | import tqdm 50 | tar_member_generator = tqdm.tqdm(tar_member_generator) 51 | 52 | return { 53 | i.name: [i.offset_data, i.size, not i.isfile()] 54 | for i in tar_member_generator 55 | } 56 | else: 57 | raise Exception("Unknown format {}.".format(ext)) 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = create_parser() 62 | args = parser.parse_args() 63 | 64 | if args.fs_config_path is None: 65 | fs = fsspec.implementations.local.LocalFileSystem() 66 | else: 67 | import dwm.common 68 | with open(args.fs_config_path, "r", encoding="utf-8") as f: 69 | fs = dwm.common.create_instance_from_config(json.load(f)) 70 | 71 | with fs.open(args.input_path, "rb") as f: 72 | items = make_info_dict(os.path.splitext(args.input_path)[-1], f) 73 | 74 | with open(args.output_path, "w", encoding="utf-8") as f: 75 | json.dump(items, f) 76 | -------------------------------------------------------------------------------- /src/dwm/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | The datasets of this project are used to read multi-view video clips, point cloud sequence clips, as well as corresponding text descriptions for each frame, camera parameters, ego-vehicle pose transform, projected 3D box condition, and projected HD map condition. 4 | We have adapted the nuScenes, Waymo Percepton, and Argoverse 2 Sensor datasets. 5 | 6 | ## Terms 7 | 8 | **Sequence length** is the frame count of the temporal sequence, denoted as `t`. 9 | 10 | **View count** is the camera view count, denoted as `v`. 11 | 12 | **Sensor count** is the sum of camera view count and the LiDAR count. The LiDAR count can only be 0 or 1 in this version. 13 | 14 | **Image coordinate system** is x-right, y-down, with the origin at the top left corner of the image. 15 | 16 | **Camera coordinate system** is x-right, y-down, z-forward, with the origin at the camera's optical center. 17 | 18 | **Ego vehicle coordinate system** is x-forward, y-left, z-up, with the origin at the the midpoint of the rear vehicle axle. 19 | 20 | ## Data items 21 | 22 | ### Basic information 23 | 24 | `fps`. A float32 tensor in the shape `(1,)`. 25 | 26 | `pts`. A float32 tensor in the shape `(t, sensor_count)` 27 | 28 | `images`. A multi-level list of PIL images in the shape `(t, v)`, with the original image size. 29 | 30 | `lidar_points`. A list in the shape `(t,)` of float32 tensors in the shape `(point_count, 3)`. The point count in each frame of a sequence is not fixed. 31 | 32 | ### Transforms 33 | 34 | `camera_transforms`. A float32 tensor in the shape `(t, v, 4, 4)`, representing the 4x4 transformation matrix from the camera coordinate system to the ego vehicle coordinate system. 35 | 36 | `camera_intrinsics`. A float32 tensor in the shape `(t, v, 3, 3)`, representing the 3x3 transformation matrix from the camera coordinate system to the image coordinate system. 37 | 38 | `image_size`. A float32 tensor in the shape `(t, v, 2)` for the pixel width and height of each image. 39 | 40 | `lidar_transforms`. A float32 tensor in the shape `(t, 1, 4, 4)`, representing the 4x4 transformation matrix from the LiDAR coordinate system to the ego vehicle coordinate system. 41 | 42 | `ego_transforms`. A float32 tensor in the shape `(t, sensor_count, 4, 4)`, representing the 4x4 transformation matrix from the ego vehicle coordinate system to the world coordinate system. 43 | 44 | ### Conditions 45 | 46 | `3dbox_images`. A multi-level list of PIL images in the shape `(t, v)`, the same as the original image size, contains the projected 3D box of object annotations. 47 | 48 | `hdmap_images`. A multi-level list of PIL images in the shape `(t, v)`, the same as the original image size, contains the projected HD map of line annotations. 49 | 50 | `image_description`. A multi-level list of strings in the shape `(t, v)`, for the text descriptions per image. 51 | -------------------------------------------------------------------------------- /src/dwm/metrics/voxel_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import torch.distributed 4 | 5 | 6 | class VoxelIoU(torchmetrics.Metric): 7 | def __init__( 8 | self, **kwargs 9 | ): 10 | super().__init__(**kwargs) 11 | self.iou_list = [] 12 | 13 | def update(self, 14 | gt_voxel: torch.tensor, 15 | pred_voxel: torch.tensor): 16 | if len(gt_voxel.shape) == 3: 17 | self.iou_list.append( 18 | (pred_voxel & gt_voxel).sum() / (pred_voxel | gt_voxel).sum()) 19 | else: 20 | for i, j in zip(gt_voxel, pred_voxel): 21 | self.iou_list.append(((i & j).sum() / (i | j).sum()).float()) 22 | 23 | def compute(self): 24 | iou_list = torch.stack(self.iou_list, dim=0) 25 | world_size = torch.distributed.get_world_size() \ 26 | if torch.distributed.is_initialized() else 1 27 | if world_size > 1: 28 | all_iou = iou_list.new_zeros( 29 | (len(iou_list)*world_size, ) + iou_list.shape[1:]) 30 | torch.distributed.all_gather_into_tensor( 31 | all_iou, iou_list) 32 | iou_list = all_iou 33 | num_samples = (~torch.isnan(iou_list) & ~torch.isinf(iou_list)).sum() 34 | iou_list = torch.nan_to_num(iou_list, nan=0.0, posinf=0.0, neginf=0.0) 35 | self.num_samples = num_samples 36 | return iou_list.sum() / num_samples 37 | 38 | def reset(self): 39 | self.iou_list.clear() 40 | super().reset() 41 | 42 | 43 | class VoxelDiff(torchmetrics.Metric): 44 | def __init__( 45 | self, **kwargs 46 | ): 47 | super().__init__(**kwargs) 48 | self.diff_list = [] 49 | 50 | def update(self, 51 | gt_voxel: torch.tensor, pred_voxel: torch.tensor): 52 | if len(gt_voxel.shape) == 3: 53 | self.diff_list.append(torch.logical_xor( 54 | pred_voxel, gt_voxel).sum().to(torch.float32)) 55 | else: 56 | for i, j in zip(gt_voxel, pred_voxel): 57 | self.diff_list.append(torch.logical_xor( 58 | i, j).sum().float()) 59 | 60 | def compute(self): 61 | diff_list = torch.stack(self.diff_list, dim=0) 62 | world_size = torch.distributed.get_world_size() \ 63 | if torch.distributed.is_initialized() else 1 64 | if world_size > 1: 65 | all_diff = diff_list.new_zeros( 66 | (len(diff_list)*world_size, ) + diff_list.shape[1:]) 67 | torch.distributed.all_gather_into_tensor( 68 | all_diff, diff_list) 69 | diff_list = all_diff 70 | self.num_samples = len(diff_list) 71 | return diff_list.mean() 72 | 73 | def reset(self): 74 | self.diff_list.clear() 75 | super().reset() 76 | -------------------------------------------------------------------------------- /src/dwm/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.distributed 3 | import torch.distributed.checkpoint.state_dict 4 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 5 | 6 | 7 | def distributed_save_optimizer_state(model, optimizer, folder, filename): 8 | if torch.distributed.is_initialized(): 9 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 10 | # DDP: rank 0 save 11 | if torch.distributed.get_rank() == 0: 12 | state_dict = optimizer.state_dict() 13 | torch.save( 14 | state_dict, 15 | os.path.join(folder, "{}.pth".format(filename))) 16 | 17 | elif isinstance(model, FSDP): 18 | # Zero2, Zero3: all ranks save 19 | # Hybird Zero2, Zero3: ranks in the 1st group save 20 | state_dict = torch.distributed.checkpoint.state_dict\ 21 | .get_optimizer_state_dict(model, optimizer) 22 | if model._device_mesh is None or ( 23 | torch.distributed.get_rank() in 24 | model._device_mesh.mesh[0].tolist() 25 | ): 26 | torch.distributed.checkpoint.state_dict_saver.save( 27 | state_dict, checkpoint_id=os.path.join(folder, filename), 28 | process_group=( 29 | None if model._device_mesh is None 30 | else model._device_mesh.get_group(mesh_dim=-1) 31 | )) 32 | 33 | else: 34 | raise Exception( 35 | "Unsupported distribution framework to save the optimizer " 36 | "state.") 37 | 38 | else: 39 | state_dict = optimizer.state_dict() 40 | torch.save(state_dict, os.path.join(folder, "{}.pth".format(filename))) 41 | 42 | 43 | def distributed_load_optimizer_state(model, optimizer, folder, filename): 44 | if torch.distributed.is_initialized(): 45 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 46 | # DDP: all ranks load the same 47 | state_dict = torch.load( 48 | os.path.join(folder, "{}.pth".format(filename)), 49 | map_location="cpu", weights_only=True) 50 | optimizer.load_state_dict(state_dict) 51 | elif isinstance(model, FSDP): 52 | # Zero2, Zero3: all ranks load 53 | # Hybird Zero2, Zero3: ranks in the 1st group load 54 | state_dict = torch.distributed.checkpoint.state_dict\ 55 | .get_optimizer_state_dict(model, optimizer) 56 | torch.distributed.checkpoint.state_dict_loader.load( 57 | state_dict, checkpoint_id=os.path.join(folder, filename), 58 | planner=torch.distributed.checkpoint.DefaultLoadPlanner( 59 | allow_partial_load=True)) 60 | 61 | else: 62 | state_dict = torch.load( 63 | os.path.join(folder, "{}.pth".format(filename)), 64 | map_location="cpu", weights_only=True) 65 | optimizer.load_state_dict(state_dict) 66 | -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | 3 | The configuration files are in the JSON format. They include settings for the models, datasets, pipelines, or any arguments for the program. 4 | 5 | ## Introduction 6 | 7 | In our code, we mainly use JSON objects in three ways: 8 | 9 | 1. As a dictionary 10 | 2. As a function's parameter list 11 | 3. As a constructor and parameter for objects 12 | 13 | ### As a dictionary 14 | 15 | The most common way for the config, for example: 16 | 17 | ```JSON 18 | { 19 | "guidance_scale": 4, 20 | "inference_steps": 40, 21 | "preview_image_size": [ 22 | 448, 23 | 252 24 | ] 25 | } 26 | ``` 27 | 28 | The pipeline finds the corresponding value variable in the dictionary through the key, which determines the behavior at runtime. 29 | 30 | ### As a function's parameter list 31 | 32 | The content of a JSON object is passed into a function, for example: 33 | 34 | ```JSON 35 | { 36 | "num_workers": 3, 37 | "prefetch_factor": 3, 38 | "persistent_workers": true 39 | } 40 | ``` 41 | 42 | The PyTorch data loader will accept all the arguments by 43 | 44 | ```Python 45 | data_loader = torch.utils.data.DataLoader( 46 | dataset, **deserialized_json_object) 47 | ``` 48 | 49 | In this case, you can fill in the required parameters according to the reference documentation of the function (such as the [data loader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) here). 50 | 51 | ### As a constructor and parameter for objects 52 | 53 | The JSON object declares the name of the object to be created, as well as the parameters, for example: 54 | 55 | ```JSON 56 | { 57 | "_class_name": "torch.optim.AdamW", 58 | "lr": 6e-5, 59 | "betas": [ 60 | 0.9, 61 | 0.975 62 | ] 63 | } 64 | ``` 65 | 66 | The "_class_name" is in the format of `{name_space}.{class_or_function_name}`, and other key-value pairs are used as parameters for the class constructor (e.g. [AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW) here) or the function. 67 | 68 | In the code, this type of object is parsed with `dwm.common.create_instance_from_config()` function. 69 | 70 | With this design, the configuration, framework, and components are **loosely coupled**. For example, user can easily switch to a third-party optimizer "bitsandbytes.optim.Adam8bit" without editing the code. Developers can provide any component class (e.g. dataset, data transforms) without having to register to a specific framework. 71 | 72 | ## Development 73 | 74 | ### Name convention 75 | 76 | The configs in this folder are mainly about the pipelines and consumed by the `src/dwm/train.py`. So they are named in the format of `{pipeline_name}_{model_config}_{condition_config}_{data_config}.json`. 77 | 78 | * Pipeline name: the python script name in the `src/dwm/pipelines`. 79 | * Model config: the most discriminative model arguments, such as `spatial`, `crossview`, `temporal` for the SD models. 80 | * Condition config: the additional input for the model, such as `ts` for the "text description per scene", `ti` for the "text description per image", `b` for the box condition, `m` for the map condition. 81 | * Data config: `mini` for the debug purpose. Combination of `nuscenes`, `argoverse`, `waymo`, `opendv` (or their initial letters), for the data components. 82 | -------------------------------------------------------------------------------- /docs/CtsdPipelineFaqs.md: -------------------------------------------------------------------------------- 1 | # CTSD pipeline FAQs 2 | 3 | The CTSD is short for cross-view temporal stable diffusion. This pipeline is an extension of the pre-trained Stable Diffusion model for autonomous driving multi-view tasks and video sequence tasks. 4 | 5 | Here are some frequently asked questions. 6 | 7 | * [Single GPU training](#single-gpu-training) 8 | * [Remove conditions](#remove-conditions) 9 | 10 | ## Single GPU training 11 | 12 | The default training configurations use the [Hybrid FSDP](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html#how-to-use-devicemesh-with-hsdp) to reduce the GPU memory usage for distributed training, which does not reduce memory usage on single GPU systems. 13 | 14 | To reduce memory on single GPU system, we should edit the config to enable quantization and CUDA AMP. 15 | 16 | 1. Set `"quantization_config"` in the `"text_encoder_load_args"` for those models implementing the [HfQuantizer](https://huggingface.co/docs/transformers/v4.48.2/en/main_classes/quantization#transformers.quantizers.HfQuantizer) (make sure the [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/main/en/index) package is installed for the following config). 17 | 18 | ```JSON 19 | { 20 | "pipeline": { 21 | "common_config": { 22 | "text_encoder_load_args": { 23 | "variant": "fp16", 24 | "torch_dtype": { 25 | "_class_name": "get_class", 26 | "class_name": "torch.float16" 27 | }, 28 | "quantization_config": { 29 | "_class_name": "diffusers.quantizers.quantization_config.BitsAndBytesConfig", 30 | "load_in_4bit": true, 31 | "bnb_4bit_quant_type": "nf4", 32 | "bnb_4bit_compute_dtype": { 33 | "_class_name": "get_class", 34 | "class_name": "torch.float16" 35 | } 36 | } 37 | } 38 | } 39 | } 40 | } 41 | ``` 42 | 43 | 2. Switch to quantized optimizer. 44 | 45 | ```JSON 46 | { 47 | "optimizer": { 48 | "_class_name": "bitsandbytes.optim.Adam8bit", 49 | "lr": 5e-5 50 | } 51 | } 52 | ``` 53 | 54 | 3. Enable the CUDA AMP instead of FSDP mixed precision. 55 | 56 | ```JSON 57 | { 58 | "pipeline": { 59 | "common_config": { 60 | "autocast": { 61 | "device_type": "cuda" 62 | } 63 | } 64 | } 65 | } 66 | ``` 67 | 68 | 4. Remove all `"device_mesh"` related config items. 69 | 70 | ## Remove conditions 71 | 72 | Remove layout conditions (3D boxes, HD map): 73 | 74 | * Model: in the case of only removing one condition, adjust the input channel number of `"pipeline.model.condition_image_adapter_config.in_channels"` to 3; If both 3dbox and hdmap are removed, the `"pipeline.model.condition_image_adapter_config"` section should be completely deleted. 75 | 76 | * Dataset: remove `"_3dbox_image_settings"`, `"hdmap_image_settings"` 77 | 78 | * Dataset adapter: remove `torchvision.transforms` items of `"3dbox_images"`, `"hdmap_images"`. 79 | 80 | Dataset related settings, note that modifications to the training set and validation set must be consistent. 81 | 82 | Remove components of text prompt: 83 | 84 | * Refer to the [text condition processing function](../src/dwm/datasets/common.py#L316), in the dataset settings `"image_description_settings"`, add `"selected_keys": ["time", "weather", "environment"]` to exclude other text fields. 85 | -------------------------------------------------------------------------------- /src/dwm/utils/make_blank_code.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.common 3 | import dwm.pipelines.lidar_vqvae 4 | import json 5 | import pickle 6 | import torch 7 | import torch.utils.data 8 | 9 | 10 | def create_parser(): 11 | parser = argparse.ArgumentParser( 12 | description="Make blank code file of the LiDAR codebook checkpoint.") 13 | parser.add_argument( 14 | "-c", "--config-path", type=str, required=True, 15 | help="The path of training config file.") 16 | parser.add_argument( 17 | "-i", "--input-path", type=str, required=True, 18 | help="The path of input checkpoint file.") 19 | parser.add_argument( 20 | "-o", "--output-path", type=str, required=True, 21 | help="The path of output blank code file.") 22 | parser.add_argument( 23 | "-it", "--iteration", default=100, type=int, 24 | help="The iteration count from the validation set for blank code.") 25 | parser.add_argument( 26 | "-s", "--sample-count", default=16, type=int, 27 | help="The count of blank code to sample.") 28 | return parser 29 | 30 | 31 | def count_code(indices): 32 | unique_elements, counts = torch.unique(indices, return_counts=True) 33 | sorted_indices = torch.argsort(counts, descending=True) 34 | sorted_elements = unique_elements[sorted_indices] 35 | sorted_counts = counts[sorted_indices] 36 | return sorted_elements, sorted_counts 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = create_parser() 41 | args = parser.parse_args() 42 | 43 | with open(args.config_path, "r", encoding="utf-8") as f: 44 | config = json.load(f) 45 | 46 | device = torch.device(config.get("device", "cpu")) 47 | dataset = dwm.common.create_instance_from_config( 48 | config["validation_dataset"]) 49 | 50 | vq_point_cloud = dwm.common.create_instance_from_config( 51 | config["pipeline"]["vq_point_cloud"]) 52 | vq_point_cloud.to(device) 53 | vq_point_cloud.load_state_dict( 54 | dwm.pipelines.lidar_vqvae.LidarCodebook.load_state(args.input_path), 55 | strict=False) 56 | 57 | dataloader = torch.utils.data.DataLoader( 58 | dataset, shuffle=True, 59 | **dwm.common.instantiate_config(config["validation_dataloader"])) 60 | 61 | iteration = 0 62 | code_dict = {} 63 | vq_point_cloud.eval() 64 | for batch in dataloader: 65 | with torch.no_grad(): 66 | points = dwm.pipelines.lidar_vqvae.LidarCodebook.get_points( 67 | batch, config["pipeline"]["common_config"], device) 68 | 69 | voxels = vq_point_cloud.voxelizer(points) 70 | lidar_feats = vq_point_cloud.lidar_encoder(voxels) 71 | _, _, code_indices = vq_point_cloud.vector_quantizer( 72 | lidar_feats, vq_point_cloud.code_age, 73 | vq_point_cloud.code_usage) 74 | 75 | codes, counts = count_code(code_indices) 76 | for code, count in zip(codes.tolist(), counts.tolist()): 77 | if code in code_dict: 78 | code_dict[code] += count 79 | else: 80 | code_dict[code] = count 81 | 82 | iteration += 1 83 | if iteration >= args.iteration: 84 | break 85 | 86 | blank_code = [ 87 | i[0] for i in sorted( 88 | code_dict.items(), 89 | key=lambda i: i[1], reverse=True)[:args.sample_count] 90 | ] 91 | with open(args.output_path, "wb") as f: 92 | pickle.dump(blank_code, f) 93 | -------------------------------------------------------------------------------- /src/dwm/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.common 3 | import json 4 | import os 5 | import torch 6 | 7 | 8 | def create_parser(): 9 | parser = argparse.ArgumentParser( 10 | description="The script to finetune a stable diffusion model to the " 11 | "driving dataset.") 12 | parser.add_argument( 13 | "-c", "--config-path", type=str, required=True, 14 | help="The config to load the train model and dataset.") 15 | parser.add_argument( 16 | "-o", "--output-path", type=str, required=True, 17 | help="The path to save checkpoint files."), 18 | parser.add_argument( 19 | "--resume-from", default=None, type=int, 20 | help="The step to resume from") 21 | return parser 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = create_parser() 26 | args = parser.parse_args() 27 | 28 | with open(args.config_path, "r", encoding="utf-8") as f: 29 | config = json.load(f) 30 | 31 | # set distributed training (if enabled), log, random number generator, and 32 | # load the checkpoint (if required). 33 | ddp = "LOCAL_RANK" in os.environ 34 | if ddp: 35 | local_rank = int(os.environ["LOCAL_RANK"]) 36 | device = torch.device(config["device"], local_rank) 37 | if config["device"] == "cuda": 38 | torch.cuda.set_device(local_rank) 39 | 40 | torch.distributed.init_process_group(backend=config["ddp_backend"]) 41 | else: 42 | device = torch.device(config["device"]) 43 | 44 | # setup the global state 45 | if "global_state" in config: 46 | for key, value in config["global_state"].items(): 47 | dwm.common.global_state[key] = \ 48 | dwm.common.create_instance_from_config(value) 49 | 50 | should_log = (ddp and local_rank == 0) or not ddp 51 | should_save = not torch.distributed.is_initialized() or \ 52 | torch.distributed.get_rank() == 0 53 | 54 | # load the pipeline including the models 55 | pipeline = dwm.common.create_instance_from_config( 56 | config["pipeline"], output_path=args.output_path, config=config, 57 | device=device, resume_from=args.resume_from) 58 | if should_log: 59 | print("The pipeline is loaded.") 60 | 61 | validation_dataset = dwm.common.create_instance_from_config( 62 | config["validation_dataset"]) 63 | if ddp: 64 | # make equal sample count for each process to simplify the result 65 | # gathering 66 | total_batch_size = int(os.environ["WORLD_SIZE"]) * \ 67 | config["validation_dataloader"]["batch_size"] 68 | dataset_length = len(validation_dataset) // \ 69 | total_batch_size * total_batch_size 70 | validation_dataset = torch.utils.data.Subset( 71 | validation_dataset, range(0, dataset_length)) 72 | validation_datasampler = \ 73 | torch.utils.data.distributed.DistributedSampler( 74 | validation_dataset) 75 | validation_dataloader = torch.utils.data.DataLoader( 76 | validation_dataset, 77 | **dwm.common.instantiate_config(config["validation_dataloader"]), 78 | sampler=validation_datasampler) 79 | else: 80 | validation_datasampler = None 81 | validation_dataloader = torch.utils.data.DataLoader( 82 | validation_dataset, 83 | **dwm.common.instantiate_config(config["validation_dataloader"])) 84 | 85 | if should_log: 86 | print("The validation dataset is loaded with {} items.".format( 87 | len(validation_dataset))) 88 | 89 | pipeline.evaluate_pipeline( 90 | 0, len(validation_dataset), validation_dataloader, 91 | validation_datasampler) 92 | 93 | if torch.distributed.is_initialized(): 94 | torch.distributed.destroy_process_group() 95 | -------------------------------------------------------------------------------- /docs/InteractiveGeneration.md: -------------------------------------------------------------------------------- 1 | # Interactive Generation 2 | 3 | 4x accelerated playing speed: 4 | 5 | https://github.com/user-attachments/assets/933b84d3-496a-41bd-b6ab-3022a0137062 6 | 7 | We've implemented the interactive generation with Carla (0.9.15). The main components: 8 | 9 | * Server-side [Carla](https://carla.org/) maintains the simulation state. 10 | * Server-side [simulation script](../src/dwm/utils/carla_simulation.py) configs the environment, ego car, sensors, and traffic manager. 11 | * Server-side [streaming generation](../src/dwm/streaming.py) reads condition data from the Carla, and write generated frames to video streaming server. 12 | * Server-side [video streaming server](https://github.com/bluenviron/mediamtx) to publish the video streaming for client video player. 13 | * Client-side [Carla control](../src/dwm/utils/carla_control.py) to control the ego car in the simulation world with kayboard. 14 | * Client-side [video player](https://ffmpeg.org/) to receive the generated result. 15 | 16 | The dataflow is: 17 | 18 | 1. Carla control 19 | 2. Carla (configured by the simulation script) 20 | 3. Streaming generation 21 | 4. Video streaming server 22 | 5. Video player 23 | 24 | ## Requirement 25 | 26 | The server requires: 27 | 28 | 1. GPU (nVidia A100 is recommended) 29 | 2. network accessibility. 30 | 3. Python in 3.9 or 3.10 31 | 4. Carla == 0.9.15 32 | 5. mediamtx 33 | 34 | The client requires: 35 | 1. Windows or Ubuntu (The supported platforms for the Carla Python API). 36 | 2. Python in 3.9 or 3.10 37 | 3. ffmpeg 38 | 39 | ## Models 40 | 41 | The interactive generative model is trained from scratch on autonomous driving data after the specification reduction (model size, view count, resolution) of CTSD 3.5, in order to reduce the overhead of model inference. 42 | 43 | | Base Model | Temporal Training Style | Prediction Style | Configs | Checkpoint Download | 44 | | :-: | :-: | :-: | :-: | :-: | 45 | | [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Diffusion forcing transformer](https://arxiv.org/abs/2502.06764) | [FIFO diffusion](https://arxiv.org/abs/2405.11473) | [Config](../configs/experimental/multi_datasets/ctsd_35_xs_df6v3_tirda_bm_nwao.json) | [Checkpoint](https://huggingface.co/wzhgba/opendwm-models/resolve/main/ctsd_35_xs_df6v3_tirda_bm_nwao_60k.pth?download=true) | 46 | 47 | ## Inference 48 | 49 | ### Server-side Setup 50 | 51 | 1. Download the base model (for VAE and text encoders) and model checkpoint, then edit the [config](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L168). 52 | 53 | 2. Launch the video streaming server following the [official guide](https://github.com/bluenviron/mediamtx?tab=readme-ov-file#installation). 54 | 55 | 3. Launch the Carla: `{CARLA_ROOT}/CarlaUE4.sh -RenderOffScreen -quality-level=Low` 56 | 57 | 4. Configure the Carla by editing the [config template](../configs/experimental/simulation/carla_simulation_town10_nusc_3views.json) and run: `PYTHONPATH=src python src/dwm/utils/carla_simulation.py -c configs/experimental/simulation/carla_simulation_town10_nusc_3views.json --client-timeout 150` 58 | 59 | 5. Edit the generation config template (e.g. [Carla endpoint](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L7), [video streaming options](../configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json#L268)) and run: `PYTHONPATH=src python src/dwm/streaming.py -c configs/experimental/streaming/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming.json -l output/ctsd_35_xs_df6v3_tirda_bm_nwao_streaming -s rtsp://{VIDEO_STREAMING_ENDPOINT}/live --fps 2` 60 | 61 | ### Client-side Setup 62 | 63 | 1. Launch the video player after the server-side streaming begin: `ffplay -fflags nobuffer -rtsp_transport tcp rtsp://{VIDEO_STREAMING_ENDPOINT}/live` 64 | 65 | 2. Launch the Carla control after the server-side streaming begin: `python src\dwm\utils\carla_control.py --host {CARLA_SERVER_ADDRESS} -p {CARLA_SERVER_PORT}` 66 | 67 | ## Known issues 68 | 69 | 1. Generation speed. 70 | 2. Latency due to the denoising queue. 71 | -------------------------------------------------------------------------------- /src/dwm/fs/README.md: -------------------------------------------------------------------------------- 1 | # File System 2 | 3 | This project uses the [fsspec](https://github.com/fsspec/filesystem_spec) to access the data from varied sources including the local file system, S3, ZIP content. 4 | 5 | 6 | ## Common usage 7 | 8 | You can open, read, seek the file with the file system objects. Check out the [usage](https://filesystem-spec.readthedocs.io/en/latest/usage.html) document to quick start. 9 | 10 | ### Create the file system 11 | 12 | ``` Python 13 | import fsspec.implementations.local 14 | fs = fsspec.implementations.local.LocalFileSystem() 15 | ``` 16 | 17 | ### Open and read the file 18 | 19 | ``` Python 20 | import json 21 | config_path = "configs/fs/local.json" 22 | with fs.open(config_path, "r", encoding="utf-8") as f: 23 | config = json.load(f) 24 | 25 | from PIL import Image 26 | image_path = "samples/CAM_FRONT/n008-2018-08-01-15-16-36-0400__CAM_FRONT__1533151603512404.jpg" 27 | with fs.open(image_path, "rb") as f: 28 | image = Image.open(f) 29 | ``` 30 | 31 | 32 | ## API reference 33 | 34 | ### dwm.fs.ctar.CombinedTarFileSystem 35 | 36 | This file system opens several TAR blobs from a given file system and provide the file access inside the TAR blobs. Please note that the required TAR file here refers only to the uncompressed format, as the TAR.GZ format does not support random access to one of the files within the blob. This file system is forkable and compatible with multi worker data loader of the PyTorch. 37 | 38 | ### dwm.fs.czip.CombinedZipFileSystem 39 | 40 | This file system opens several ZIP blobs from a given file system and provide the file access inside the ZIP blobs. It is forkable and compatible with multi worker data loader of the PyTorch. 41 | 42 | ### dwm.fs.s3fs.ForkableS3FileSystem 43 | 44 | This file system opens the S3 service and provide the file access on the service. It is forkable and compatible with multi worker data loader of the PyTorch. 45 | 46 | 47 | ## Configuration samples 48 | It is easy to initialize the file system object by `dwm.common.create_instance_from_config()` with following configurations by JSON. 49 | 50 | ### Local file system 51 | 52 | ``` JSON 53 | { 54 | "_class_name": "fsspec.implementations.local.LocalFileSystem" 55 | } 56 | ``` 57 | 58 | **Relative directory on local file system** 59 | ``` JSON 60 | { 61 | "_class_name": "fsspec.implementations.dirfs.DirFileSystem", 62 | "path": "/mnt/storage/user/wuzehuan/Downloads/data/nuscenes", 63 | "fs": { 64 | "_class_name": "fsspec.implementations.local.LocalFileSystem" 65 | } 66 | } 67 | ``` 68 | 69 | ### S3 file system 70 | 71 | The parameters follow the [Botocore confiruation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html). 72 | 73 | ``` JSON 74 | { 75 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 76 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn", 77 | "aws_access_key_id": "YOUR_ACCESS_KEY", 78 | "aws_secret_access_key": "YOUR_SECRET_KEY" 79 | } 80 | ``` 81 | 82 | **Relative directory on S3 file system** 83 | 84 | ``` JSON 85 | { 86 | "_class_name": "fsspec.implementations.dirfs.DirFileSystem", 87 | "path": "users/wuzehuan/data/nuscenes", 88 | "fs": { 89 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 90 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn", 91 | "aws_access_key_id": "YOUR_ACCESS_KEY", 92 | "aws_secret_access_key": "YOUR_SECRET_KEY" 93 | } 94 | } 95 | ``` 96 | 97 | **Retry options on S3 file system** 98 | 99 | ``` JSON 100 | { 101 | "_class_name": "dwm.fs.s3fs.ForkableS3FileSystem", 102 | "endpoint_url": "http://aoss-internal-v2.st-sh-01.sensecoreapi-oss.cn", 103 | "aws_access_key_id": "YOUR_ACCESS_KEY", 104 | "aws_secret_access_key": "YOUR_SECRET_KEY", 105 | "config": { 106 | "_class_name": "botocore.config.Config", 107 | "retries": { 108 | "max_attempts": 8 109 | } 110 | } 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /src/dwm/models/voxelizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Voxelizer(torch.nn.Module): 5 | """Voxelizer for converting Lidar point cloud to image""" 6 | 7 | def __init__(self, x_min, x_max, y_min, y_max, step, z_min, z_max, z_step): 8 | super().__init__() 9 | 10 | self.x_min = x_min 11 | self.x_max = x_max 12 | self.y_min = y_min 13 | self.y_max = y_max 14 | self.step = step 15 | self.z_min = z_min 16 | self.z_max = z_max 17 | self.z_step = z_step 18 | 19 | self.width = round((self.x_max - self.x_min) / self.step) 20 | self.height = round((self.y_max - self.y_min) / self.step) 21 | self.z_depth = round((self.z_max - self.z_min) / self.z_step) 22 | self.depth = self.z_depth 23 | 24 | def voxelize_single(self, lidar, bev): 25 | """Voxelize a single lidar sweep into image frame 26 | Image frame: 27 | 1. Increasing depth indices corresponds to increasing real world z 28 | values. 29 | 2. Increasing height indices corresponds to decreasing real world y 30 | values. 31 | 3. Increasing width indices corresponds to increasing real world x 32 | values. 33 | Args: 34 | lidar (torch.Tensor N x 4 or N x 5) x, y, z, intensity, height_to_ground (optional) 35 | bev (torch.Tensor D x H x W) D = depth, the bird's eye view 36 | raster to populate 37 | """ 38 | 39 | # 1 & 2. Convert points to tensor index location. Clamp z indices to 40 | # valid range. 41 | indices_h = torch.floor((lidar[:, 1] - self.y_min) / self.step).long() 42 | indices_w = torch.floor((lidar[:, 0] - self.x_min) / self.step).long() 43 | indices_d = torch.floor( 44 | (lidar[:, 2] - self.z_min) / self.z_step).long() 45 | 46 | # 3. Remove points out of bound 47 | valid_mask = ~torch.any( 48 | torch.stack( 49 | [ 50 | indices_h < 0, 51 | indices_h >= self.height, 52 | indices_w < 0, 53 | indices_w >= self.width, 54 | indices_d < 0, 55 | indices_d >= self.z_depth, 56 | ] 57 | ), 58 | dim=0, 59 | ) 60 | indices_h = indices_h[valid_mask] 61 | indices_w = indices_w[valid_mask] 62 | indices_d = indices_d[valid_mask] 63 | # 4. Assign indices to 1 64 | bev[indices_d, indices_h, indices_w] = 1.0 65 | 66 | def forward(self, lidars): 67 | """Voxelize multiple sweeps in the current vehicle frame into voxels 68 | in image frame 69 | Args: 70 | list(list(tensor)): B * T * tensor[N x 4], 71 | where B = batch_size, T = 5, N is variable, 72 | 4 = [x, y, z, intensity] 73 | Returns: 74 | tensor: [B x D x H x W], B = batch_size, D = T * depth, H = height, 75 | W = width 76 | """ 77 | batch_size = len(lidars) 78 | assert batch_size > 0 and len(lidars[0]) > 0 79 | num_sweep = len(lidars[0]) 80 | 81 | bev = torch.zeros( 82 | (batch_size, num_sweep, self.depth, self.height, self.width), 83 | dtype=torch.float, 84 | device=lidars[0][0][0].device, 85 | ) 86 | for b in range(batch_size): 87 | assert len(lidars[b]) == num_sweep 88 | for i in range(num_sweep): 89 | self.voxelize_single(lidars[b][i], bev[b][i]) 90 | 91 | return bev 92 | 93 | def get_voxel_coordinates(self, downsample_scale = 1): 94 | x_coord = torch.arange(self.x_min, self.x_max, ((self.x_max - self.x_min) / self.width) / downsample_scale) + self.step / 2 95 | y_coord = torch.arange(self.y_min, self.y_max, ((self.y_max - self.y_min) / self.height) / downsample_scale) + self.step / 2 96 | z_coord = torch.arange(self.z_min, self.z_max, ((self.z_max - self.z_min) / self.z_depth) / downsample_scale) + self.z_step / 2 97 | x_grid, y_grid, z_grid = torch.meshgrid(z_coord, y_coord, x_coord) 98 | 99 | return torch.stack([x_grid, y_grid, z_grid], dim = -1) 100 | -------------------------------------------------------------------------------- /examples/ctsd_generation_example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.distributed 3 | import dwm.common 4 | import dwm.utils.preview 5 | import json 6 | import os 7 | import torch 8 | import torchvision 9 | 10 | 11 | def create_parser(): 12 | parser = argparse.ArgumentParser( 13 | description="The script to finetune a stable diffusion model to the " 14 | "driving dataset.") 15 | parser.add_argument( 16 | "-c", "--config-path", type=str, required=True, 17 | help="The config to load the train model and dataset.") 18 | parser.add_argument( 19 | "-o", "--output-path", type=str, required=True, 20 | help="The path to save checkpoint files.") 21 | return parser 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = create_parser() 26 | args = parser.parse_args() 27 | 28 | with open(args.config_path, "r", encoding="utf-8") as f: 29 | config = json.load(f) 30 | 31 | ddp = "LOCAL_RANK" in os.environ 32 | if ddp: 33 | local_rank = int(os.environ["LOCAL_RANK"]) 34 | device = torch.device(config["device"], local_rank) 35 | if config["device"] == "cuda": 36 | torch.cuda.set_device(local_rank) 37 | 38 | torch.distributed.init_process_group(backend=config["ddp_backend"]) 39 | else: 40 | device = torch.device(config["device"]) 41 | 42 | pipeline = dwm.common.create_instance_from_config( 43 | config["pipeline"], output_path=None, config=config, device=device) 44 | print("The pipeline is loaded.") 45 | 46 | os.makedirs(args.output_path, exist_ok=True) 47 | for i_id, i in enumerate(config["inputs"]): 48 | i["batch"] = { 49 | k: torch.tensor(v) if k != "clip_text" else v 50 | for k, v in i["batch"].items() 51 | } 52 | with torch.no_grad(): 53 | if ( 54 | "sequence_length_per_iteration" in 55 | config["pipeline"]["inference_config"] 56 | ): 57 | latent_shape = tuple( 58 | i["latent_shape"][:1] + [ 59 | config["pipeline"]["inference_config"] 60 | ["sequence_length_per_iteration"], 61 | ] + i["latent_shape"][2:] 62 | ) 63 | pipeline_output = pipeline.autoregressive_inference_pipeline( 64 | **{ 65 | k: latent_shape if k == "latent_shape" else v 66 | for k, v in i.items() 67 | }) 68 | else: 69 | pipeline_output = pipeline.inference_pipeline(**i) 70 | 71 | output_images = pipeline_output["images"] 72 | collected_images = [ 73 | output_images.cpu().unflatten(0, i["latent_shape"][:3]) 74 | ] 75 | 76 | stacked_images = torch.stack(collected_images) 77 | resized_images = torch.nn.functional.interpolate( 78 | stacked_images.flatten(0, 3), 79 | tuple(pipeline.inference_config["preview_image_size"][::-1]) 80 | ) 81 | resized_images = resized_images.view( 82 | *stacked_images.shape[:4], -1, *resized_images.shape[-2:]) 83 | 84 | if not ddp or torch.distributed.get_rank() == 0: 85 | if i["latent_shape"][1] == 1: 86 | # [C, B * T * S * H, V * W] 87 | preview_tensor = resized_images.permute(4, 1, 2, 0, 5, 3, 6)\ 88 | .flatten(-2).flatten(1, 4) 89 | image_output_path = os.path.join( 90 | args.output_path, "{}.png".format(i_id)) 91 | torchvision.transforms.functional.to_pil_image(preview_tensor)\ 92 | .save(image_output_path) 93 | else: 94 | # [T, C, B * S * H, V * W] 95 | preview_tensor = resized_images.permute(2, 4, 1, 0, 5, 3, 6)\ 96 | .flatten(-2).flatten(2, 4) 97 | video_output_path = os.path.join( 98 | args.output_path, "{}.mp4".format(i_id)) 99 | dwm.utils.preview.save_tensor_to_video( 100 | video_output_path, "libx264", i["batch"]["fps"][0].item(), 101 | preview_tensor) 102 | 103 | print("{} done".format(i_id)) 104 | -------------------------------------------------------------------------------- /src/dwm/fs/ctar.py: -------------------------------------------------------------------------------- 1 | import dwm.common 2 | import fsspec.archive 3 | import json 4 | import os 5 | import re 6 | import tarfile 7 | 8 | 9 | class CombinedTarFileSystem(fsspec.archive.AbstractArchiveFileSystem): 10 | 11 | root_marker = "" 12 | protocol = "ctar" 13 | cachable = False 14 | 15 | def __init__( 16 | self, paths: list, fs=None, 17 | enable_cached_info: bool = False, **kwargs 18 | ): 19 | super().__init__(**kwargs) 20 | 21 | self.fs = fsspec.implementations.local.LocalFileSystem() \ 22 | if fs is None else fs 23 | self.paths = paths 24 | 25 | _belongs_to = {} 26 | self._info = {} 27 | pattern = re.compile(r"\.tar$") 28 | for path in paths: 29 | cached_info_path = re.sub(pattern, ".info.json", path) 30 | if enable_cached_info: 31 | with fs.open(cached_info_path, "r", encoding="utf-8") as f: 32 | infodict = json.load(f) 33 | 34 | for filename, i in infodict.items(): 35 | _belongs_to[filename] = path 36 | 37 | else: 38 | infodict = {} 39 | with fs.open(path) as f: 40 | with tarfile.TarFile(fileobj=f) as tf: 41 | for i in tf.getmembers(): 42 | _belongs_to[i.name] = path 43 | infodict[i.name] = [ 44 | i.offset_data, i.size, not i.isfile() 45 | ] 46 | 47 | self._info[path] = dwm.common.SerializedReadonlyDict(infodict) 48 | 49 | self._belongs_to = dwm.common.SerializedReadonlyDict(_belongs_to) 50 | 51 | self.dir_cache = None 52 | self.fp_cache = {} 53 | 54 | def close(self): 55 | # the file points are not forkable 56 | current_pid = os.getpid() 57 | if self._pid != current_pid: 58 | self.fp_cache.clear() 59 | self._pid = current_pid 60 | 61 | for fp in self.fp_cache.values(): 62 | fp.close() 63 | 64 | self.fp_cache.clear() 65 | 66 | def _get_dirs(self): 67 | if self.dir_cache is None: 68 | self.dir_cache = { 69 | dirname.rstrip("/"): { 70 | "name": dirname.rstrip("/"), 71 | "size": 0, 72 | "type": "directory", 73 | } 74 | for dirname in self._all_dirnames(self._belongs_to.keys()) 75 | } 76 | 77 | for file_name in self._belongs_to.keys(): 78 | blob_path = self._belongs_to[file_name] 79 | _, file_size, is_dir = self._info[blob_path][file_name] 80 | f = { 81 | "name": file_name.rstrip("/"), 82 | "size": file_size, 83 | "type": "directory" if is_dir else "file", 84 | } 85 | self.dir_cache[f["name"]] = f 86 | 87 | def exists(self, path, **kwargs): 88 | return path in self._belongs_to 89 | 90 | def _open( 91 | self, path, mode="rb", block_size=None, autocommit=True, 92 | cache_options=None, **kwargs 93 | ): 94 | if "w" in mode: 95 | raise OSError("Combined TAR file system is read only.") 96 | 97 | if "b" not in mode: 98 | raise OSError( 99 | "Combined TAR file system only support the mode of binary.") 100 | 101 | if not self.exists(path): 102 | raise FileNotFoundError(path) 103 | 104 | blob_path = self._belongs_to[path] 105 | offset, size, is_dir = self._info[blob_path][path] 106 | if is_dir: 107 | raise FileNotFoundError(path) 108 | 109 | # the file points are not forkable 110 | current_pid = os.getpid() 111 | if self._pid != current_pid: 112 | self.fp_cache.clear() 113 | self._pid = current_pid 114 | 115 | if blob_path in self.fp_cache: 116 | f = self.fp_cache[blob_path] 117 | else: 118 | f = self.fp_cache[blob_path] = \ 119 | self.fs.open(blob_path, mode, block_size, cache_options) 120 | 121 | return dwm.common.PartialReadableRawIO(f, offset, offset + size) 122 | -------------------------------------------------------------------------------- /src/dwm/metrics/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed 3 | import torchmetrics 4 | 5 | # PYTHONPATH includes ${workspaceFolder}/externals/TATS/tats/fvd 6 | import pytorch_i3d 7 | 8 | 9 | def _compute_fid( 10 | mu1: torch.Tensor, sigma1: torch.Tensor, mu2: torch.Tensor, 11 | sigma2: torch.Tensor 12 | ): 13 | # The same implementation as torchmetrics.image.fid._compute_fid() 14 | 15 | a = (mu1 - mu2).square().sum(dim=-1) 16 | b = sigma1.trace() + sigma2.trace() 17 | c = torch.linalg.eigvals(sigma1 @ sigma2).sqrt().real.sum(dim=-1) 18 | 19 | return a + b - 2 * c 20 | 21 | 22 | class FrechetVideoDistance(torchmetrics.Metric): 23 | 24 | target_resolution = (224, 224) 25 | i3d_min = 10 26 | 27 | def __init__( 28 | self, inception_3d_checkpoint_path: str, sequence_count: int = -1, 29 | num_classes: int = 400, **kwargs 30 | ): 31 | super().__init__(**kwargs) 32 | 33 | self.inception = pytorch_i3d.InceptionI3d(num_classes) 34 | self.inception.eval() 35 | state_dict = torch.load( 36 | inception_3d_checkpoint_path, map_location="cpu", 37 | weights_only=True) 38 | self.inception.load_state_dict(state_dict) 39 | self.sequence_count = sequence_count 40 | 41 | mx_num_feats = (num_classes, num_classes) 42 | self.add_state( 43 | "real_features_sum", torch.zeros(num_classes).double(), 44 | dist_reduce_fx="sum") 45 | self.add_state( 46 | "real_features_cov_sum", torch.zeros(mx_num_feats).double(), 47 | dist_reduce_fx="sum") 48 | self.add_state( 49 | "real_features_num_samples", torch.tensor(0).long(), 50 | dist_reduce_fx="sum") 51 | self.add_state( 52 | "fake_features_sum", torch.zeros(num_classes).double(), 53 | dist_reduce_fx="sum") 54 | self.add_state( 55 | "fake_features_cov_sum", torch.zeros(mx_num_feats).double(), 56 | dist_reduce_fx="sum") 57 | self.add_state( 58 | "fake_features_num_samples", torch.tensor(0).long(), 59 | dist_reduce_fx="sum") 60 | 61 | def update(self, frames, real=True): 62 | """Update the state with extracted features. 63 | 64 | Args: 65 | frames: The video frame tensor to evaluate in the shape of 66 | `(batch_size, sequence_length, channels, height, width)`. 67 | real: Whether given frames are real or fake. 68 | """ 69 | 70 | if self.sequence_count >= 0: 71 | frames = frames[:, :self.sequence_count] 72 | 73 | assert frames.shape[1] >= FrechetVideoDistance.i3d_min 74 | 75 | # normalize from [0, 1] to [-1, 1] 76 | frames = frames * 2 - 1 77 | 78 | frames = torch.nn.functional.interpolate( 79 | frames.flatten(0, 1), size=FrechetVideoDistance.target_resolution, 80 | mode="bilinear" 81 | ).unflatten(0, frames.shape[:2]) 82 | features = self.inception(frames.transpose(1, 2)) 83 | self.orig_dtype = features.dtype 84 | features = features.double() 85 | 86 | if real: 87 | self.real_features_sum += features.sum(dim=0) 88 | self.real_features_cov_sum += features.t().mm(features) 89 | self.real_features_num_samples += features.shape[0] 90 | else: 91 | self.fake_features_sum += features.sum(dim=0) 92 | self.fake_features_cov_sum += features.t().mm(features) 93 | self.fake_features_num_samples += features.shape[0] 94 | 95 | def compute(self): 96 | if ( 97 | self.real_features_num_samples < 2 or 98 | self.fake_features_num_samples < 2 99 | ): 100 | raise RuntimeError( 101 | "More than one sample is required for both the real and fake " 102 | "distributed to compute FVD") 103 | 104 | mean_real = ( 105 | self.real_features_sum / self.real_features_num_samples 106 | ).unsqueeze(0) 107 | mean_fake = ( 108 | self.fake_features_sum / self.fake_features_num_samples 109 | ).unsqueeze(0) 110 | 111 | cov_real_num = self.real_features_cov_sum - \ 112 | self.real_features_num_samples * mean_real.t().mm(mean_real) 113 | cov_real = cov_real_num / (self.real_features_num_samples - 1) 114 | cov_fake_num = self.fake_features_cov_sum - \ 115 | self.fake_features_num_samples * mean_fake.t().mm(mean_fake) 116 | cov_fake = cov_fake_num / (self.fake_features_num_samples - 1) 117 | return _compute_fid( 118 | mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake 119 | ).to(self.orig_dtype) 120 | -------------------------------------------------------------------------------- /src/dwm/tools/export_nusc_2_preview_format.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.common 3 | import json 4 | import os 5 | import copy 6 | import torch 7 | import torchvision 8 | 9 | 10 | def create_parser(): 11 | parser = argparse.ArgumentParser( 12 | description="The script is designed to convert the nuscenes " 13 | "dataset into independent data packets, suitable for data " 14 | "loading in preview.py.") 15 | parser.add_argument( 16 | "--reference-frame-count", type=int, required=True, 17 | help="Save the nums of reference frame.") 18 | parser.add_argument( 19 | "-c", "--config-path", type=str, required=True, 20 | help="The config to load the dataset.") 21 | parser.add_argument( 22 | "-o", "--output-path", type=str, required=True, 23 | help="The path to save data packets.") 24 | 25 | return parser 26 | 27 | if __name__ == "__main__": 28 | parser = create_parser() 29 | args = parser.parse_args() 30 | 31 | with open(args.config_path, "r", encoding="utf-8") as f: 32 | config = json.load(f) 33 | 34 | validation_dataset = dwm.common.create_instance_from_config( 35 | config["validation_dataset"]) 36 | validation_dataloader = torch.utils.data.DataLoader( 37 | validation_dataset, 38 | **dwm.common.instantiate_config(config["validation_dataloader"])) 39 | print("The validation dataset is loaded with {} items.".format( 40 | len(validation_dataset))) 41 | 42 | sensor_channels = [ 43 | "CAM_FRONT_LEFT", 44 | "CAM_FRONT", 45 | "CAM_FRONT_RIGHT", 46 | "CAM_BACK_RIGHT", 47 | "CAM_BACK", 48 | "CAM_BACK_LEFT" 49 | ] 50 | 51 | for batch in validation_dataloader: 52 | scene_name = batch["scene"]["name"][0] 53 | output_path = os.path.join(args.output_path, scene_name) 54 | os.makedirs(output_path, exist_ok=True) 55 | timestamp = 0 56 | 57 | json_path = os.path.join(output_path, "data.json") 58 | with open(json_path, "w") as json_file: pass 59 | 60 | data_package = dict() 61 | frame_data = dict() 62 | for frame in range(batch["vae_images"].shape[1]): 63 | frame_data["camera_infos"] = dict() 64 | for sensor_channel in sensor_channels: 65 | frame_data["camera_infos"][sensor_channel] = dict() 66 | 67 | for view in range(len(sensor_channels)): 68 | frame_data["camera_infos"][sensor_channels[view]]["extrin"] = \ 69 | batch["camera_transforms"][0, frame, view].tolist() 70 | frame_data["camera_infos"][sensor_channels[view]]["intrin"] = \ 71 | batch["camera_intrinsics"][0, frame, view].tolist() 72 | frame_data["camera_infos"][sensor_channels[view]]["image_description"] = \ 73 | batch["clip_text"][0][frame][view] 74 | 75 | if frame % batch["vae_images"].shape[1] < args.reference_frame_count: 76 | image_output_path = os.path.join( 77 | output_path, sensor_channels[view], "rgb", f"{timestamp}.png") 78 | os.makedirs(os.path.dirname(image_output_path), exist_ok=True) 79 | torchvision.transforms.functional.to_pil_image( 80 | batch["vae_images"][0, frame, view]).save(image_output_path) 81 | frame_data["camera_infos"][sensor_channels[view]]["rgb"] = \ 82 | os.path.relpath(image_output_path, output_path) 83 | else: 84 | frame_data["camera_infos"][sensor_channels[view]]["rgb"] = None 85 | 86 | _3dbox_output_path = os.path.join( 87 | output_path, sensor_channels[view], "3dbox", f"{timestamp}.png") 88 | os.makedirs(os.path.dirname(_3dbox_output_path), exist_ok=True) 89 | torchvision.transforms.functional.to_pil_image( 90 | batch["3dbox_images"][0, frame, view]).save(_3dbox_output_path) 91 | frame_data["camera_infos"][sensor_channels[view]]["3dbox"] = \ 92 | os.path.relpath(_3dbox_output_path, output_path) 93 | 94 | hdmap_output_path = os.path.join( 95 | output_path, sensor_channels[view], "hdmap", f"{timestamp}.png") 96 | os.makedirs(os.path.dirname(hdmap_output_path), exist_ok=True) 97 | torchvision.transforms.functional.to_pil_image( 98 | batch["hdmap_images"][0, frame, view]).save(hdmap_output_path) 99 | frame_data["camera_infos"][sensor_channels[view]]["hdmap"] = \ 100 | os.path.relpath(hdmap_output_path, output_path) 101 | 102 | frame_data["timestamp"] = timestamp 103 | timestamp += 1/int(batch["fps"]) 104 | timestamp = round(timestamp, 4) 105 | frame_data["ego_pose"] = batch["ego_transforms"][0, frame, 0].tolist() 106 | data_package[frame] = copy.deepcopy(frame_data) 107 | 108 | with open(json_path, "a") as json_file: 109 | json.dump(data_package, json_file, indent=4) 110 | json_file.write("\n") 111 | -------------------------------------------------------------------------------- /src/dwm/utils/preview.py: -------------------------------------------------------------------------------- 1 | import av 2 | import torch 3 | import torchvision 4 | 5 | 6 | def make_ctsd_preview_tensor(output_images, batch, inference_config): 7 | 8 | # The output image sequece length may be shorter than the input due to the 9 | # autoregressive inference, so use the output sequence length to clip batch 10 | # data. 11 | batch_size, _, view_count = batch["vae_images"].shape[:3] 12 | output_images = output_images\ 13 | .cpu().unflatten(0, (batch_size, -1, view_count)) 14 | sequence_length = output_images.shape[1] 15 | 16 | collected_images = [batch["vae_images"][:, :sequence_length]] 17 | if "3dbox_images" in batch: 18 | collected_images.append( 19 | batch["3dbox_images"][:, :sequence_length]) 20 | 21 | if "hdmap_images" in batch: 22 | collected_images.append( 23 | batch["hdmap_images"][:, :sequence_length]) 24 | 25 | collected_images.append(output_images) 26 | 27 | stacked_images = torch.stack(collected_images) 28 | resized_images = torch.nn.functional.interpolate( 29 | stacked_images.flatten(0, 3), 30 | tuple(inference_config["preview_image_size"][::-1]) 31 | ) 32 | resized_images = resized_images.view( 33 | *stacked_images.shape[:4], -1, *resized_images.shape[-2:]) 34 | if sequence_length == 1: 35 | # image preview with shape [C, B * T * S * H, V * W] 36 | preview_tensor = resized_images.permute(4, 1, 2, 0, 5, 3, 6)\ 37 | .flatten(-2).flatten(1, 4) 38 | else: 39 | # video preview with shape [T, C, B * S * H, V * W] 40 | preview_tensor = resized_images.permute(2, 4, 1, 0, 5, 3, 6)\ 41 | .flatten(-2).flatten(2, 4) 42 | 43 | return preview_tensor 44 | 45 | 46 | def make_lidar_preview_tensor( 47 | ground_truth_volumn, generated_volumn, batch, inference_config 48 | ): 49 | collected_images = [ 50 | ground_truth_volumn.amax(-3, keepdim=True).repeat_interleave(3, -3) 51 | .cpu() 52 | ] 53 | if "3dbox_bev_images_denorm" in batch: 54 | collected_images.append(batch["3dbox_bev_images_denorm"]) 55 | 56 | if "hdmap_bev_images_denorm" in batch: 57 | collected_images.append(batch["hdmap_bev_images_denorm"]) 58 | 59 | if isinstance(generated_volumn, list): 60 | for gv in generated_volumn: 61 | collected_images.append( 62 | gv.amax(-3, keepdim=True).repeat_interleave(3, -3).cpu()) 63 | else: 64 | collected_images.append( 65 | generated_volumn.amax(-3, keepdim=True).repeat_interleave(3, -3).cpu()) 66 | 67 | # assume all BEV images have the same size 68 | stacked_images = torch.stack(collected_images) 69 | if ground_truth_volumn.shape[1] == 1: 70 | # BEV image preview with shape [C, B * T * H, S * W] 71 | preview_tensor = stacked_images.permute(3, 1, 2, 4, 0, 5).flatten(-2)\ 72 | .flatten(1, 3) 73 | else: 74 | # BEV video preview with shape [T, C, B * H, S * W] 75 | preview_tensor = stacked_images.permute(2, 3, 1, 4, 0, 5).flatten(-2)\ 76 | .flatten(2, 3) 77 | 78 | return preview_tensor 79 | 80 | 81 | def save_tensor_to_video( 82 | path: str, video_encoder: str, fps, tensor_list, pix_fmt: str = "yuv420p", 83 | stream_options: dict = {"crf": "16"} 84 | ): 85 | tensor_shape = tensor_list[0].shape 86 | with av.open(path, mode="w") as container: 87 | stream = container.add_stream(video_encoder, int(fps)) 88 | stream.width = tensor_shape[-1] 89 | stream.height = tensor_shape[-2] 90 | stream.pix_fmt = pix_fmt 91 | stream.options = stream_options 92 | for i in tensor_list: 93 | frame = av.VideoFrame.from_image( 94 | torchvision.transforms.functional.to_pil_image(i)) 95 | for p in stream.encode(frame): 96 | container.mux(p) 97 | 98 | for p in stream.encode(): 99 | container.mux(p) 100 | 101 | 102 | def gray_to_colormap(img, cmap='rainbow', max_val=None): 103 | """ 104 | Transfer gray map to matplotlib colormap 105 | """ 106 | assert img.ndim == 2 107 | import matplotlib 108 | import matplotlib.cm 109 | 110 | img[img<0] = 0 111 | mask_invalid = img < 1e-10 112 | if max_val is None: 113 | img = img / (img.max() + 1e-8) 114 | else: 115 | img = img / (max_val + 1e-8) 116 | norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) 117 | cmap_m = matplotlib.cm.get_cmap(cmap) 118 | map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) 119 | colormap = map.to_rgba(img)[:, :, :3] 120 | colormap[mask_invalid] = 0 121 | return colormap 122 | 123 | def depths_to_colors(depths, concat="width", colormap="rainbow", max_val=None): 124 | colors = [] 125 | if isinstance(depths, list) or len(depths.shape) == 4: 126 | for depth in depths: 127 | color = gray_to_colormap(depth.detach().cpu().numpy(), cmap=colormap, max_val=max_val) 128 | colors.append(color.permute(2, 0, 1)) 129 | if concat == "width": 130 | colors = torch.cat(colors, dim=2) 131 | else: 132 | colors = torch.stack(colors) 133 | else: 134 | colors = gray_to_colormap(depths.detach().cpu().numpy(), cmap=colormap, max_val=max_val) 135 | colors = torch.from_numpy(colors).permute(2, 0, 1) 136 | return colors 137 | -------------------------------------------------------------------------------- /src/dwm/streaming.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import av 3 | import dwm.common 4 | import einops 5 | import json 6 | import numpy as np 7 | from PIL import Image 8 | import queue 9 | import time 10 | import torch 11 | 12 | 13 | def create_parser(): 14 | parser = argparse.ArgumentParser( 15 | description="The script to finetune a stable diffusion model to the " 16 | "driving dataset.") 17 | parser.add_argument( 18 | "-c", "--config-path", type=str, required=True, 19 | help="The config to load the train model and dataset.") 20 | parser.add_argument( 21 | "-l", "--log-path", type=str, required=True, 22 | help="The path to save log files.") 23 | parser.add_argument( 24 | "-s", "--streaming-path", type=str, required=True, 25 | help="The path to upload the video stream.") 26 | parser.add_argument( 27 | "-f", "--format", default="rtsp", type=str, 28 | help="The streaming format.") 29 | parser.add_argument( 30 | "--fps", default=2, type=int, 31 | help="The streaming FPS.") 32 | parser.add_argument( 33 | "-vcodec", "--video-encoder", default="libx264", type=str, 34 | help="The video encoder type.") 35 | parser.add_argument( 36 | "--pix-fmt", default="yuv420p", type=str, 37 | help="The pixel format.") 38 | return parser 39 | 40 | 41 | def merge_multiview_images(pipeline_frame, data_condition=None): 42 | image_data = np.concatenate([np.asarray(i) for i in pipeline_frame], 1) 43 | if data_condition is not None: 44 | _3dbox_data = torch.nn.functional.interpolate( 45 | einops.rearrange( 46 | data_condition["3dbox_images"], 47 | "b t v c h w -> b c h (t v w)"), 48 | image_data.shape[:2] 49 | )[0].permute(1, 2, 0).numpy() 50 | hdmap_data = torch.nn.functional.interpolate( 51 | einops.rearrange( 52 | data_condition["hdmap_images"], 53 | "b t v c h w -> b c h (t v w)"), 54 | image_data.shape[:2] 55 | )[0].permute(1, 2, 0).numpy() 56 | condition_data = np.maximum(_3dbox_data, hdmap_data) 57 | condition_ahpla = np.max(condition_data, -1, keepdims=True) * 0.6 58 | image_data = ( 59 | condition_data * 255 * condition_ahpla + 60 | image_data * (1 - condition_ahpla) 61 | ).astype(np.uint8) 62 | 63 | return Image.fromarray(image_data) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = create_parser() 68 | args = parser.parse_args() 69 | 70 | with open(args.config_path, "r", encoding="utf-8") as f: 71 | config = json.load(f) 72 | 73 | # setup the global state 74 | if "global_state" in config: 75 | for key, value in config["global_state"].items(): 76 | dwm.common.global_state[key] = \ 77 | dwm.common.create_instance_from_config(value) 78 | 79 | # load the pipeline including the models 80 | pipeline = dwm.common.create_instance_from_config( 81 | config["pipeline"], output_path=args.log_path, config=config, 82 | device=torch.device(config["device"])) 83 | print("The pipeline is loaded.") 84 | 85 | data_adapter = dwm.common.create_instance_from_config( 86 | config["data_adapter"]) 87 | 88 | size = pipeline.inference_config["preview_image_size"] 89 | latent_shape = ( 90 | 1, pipeline.inference_config["sequence_length_per_iteration"], 91 | len(data_adapter.sensor_channels), pipeline.vae.config.latent_channels, 92 | config["latent_size"][0], config["latent_size"][1] 93 | ) 94 | pipeline.reset_streaming(latent_shape, "pil") 95 | 96 | streaming_state = {} 97 | data_queue = queue.Queue() 98 | with av.open( 99 | args.streaming_path, mode="w", format=args.format, 100 | container_options=config.get("container_options", {}) 101 | ) as container: 102 | stream = container.add_stream(args.video_encoder, args.fps) 103 | stream.pix_fmt = args.pix_fmt 104 | stream.options = config.get("stream_options", {}) 105 | while True: 106 | data = data_adapter.query_data() 107 | data_queue.put_nowait(data) 108 | pipeline.send_frame_condition(data) 109 | pipeline_frame = pipeline.receive_frame() 110 | if pipeline_frame is None: 111 | continue 112 | 113 | matched_data = data_queue.get_nowait() 114 | image = merge_multiview_images( 115 | pipeline_frame, 116 | ( 117 | matched_data 118 | if config.get("preview_condition", False) 119 | else None 120 | )) 121 | if not streaming_state.get("is_frame_size_set", False): 122 | stream.width = image.width 123 | stream.height = image.height 124 | streaming_state["is_frame_size_set"] = True 125 | 126 | while ( 127 | "expected_time" in streaming_state and 128 | time.time() < streaming_state["expected_time"] 129 | ): 130 | time.sleep(0.01) 131 | 132 | frame = av.VideoFrame.from_image(image) 133 | for p in stream.encode(frame): 134 | container.mux(p) 135 | 136 | streaming_state["expected_time"] = ( 137 | time.time() 138 | if "expected_time" not in streaming_state 139 | else streaming_state["expected_time"] 140 | ) + 1 / args.fps 141 | print("{:.1f}".format(streaming_state["expected_time"])) 142 | -------------------------------------------------------------------------------- /src/dwm/metrics/pc_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import torch.distributed 4 | from torchmetrics import Metric 5 | 6 | from dwm.utils.metrics_copilot4d import ( 7 | compute_chamfer_distance, 8 | compute_chamfer_distance_inner, 9 | compute_mmd, 10 | jsd_2d, 11 | gaussian, 12 | point_cloud_to_histogram 13 | ) 14 | 15 | class PointCloudChamfer(Metric): 16 | def __init__(self, inner_dist=None, **kwargs): 17 | super().__init__(**kwargs) 18 | self.inner_dist = inner_dist 19 | 20 | self.cd_func = compute_chamfer_distance if self.inner_dist is None else compute_chamfer_distance_inner 21 | self.chamfer_list = [] 22 | 23 | def update(self, pred_pcd, gt_pcd, device=None): 24 | for pred, gt in zip(pred_pcd, gt_pcd): 25 | for p, g in zip(pred, gt): 26 | if self.inner_dist is None: 27 | cd = self.cd_func(p.to(torch.float32), g.to(torch.float32), device=device) 28 | else: 29 | cd = self.cd_func(p.to(torch.float32), g.to(torch.float32), device=device, pc_range=[ 30 | -self.inner_dist, -self.inner_dist, -3, self.inner_dist, self.inner_dist, 5]) 31 | if not isinstance(cd, torch.Tensor): 32 | cd = torch.tensor(cd).to(device) 33 | self.chamfer_list.append(cd.float()) 34 | 35 | def compute(self): 36 | chamfer_list = torch.stack(self.chamfer_list, dim=0) 37 | world_size = torch.distributed.get_world_size() \ 38 | if torch.distributed.is_initialized() else 1 39 | if world_size > 1: 40 | all_chamfer = chamfer_list.new_zeros( 41 | (len(chamfer_list)*world_size, ) + chamfer_list.shape[1:]) 42 | torch.distributed.all_gather_into_tensor( 43 | all_chamfer, chamfer_list) 44 | chamfer_list = all_chamfer 45 | num_samples = (~torch.isnan(chamfer_list) & ~torch.isinf(chamfer_list)).sum() 46 | chamfer_list = torch.nan_to_num(chamfer_list, nan=0.0, posinf=0.0, neginf=0.0) 47 | return chamfer_list.sum() / num_samples 48 | 49 | def reset(self): 50 | self.chamfer_list.clear() 51 | super().reset() 52 | 53 | 54 | class PointCloudMMD(Metric): 55 | """ 56 | Compute the Maximum Mean Discrepancy (MMD) between two point clouds. 57 | """ 58 | def __init__(self, field_size=160, bins=100, **kwargs): 59 | super().__init__(**kwargs) 60 | self.field_size = field_size 61 | self.bins = bins 62 | self.mmd_list = [] 63 | def update(self, pred_pcd, gt_pcd, device=None): 64 | pred_hist = [] 65 | gt_hist = [] 66 | for pred, gt in zip(pred_pcd, gt_pcd): 67 | for p, g in zip(pred, gt): 68 | p = point_cloud_to_histogram(self.field_size, self.bins, p.to(torch.float32))[0] 69 | g = point_cloud_to_histogram(self.field_size, self.bins, g.to(torch.float32))[0] 70 | pred_hist.append(p) 71 | gt_hist.append(g) 72 | mmd = compute_mmd(pred_hist, gt_hist, kernel=gaussian, is_parallel=True) 73 | if not isinstance(mmd, torch.Tensor): 74 | mmd = torch.tensor(mmd).to(device) 75 | self.mmd_list.append(mmd.float()) 76 | 77 | def compute(self): 78 | mmd_list = torch.stack(self.mmd_list, dim=0) 79 | world_size = torch.distributed.get_world_size() \ 80 | if torch.distributed.is_initialized() else 1 81 | if world_size > 1: 82 | all_mmd = mmd_list.new_zeros( 83 | (len(mmd_list)*world_size, ) + mmd_list.shape[1:]) 84 | torch.distributed.all_gather_into_tensor( 85 | all_mmd, mmd_list) 86 | mmd_list = all_mmd 87 | num_samples =(~torch.isnan(mmd_list) & ~torch.isinf(mmd_list)).sum() 88 | mmd_list = torch.nan_to_num(mmd_list, nan=0.0, posinf=0.0, neginf=0.0) 89 | return mmd_list.sum() / num_samples 90 | 91 | def reset(self): 92 | self.mmd_list.clear() 93 | super().reset() 94 | 95 | class PointCloudJSD(Metric): 96 | def __init__(self, field_size=160, bins=100, **kwargs): 97 | super().__init__(**kwargs) 98 | self.field_size = field_size 99 | self.bins = bins 100 | self.jsd_list = [] 101 | 102 | def update(self, pred_pcd, gt_pcd, device=None): 103 | for pred, gt in zip(pred_pcd, gt_pcd): 104 | for p, g in zip(pred, gt): 105 | p = point_cloud_to_histogram(self.field_size, self.bins, p)[0] 106 | g = point_cloud_to_histogram(self.field_size, self.bins, g)[0] 107 | jsd = jsd_2d(p, g) 108 | if not isinstance(jsd, torch.Tensor): 109 | jsd = torch.tensor(jsd).to(device) 110 | self.jsd_list.append(jsd.float()) 111 | 112 | def compute(self): 113 | jsd_list = torch.stack(self.jsd_list, dim=0) 114 | world_size = torch.distributed.get_world_size() \ 115 | if torch.distributed.is_initialized() else 1 116 | if world_size > 1: 117 | all_jsd = jsd_list.new_zeros( 118 | (len(jsd_list)*world_size, ) + jsd_list.shape[1:]) 119 | torch.distributed.all_gather_into_tensor( 120 | all_jsd, jsd_list) 121 | jsd_list = all_jsd 122 | num_samples = (~torch.isnan(jsd_list) & ~torch.isinf(jsd_list)).sum() 123 | jsd_list = torch.nan_to_num(jsd_list, nan=0.0, posinf=0.0, neginf=0.0) 124 | return jsd_list.sum() / num_samples 125 | 126 | def reset(self): 127 | self.jsd_list.clear() 128 | super().reset() 129 | -------------------------------------------------------------------------------- /src/dwm/tools/dataset_make_info_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.tools.fs_make_info_json 3 | import fsspec.implementations.local 4 | import json 5 | import os 6 | import struct 7 | import re 8 | 9 | 10 | def create_parser(): 11 | parser = argparse.ArgumentParser( 12 | description="The script to make information JSON(s) for dataset to " 13 | "accelerate initialization.") 14 | parser.add_argument( 15 | "-dt", "--dataset-type", type=str, 16 | choices=["nuscenes", "waymo", "argoverse"], required=True, 17 | help="The dataset type.") 18 | parser.add_argument( 19 | "-s", "--split", default=None, type=str, 20 | help="The split, optional depending on the dataset type.") 21 | parser.add_argument( 22 | "-i", "--input-path", type=str, required=True, 23 | help="The path of the dataset root.") 24 | parser.add_argument( 25 | "-o", "--output-path", type=str, required=True, 26 | help="The path to save the information JSON file(s) on the local file " 27 | "system.") 28 | parser.add_argument( 29 | "-fs", "--fs-config-path", default=None, type=str, 30 | help="The path of file system JSON config to open the dataset.") 31 | return parser 32 | 33 | 34 | if __name__ == "__main__": 35 | import tqdm 36 | 37 | parser = create_parser() 38 | args = parser.parse_args() 39 | 40 | if args.fs_config_path is None: 41 | fs = fsspec.implementations.local.LocalFileSystem() 42 | else: 43 | import dwm.common 44 | with open(args.fs_config_path, "r", encoding="utf-8") as f: 45 | fs = dwm.common.create_instance_from_config(json.load(f)) 46 | 47 | if args.dataset_type == "nuscenes": 48 | files = [ 49 | os.path.relpath(i, args.input_path) 50 | for i in fs.ls(args.input_path, detail=False) 51 | ] 52 | filtered_files = [ 53 | i for i in files 54 | if ( 55 | (args.split is None or i.startswith(args.split)) and 56 | i.endswith(".zip") 57 | ) 58 | ] 59 | assert len(filtered_files) > 0, ( 60 | "No files detected, please check the split (one of \"v1.0-mini\", " 61 | "\"v1.0-trainval\", \"v1.0-test\") is correct, and ensure the " 62 | "blob files are already converted to the ZIP format." 63 | ) 64 | 65 | os.makedirs(args.output_path, exist_ok=True) 66 | for i in tqdm.tqdm(filtered_files): 67 | with fs.open("{}/{}".format(args.input_path, i)) as f: 68 | items = dwm.tools.fs_make_info_json.make_info_dict( 69 | os.path.splitext(i)[-1], f) 70 | 71 | with open( 72 | os.path.join( 73 | args.output_path, i.replace(".zip", ".info.json")), 74 | "w", encoding="utf-8" 75 | ) as f: 76 | json.dump(items, f) 77 | 78 | elif args.dataset_type == "waymo": 79 | import waymo_open_dataset.dataset_pb2 as waymo_pb 80 | 81 | files = [ 82 | os.path.relpath(i, args.input_path) 83 | for i in fs.ls(args.input_path, detail=False) 84 | if i.endswith(".tfrecord") 85 | ] 86 | assert len(files) > 0, "No files detected." 87 | 88 | pattern = re.compile( 89 | "^segment-(?P.*)_with_camera_labels.tfrecord$") 90 | info_dict = {} 91 | for i in tqdm.tqdm(files): 92 | match = re.match(pattern, i) 93 | scene = match.group("scene") 94 | pt = 0 95 | info_list = [] 96 | with fs.open("{}/{}".format(args.input_path, i)) as f: 97 | while True: 98 | start = f.read(8) 99 | if len(start) == 0: 100 | break 101 | 102 | size, = struct.unpack(" 0, ( 130 | "No files detected, please check the split (one of \"train\", " 131 | "\"val\", \"test\") is correct." 132 | ) 133 | 134 | os.makedirs(args.output_path, exist_ok=True) 135 | for i in tqdm.tqdm(files): 136 | with fs.open("{}/{}".format(args.input_path, i)) as f: 137 | items = dwm.tools.fs_make_info_json.make_info_dict( 138 | os.path.splitext(i)[-1], f, enable_tqdm=False) 139 | 140 | with open( 141 | os.path.join( 142 | args.output_path, i.replace(".tar", ".info.json")), 143 | "w", encoding="utf-8" 144 | ) as f: 145 | json.dump(items, f) 146 | 147 | else: 148 | raise Exception("Unknown dataset type {}.".format(args.dataset_type)) 149 | -------------------------------------------------------------------------------- /src/dwm/fs/czip.py: -------------------------------------------------------------------------------- 1 | import dwm.common 2 | import fsspec 3 | import fsspec.archive 4 | import io 5 | import json 6 | import os 7 | import re 8 | import struct 9 | import zipfile 10 | import zlib 11 | 12 | 13 | class CombinedZipFileSystem(fsspec.archive.AbstractArchiveFileSystem): 14 | 15 | """This file system can be used to access files inside ZIP files, and it is 16 | also compatible with the process fork under the multi-worker situation of 17 | the PyTorch data loader. 18 | 19 | Args: 20 | paths (list): The path list of ZIP files to open. 21 | fs (fsspec.AbstractFileSystem or None): The file system to open the ZIP 22 | blobs, or local file system as default. 23 | enable_cached_info (bool): Load the byte offset of ZIP entries from 24 | cached info file with the extension `.info.json` to accelerate 25 | opening large ZIP files. 26 | """ 27 | 28 | root_marker = "" 29 | protocol = "czip" 30 | cachable = False 31 | 32 | def __init__( 33 | self, paths: list, fs=None, 34 | enable_cached_info: bool = False, **kwargs 35 | ): 36 | super().__init__(**kwargs) 37 | 38 | self.fs = fsspec.implementations.local.LocalFileSystem() \ 39 | if fs is None else fs 40 | self.paths = paths 41 | 42 | _belongs_to = {} 43 | self._info = {} 44 | pattern = re.compile(r"\.zip$") 45 | for path in paths: 46 | cached_info_path = re.sub(pattern, ".info.json", path) 47 | if enable_cached_info: 48 | with fs.open(cached_info_path, "r", encoding="utf-8") as f: 49 | infodict = json.load(f) 50 | 51 | for filename, i in infodict.items(): 52 | _belongs_to[filename] = path 53 | 54 | else: 55 | with fs.open(path) as f: 56 | with zipfile.ZipFile(f) as zf: 57 | infodict = {} 58 | for i in zf.infolist(): 59 | _belongs_to[i.filename] = path 60 | infodict[i.filename] = [ 61 | i.header_offset, i.file_size, i.is_dir() 62 | ] 63 | 64 | self._info[path] = dwm.common.SerializedReadonlyDict(infodict) 65 | 66 | self._belongs_to = dwm.common.SerializedReadonlyDict(_belongs_to) 67 | 68 | self.dir_cache = None 69 | self.fp_cache = {} 70 | 71 | @classmethod 72 | def _strip_protocol(cls, path): 73 | # zip file paths are always relative to the archive root 74 | return super()._strip_protocol(path).lstrip("/") 75 | 76 | def close(self): 77 | # the file points are not forkable 78 | current_pid = os.getpid() 79 | if self._pid != current_pid: 80 | self.fp_cache.clear() 81 | self._pid = current_pid 82 | 83 | for fp in self.fp_cache.values(): 84 | fp.close() 85 | 86 | self.fp_cache.clear() 87 | 88 | def _get_dirs(self): 89 | if self.dir_cache is None: 90 | self.dir_cache = { 91 | dirname.rstrip("/"): { 92 | "name": dirname.rstrip("/"), 93 | "size": 0, 94 | "type": "directory", 95 | } 96 | for dirname in self._all_dirnames(self._belongs_to.keys()) 97 | } 98 | 99 | for file_name in self._belongs_to.keys(): 100 | zip_path = self._belongs_to[file_name] 101 | _, file_size, is_dir = self._info[zip_path][file_name] 102 | f = { 103 | "name": file_name.rstrip("/"), 104 | "size": file_size, 105 | "type": "directory" if is_dir else "file", 106 | } 107 | self.dir_cache[f["name"]] = f 108 | 109 | def exists(self, path, **kwargs): 110 | return path in self._belongs_to 111 | 112 | def _open( 113 | self, path, mode="rb", block_size=None, autocommit=True, 114 | cache_options=None, **kwargs, 115 | ): 116 | if "w" in mode: 117 | raise OSError("Combined stateless ZIP file system is read only.") 118 | 119 | if "b" not in mode: 120 | raise OSError( 121 | "Combined stateless ZIP file system only support the mode of " 122 | "binary.") 123 | 124 | if not self.exists(path): 125 | raise FileNotFoundError(path) 126 | 127 | zip_path = self._belongs_to[path] 128 | header_offset, size, is_dir = self._info[zip_path][path] 129 | if is_dir: 130 | raise FileNotFoundError(path) 131 | 132 | # the file points are not forkable 133 | current_pid = os.getpid() 134 | if self._pid != current_pid: 135 | self.fp_cache.clear() 136 | self._pid = current_pid 137 | 138 | if zip_path in self.fp_cache: 139 | f = self.fp_cache[zip_path] 140 | else: 141 | f = self.fp_cache[zip_path] = \ 142 | self.fs.open(zip_path, mode, block_size, cache_options) 143 | 144 | f.seek(header_offset) 145 | fh = struct.unpack(zipfile.structFileHeader, f.read(30)) 146 | method = fh[zipfile._FH_COMPRESSION_METHOD] 147 | offset = header_offset + 30 + fh[zipfile._FH_FILENAME_LENGTH] + \ 148 | fh[zipfile._FH_EXTRA_FIELD_LENGTH] 149 | 150 | if method == zipfile.ZIP_STORED: 151 | return dwm.common.PartialReadableRawIO(f, offset, offset + size) 152 | 153 | elif method == zipfile.ZIP_DEFLATED: 154 | f.seek(offset) 155 | data = f.read(size) 156 | result = io.BytesIO(zlib.decompress(data, -15)) 157 | return result 158 | 159 | else: 160 | raise NotImplementedError("The compression method is unsupported") 161 | -------------------------------------------------------------------------------- /src/dwm/utils/carla_simulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import carla 3 | import dwm.common 4 | import json 5 | import random 6 | import time 7 | 8 | 9 | def create_parser(): 10 | parser = argparse.ArgumentParser( 11 | description="The tool to setup the Carla simulation.") 12 | parser.add_argument( 13 | "-c", "--config-path", type=str, required=True, 14 | help="The path of the simulation config about the environment, ego " 15 | "vehicle, and scenario.") 16 | parser.add_argument( 17 | "--host", default="127.0.0.1", type=str, 18 | help="The host address of the Carla simulator.") 19 | parser.add_argument( 20 | "-p", "--port", default=2000, type=int, 21 | help="The port of the Carla simulator.") 22 | parser.add_argument( 23 | "-tp", "--traffic-port", default=8000, type=int, 24 | help="The port of the Traffic manager.") 25 | parser.add_argument( 26 | "--client-timeout", default=10.0, type=float, 27 | help="The timeout of the Carla client.") 28 | parser.add_argument( 29 | "--step-sleep", default=0.0, type=float, 30 | help="The time to sleep for each step.") 31 | return parser 32 | 33 | 34 | def make_actor( 35 | world: carla.World, blueprint_library: carla.BlueprintLibrary, 36 | spawn_points: list, actor_config: dict, random_state: random.Random, 37 | attach_to=None 38 | ): 39 | # prepare the blueprint 40 | if "pattern" in actor_config: 41 | bp_list = blueprint_library.filter(actor_config["pattern"]) 42 | bp = ( 43 | bp_list[actor_config["matched_index"]] 44 | if "matched_index" in actor_config 45 | else random_state.choice(bp_list) 46 | ) 47 | else: 48 | bp = blueprint_library.find(actor_config["id"]) 49 | 50 | for k, v in actor_config.get("attributes", {}).items(): 51 | bp.set_attribute(k, v) 52 | 53 | # prepare the spawn location for vehicles, pedestrians, cameras 54 | if "spawn_index" in actor_config: 55 | spawn_transform = spawn_points[ 56 | actor_config["spawn_index"] % len(spawn_points) 57 | ] 58 | elif "spawn_from_navigation" in actor_config: 59 | location = world.get_random_location_from_navigation() 60 | spawn_transform = carla.Transform(location, carla.Rotation(0, 0, 0)) 61 | else: 62 | spawn_transform = actor_config["spawn_transform"] 63 | spawn_transform = carla.Transform( 64 | carla.Location(*spawn_transform.get("location", [0, 0, 0])), 65 | carla.Rotation(*spawn_transform.get("rotation", [0, 0, 0]))) 66 | 67 | # instantiate the actor, set attributes and apply custom setup functions 68 | actor = world.try_spawn_actor(bp, spawn_transform, attach_to) 69 | 70 | if actor is not None: 71 | if actor.attributes.get("role_name") == "autopilot": 72 | actor.set_autopilot(True) 73 | 74 | if actor is None: 75 | print("Warning: failed to spawn {}".format(bp.id)) 76 | return None, None, None 77 | 78 | if actor_config.get("report_actor_id", False): 79 | attributes = actor.attributes 80 | report_text = "{}{}: {}".format( 81 | actor_config["id"], 82 | " ({})".format(attributes["role_name"]) 83 | if "role_name" in attributes else "", actor.id) 84 | print(report_text) 85 | 86 | if actor_config.get("report_actor_attributes", False): 87 | print("{}: {}".format(actor.type_id, actor.attributes)) 88 | 89 | if "state_machine" in actor_config: 90 | _class = dwm.common.get_class(actor_config["state_machine"]) 91 | state_machine = _class( 92 | actor, **actor_config.get("state_machine_args", {})) 93 | else: 94 | state_machine = None 95 | 96 | children = [ 97 | make_actor( 98 | world, blueprint_library, spawn_points, i, rs, actor) 99 | for i in actor_config["child_configs"] 100 | ] if "child_configs" in actor_config else None 101 | 102 | return actor, state_machine, children 103 | 104 | 105 | def update_actor_state(actors: list): 106 | for _, state_machine, children in actors: 107 | if state_machine is not None: 108 | state_machine.update() 109 | 110 | if children is not None: 111 | update_actor_state(children) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = create_parser() 116 | args = parser.parse_args() 117 | 118 | with open(args.config_path, "r", encoding="utf-8") as f: 119 | config = json.load(f) 120 | 121 | rs = random.Random(config.get("seed", None)) 122 | client = carla.Client(args.host, args.port, 1) 123 | client.set_timeout(args.client_timeout) 124 | 125 | # This script just use the activated map. To config the map, please use the 126 | # {CarlaRoot}/PythonAPI/util/config.py 127 | world = client.get_world() 128 | traffic_manager = client.get_trafficmanager(args.traffic_port) 129 | 130 | if config.get("master", False): 131 | traffic_manager.set_synchronous_mode(True) 132 | 133 | if "world_settings" in config: 134 | settings = world.get_settings() 135 | for k, v in config["world_settings"].items(): 136 | setattr(settings, k, v) 137 | 138 | world.apply_settings(settings) 139 | 140 | if "traffic_manager_settings" in config: 141 | for k, v in config["traffic_manager_settings"].items(): 142 | getattr(traffic_manager, k)(v) 143 | 144 | actors = [ 145 | make_actor( 146 | world, world.get_blueprint_library(), 147 | world.get_map().get_spawn_points(), i, rs) 148 | for i in config["actor_configs"] 149 | ] 150 | 151 | step = 0 152 | total_steps = config.get("total_steps", -1) 153 | while total_steps == -1 or step < total_steps: 154 | if args.step_sleep > 0.0: 155 | time.sleep(args.step_sleep) 156 | 157 | if config.get("master", False): 158 | world.tick() 159 | else: 160 | world.wait_for_tick() 161 | 162 | update_actor_state(actors) 163 | step += 1 164 | -------------------------------------------------------------------------------- /src/dwm/export_generation_result_as_nuscenes_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.common 3 | import json 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | 9 | def create_parser(): 10 | parser = argparse.ArgumentParser( 11 | description="The script to run the diffusion model to generate data for" 12 | "detection evaluation.") 13 | parser.add_argument( 14 | "-c", "--config-path", type=str, required=True, 15 | help="The config to load the train model and dataset.") 16 | parser.add_argument( 17 | "-o", "--output-path", type=str, required=True, 18 | help="The path to save checkpoint files.") 19 | return parser 20 | 21 | 22 | """ 23 | This script requires: 24 | 1. The validation dataset should be only nuScenes. 25 | 2. Add `"enable_sample_data": true` to the dataset arguments, so the Dataset 26 | load the filename (path) to save the generated images. 27 | 3. Add `"sample_data"` to "validation_dataloader.collate_fn.keys" to pass the 28 | object data directly to the script here, no need to collate to tensors. 29 | 30 | Note: 31 | *. Set the "model_checkpoint_path" with the trained checkpoint, rather than 32 | the pretrained checkpoint. 33 | *. Set the fps_stride to [0, 1] for the image dataset, or 34 | [2, 0.5 * sequence_length] for the origin video dataset, or 35 | [12, 0.1 * sequence_length] for the 12Hz video dataset. Make sure no 36 | sample is missed. 37 | """ 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = create_parser() 42 | args = parser.parse_args() 43 | 44 | with open(args.config_path, "r", encoding="utf-8") as f: 45 | config = json.load(f) 46 | 47 | # set distributed training (if enabled), log, random number generator, and 48 | # load the checkpoint (if required). 49 | ddp = "LOCAL_RANK" in os.environ 50 | if ddp: 51 | local_rank = int(os.environ["LOCAL_RANK"]) 52 | device = torch.device(config["device"], local_rank) 53 | if config["device"] == "cuda": 54 | torch.cuda.set_device(local_rank) 55 | 56 | torch.distributed.init_process_group(backend=config["ddp_backend"]) 57 | else: 58 | device = torch.device(config["device"]) 59 | 60 | # setup the global state 61 | if "global_state" in config: 62 | for key, value in config["global_state"].items(): 63 | dwm.common.global_state[key] = \ 64 | dwm.common.create_instance_from_config(value) 65 | 66 | should_log = (ddp and local_rank == 0) or not ddp 67 | 68 | pipeline = dwm.common.create_instance_from_config( 69 | config["pipeline"], output_path=args.output_path, config=config, 70 | device=device) 71 | if should_log: 72 | print("The pipeline is loaded.") 73 | 74 | # load the dataset 75 | validation_dataset = dwm.common.create_instance_from_config( 76 | config["validation_dataset"]) 77 | if ddp: 78 | validation_datasampler = \ 79 | torch.utils.data.distributed.DistributedSampler( 80 | validation_dataset) 81 | validation_dataloader = torch.utils.data.DataLoader( 82 | validation_dataset, 83 | **dwm.common.instantiate_config(config["validation_dataloader"]), 84 | sampler=validation_datasampler) 85 | else: 86 | validation_datasampler = None 87 | validation_dataloader = torch.utils.data.DataLoader( 88 | validation_dataset, 89 | **dwm.common.instantiate_config(config["validation_dataloader"])) 90 | 91 | if should_log: 92 | print("The validation dataset is loaded with {} items.".format( 93 | len(validation_dataset))) 94 | 95 | if ddp: 96 | validation_datasampler.set_epoch(0) 97 | 98 | for batch in validation_dataloader: 99 | batch_size, sequence_length, view_count = batch["vae_images"].shape[:3] 100 | latent_height = batch["vae_images"].shape[-2] // \ 101 | (2 ** (len(pipeline.vae.config.down_block_types) - 1)) 102 | latent_width = batch["vae_images"].shape[-1] // \ 103 | (2 ** (len(pipeline.vae.config.down_block_types) - 1)) 104 | latent_shape = ( 105 | batch_size, sequence_length, view_count, 106 | pipeline.vae.config.latent_channels, latent_height, 107 | latent_width 108 | ) 109 | 110 | with torch.no_grad(): 111 | pipeline_output = pipeline.inference_pipeline( 112 | latent_shape, batch, "pil") 113 | 114 | if "images" in pipeline_output: 115 | paths = [ 116 | os.path.join(args.output_path, k["filename"]) 117 | for i in batch["sample_data"] 118 | for j in i 119 | for k in j if not k["filename"].endswith(".bin") 120 | ] 121 | image_results = pipeline_output["images"] 122 | image_sizes = batch["image_size"].flatten(0, 2) 123 | for path, image, image_size in zip(paths, image_results, image_sizes): 124 | dir = os.path.dirname(path) 125 | os.makedirs(dir, exist_ok=True) 126 | image.resize(tuple(image_size.int().tolist()))\ 127 | .save(path, quality=95) 128 | 129 | if "raw_points" in pipeline_output: 130 | paths = [ 131 | os.path.join(args.output_path, k["filename"]) 132 | for i in batch["sample_data"] 133 | for j in i 134 | for k in j if k["filename"].endswith(".bin") 135 | ] 136 | raw_points = [ 137 | j 138 | for i in pipeline_output["raw_points"] 139 | for j in i 140 | ] 141 | for path, points in zip(paths, raw_points): 142 | os.makedirs(os.path.dirname(path), exist_ok=True) 143 | points = points.numpy() 144 | padded_points = np.concatenate([ 145 | points, np.zeros((points.shape[0], 2), dtype=np.float32) 146 | ], axis=-1) 147 | with open(path, "wb") as f: 148 | f.write(padded_points.tobytes()) 149 | -------------------------------------------------------------------------------- /src/dwm/preview.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dwm.common 3 | import json 4 | import os 5 | import torch 6 | 7 | 8 | def customize_text(clip_text, preview_config): 9 | 10 | # text 11 | if preview_config["text"] is not None: 12 | text_config = preview_config["text"] 13 | 14 | if text_config["type"] == "add": 15 | new_clip_text = \ 16 | [ 17 | [ 18 | [ 19 | text_config["prompt"] + k 20 | for k in j 21 | ] 22 | for j in i 23 | ] 24 | for i in clip_text 25 | ] 26 | 27 | elif text_config["type"] == "replace": 28 | new_clip_text = \ 29 | [ 30 | [ 31 | [ 32 | text_config["prompt"] 33 | for k in j 34 | ] 35 | for j in i 36 | ] 37 | for i in clip_text 38 | ] 39 | 40 | elif text_config["type"] == "template": 41 | time = text_config["time"] 42 | weather = text_config["weather"] 43 | new_clip_text = \ 44 | [ 45 | [ 46 | [ 47 | text_config["template"][time][weather][idx][0] 48 | for idx, k in enumerate(j) 49 | ] 50 | for j in i 51 | ] 52 | for i in clip_text 53 | ] 54 | 55 | else: 56 | raise NotImplementedError( 57 | f"{text_config['type']}has not been implemented yet.") 58 | 59 | return new_clip_text 60 | 61 | else: 62 | 63 | return clip_text 64 | 65 | 66 | def create_parser(): 67 | parser = argparse.ArgumentParser( 68 | description="The script to finetune a stable diffusion model to the " 69 | "driving dataset.") 70 | parser.add_argument( 71 | "-c", "--config-path", type=str, required=True, 72 | help="The config to load the train model and dataset.") 73 | parser.add_argument( 74 | "-o", "--output-path", type=str, required=True, 75 | help="The path to save checkpoint files.") 76 | parser.add_argument( 77 | "-pc", "--preview-config-path", default=None, type=str, 78 | help="The config for preview setting") 79 | parser.add_argument( 80 | "-eic", "--export-item-config", default=False, type=bool, 81 | help="The flag to export the item config as JSON") 82 | return parser 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = create_parser() 87 | args = parser.parse_args() 88 | 89 | with open(args.config_path, "r", encoding="utf-8") as f: 90 | config = json.load(f) 91 | 92 | if args.preview_config_path is not None: 93 | with open(args.preview_config_path, "r", encoding="utf-8") as f: 94 | preview_config = json.load(f) 95 | else: 96 | preview_config = None 97 | 98 | # set distributed training (if enabled), log, random number generator, and 99 | # load the checkpoint (if required). 100 | ddp = "LOCAL_RANK" in os.environ 101 | if ddp: 102 | local_rank = int(os.environ["LOCAL_RANK"]) 103 | device = torch.device(config["device"], local_rank) 104 | if config["device"] == "cuda": 105 | torch.cuda.set_device(local_rank) 106 | 107 | torch.distributed.init_process_group(backend=config["ddp_backend"]) 108 | else: 109 | device = torch.device(config["device"]) 110 | 111 | # setup the global state 112 | if "global_state" in config: 113 | for key, value in config["global_state"].items(): 114 | dwm.common.global_state[key] = \ 115 | dwm.common.create_instance_from_config(value) 116 | 117 | should_log = (ddp and local_rank == 0) or not ddp 118 | should_save = not torch.distributed.is_initialized() or \ 119 | torch.distributed.get_rank() == 0 120 | 121 | # load the pipeline including the models 122 | pipeline = dwm.common.create_instance_from_config( 123 | config["pipeline"], output_path=args.output_path, config=config, 124 | device=device) 125 | if should_log: 126 | print("The pipeline is loaded.") 127 | 128 | validation_dataset = dwm.common.create_instance_from_config( 129 | config["validation_dataset"]) 130 | 131 | preview_dataloader = torch.utils.data\ 132 | .DataLoader( 133 | validation_dataset, 134 | **dwm.common.instantiate_config(config["preview_dataloader"])) if \ 135 | "preview_dataloader" in config else None 136 | 137 | if should_log: 138 | print("The validation dataset is loaded with {} items.".format( 139 | len(validation_dataset))) 140 | 141 | export_batch_except = ["vae_images"] 142 | output_path = args.output_path 143 | global_step = 0 144 | for batch in preview_dataloader: 145 | if ddp: 146 | torch.distributed.barrier() 147 | 148 | if preview_config is not None: 149 | new_clip_text = customize_text(batch["clip_text"], preview_config) 150 | batch["clip_text"] = new_clip_text 151 | 152 | pipeline.preview_pipeline( 153 | batch, output_path, global_step) 154 | 155 | if args.export_item_config: 156 | with open( 157 | os.path.join( 158 | output_path, "preview", 159 | "{}.json".format(global_step)), 160 | "w", encoding="utf-8" 161 | ) as f: 162 | json.dump({ 163 | k: v.tolist() if isinstance(v, torch.Tensor) else v 164 | for k, v in batch.items() 165 | if k not in export_batch_except 166 | }, f, indent=4) 167 | 168 | global_step += 1 169 | if should_log: 170 | print(f"preview: {global_step}") 171 | 172 | if torch.distributed.is_initialized(): 173 | torch.distributed.destroy_process_group() 174 | -------------------------------------------------------------------------------- /src/dwm/common.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import importlib 3 | import io 4 | import numpy as np 5 | import os 6 | import pickle 7 | 8 | 9 | class PartialReadableRawIO(io.RawIOBase): 10 | def __init__( 11 | self, base_io_object: io.RawIOBase, start: int, end: int, 12 | close_with_this_object: bool = False 13 | ): 14 | super().__init__() 15 | self.base_io_object = base_io_object 16 | self.p = self.start = start 17 | self.end = end 18 | self.close_with_this_object = close_with_this_object 19 | self.base_io_object.seek(start) 20 | 21 | def close(self): 22 | if self.close_with_this_object: 23 | self.base_io_object.close() 24 | 25 | @property 26 | def closed(self): 27 | return self.base_io_object.closed if self.close_with_this_object \ 28 | else False 29 | 30 | def readable(self): 31 | return self.base_io_object.readable() 32 | 33 | def read(self, size=-1): 34 | read_count = min(size, self.end - self.p) \ 35 | if size >= 0 else self.end - self.p 36 | data = self.base_io_object.read(read_count) 37 | self.p += read_count 38 | return data 39 | 40 | def readall(self): 41 | return self.read(-1) 42 | 43 | def seek(self, offset, whence=os.SEEK_SET): 44 | if whence == os.SEEK_SET: 45 | p = max(0, min(self.end - self.start, offset)) 46 | elif whence == os.SEEK_CUR: 47 | p = max( 48 | 0, min(self.end - self.start, self.p - self.start + offset)) 49 | elif whence == os.SEEK_END: 50 | p = max( 51 | 0, min(self.end - self.start, self.end - self.start + offset)) 52 | 53 | self.p = self.base_io_object.seek(self.start + p, os.SEEK_SET) 54 | return self.p 55 | 56 | def seekable(self): 57 | return self.base_io_object.seekable() 58 | 59 | def tell(self): 60 | return self.p - self.start 61 | 62 | def writable(self): 63 | return False 64 | 65 | 66 | class ReadonlyDictIndices: 67 | def __init__(self, base_dict_keys): 68 | sorted_table = sorted( 69 | enumerate(base_dict_keys), key=lambda i: i[1]) 70 | self.sorted_keys = [i[1] for i in sorted_table] 71 | self.key_indices = np.array([i[0] for i in sorted_table], np.int64) 72 | 73 | def __len__(self): 74 | return len(self.sorted_keys) 75 | 76 | def __contains__(self, key): 77 | i = bisect.bisect_left(self.sorted_keys, key) 78 | in_range = i >= 0 and i < len(self.sorted_keys) 79 | return in_range and self.sorted_keys[i] == key 80 | 81 | def __getitem__(self, key): 82 | i = bisect.bisect_left(self.sorted_keys, key) 83 | if i < 0 or i >= len(self.sorted_keys) or self.sorted_keys[i] != key: 84 | raise KeyError("{} not found".format(key)) 85 | 86 | return self.key_indices[i] 87 | 88 | def get_all_indices(self, key): 89 | i0 = bisect.bisect_left(self.sorted_keys, key) 90 | i1 = bisect.bisect_right(self.sorted_keys, key) 91 | return [self.key_indices[i] for i in range(i0, i1)] if i1 > i0 else [] 92 | 93 | 94 | class SerializedReadonlyList: 95 | """A list to prevent memory divergence accessed by forked process. 96 | """ 97 | 98 | def __init__(self, items: list): 99 | serialized_items = [pickle.dumps(i) for i in items] 100 | self.offsets = np.cumsum([len(i) for i in serialized_items]) 101 | self.blob = b''.join(serialized_items) 102 | 103 | def __len__(self): 104 | return len(self.offsets) 105 | 106 | def __getitem__(self, index: int): 107 | start = 0 if index == 0 else self.offsets[index - 1] 108 | end = self.offsets[index] 109 | return pickle.loads(self.blob[start:end]) 110 | 111 | 112 | class SerializedReadonlyDict: 113 | """A dict to prevent memory divergence accessed by forked process. 114 | """ 115 | 116 | def __init__(self, base_dict: dict): 117 | self.indices = ReadonlyDictIndices(base_dict.keys()) 118 | self.values = SerializedReadonlyList(base_dict.values()) 119 | 120 | def __len__(self): 121 | return len(self.indices) 122 | 123 | def __contains__(self, key): 124 | return key in self.indices 125 | 126 | def __getitem__(self, key): 127 | return self.values[self.indices[key]] 128 | 129 | def keys(self): 130 | return self.indices.sorted_keys 131 | 132 | 133 | def get_class(class_name: str): 134 | if "." in class_name: 135 | i = class_name.rfind(".") 136 | module_name = class_name[:i] 137 | class_name = class_name[i+1:] 138 | 139 | module_type = importlib.import_module(module_name, package=None) 140 | class_type = getattr(module_type, class_name) 141 | elif class_name in globals(): 142 | class_type = globals()[class_name] 143 | else: 144 | raise RuntimeError("Failed to find the class {}.".format(class_name)) 145 | 146 | return class_type 147 | 148 | 149 | def create_instance(class_name: str, **kwargs): 150 | class_type = get_class(class_name) 151 | return class_type(**kwargs) 152 | 153 | 154 | def create_instance_from_config(_config: dict, level: int = 0, **kwargs): 155 | if isinstance(_config, dict): 156 | if "_class_name" in _config: 157 | args = instantiate_config(_config, level) 158 | if level == 0: 159 | args.update(kwargs) 160 | 161 | if _config["_class_name"] == "get_class": 162 | return get_class(**args) 163 | else: 164 | return create_instance(_config["_class_name"], **args) 165 | 166 | else: 167 | return instantiate_config(_config, level) 168 | 169 | elif isinstance(_config, list): 170 | return [create_instance_from_config(i, level + 1) for i in _config] 171 | else: 172 | return _config 173 | 174 | 175 | def instantiate_config(_config: dict, level: int = 0): 176 | return { 177 | k: create_instance_from_config(v, level + 1) 178 | for k, v in _config.items() if k != "_class_name" 179 | } 180 | 181 | 182 | def get_state(key: str): 183 | return global_state[key] 184 | 185 | 186 | global_state = {} 187 | -------------------------------------------------------------------------------- /src/dwm/fs/s3fs.py: -------------------------------------------------------------------------------- 1 | import botocore.session 2 | import fsspec 3 | import io 4 | import os 5 | import re 6 | 7 | 8 | class S3File(io.RawIOBase): 9 | 10 | @staticmethod 11 | def find_bucket_key(s3_path): 12 | """ 13 | This is a helper function that given an s3 path such that the path is of 14 | the form: bucket/key 15 | It will return the bucket and the key represented by the s3 path 16 | """ 17 | 18 | bucket_format_list = [ 19 | re.compile( 20 | r"^(?Parn:(aws).*:s3:[a-z\-0-9]*:[0-9]{12}:accesspoint[:/][^/]+)/?" 21 | r"(?P.*)$" 22 | ), 23 | re.compile( 24 | r"^(?Parn:(aws).*:s3-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]" 25 | r"[a-zA-Z0-9\-]{1,63}[/:](bucket|accesspoint)[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" 26 | ), 27 | re.compile( 28 | r"^(?Parn:(aws).*:s3-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]" 29 | r"[a-zA-Z0-9\-]{1,63}[/:]bucket[/:]" 30 | r"[a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" 31 | ), 32 | re.compile( 33 | r"^(?Parn:(aws).*:s3-object-lambda:[a-z\-0-9]+:[0-9]{12}:" 34 | r"accesspoint[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" 35 | ), 36 | ] 37 | for bucket_format in bucket_format_list: 38 | match = bucket_format.match(s3_path) 39 | if match: 40 | return match.group("bucket"), match.group("key") 41 | 42 | s3_components = s3_path.split("/", 1) 43 | bucket = s3_components[0] 44 | s3_key = "" 45 | if len(s3_components) > 1: 46 | s3_key = s3_components[1] 47 | 48 | return bucket, s3_key 49 | 50 | def __init__(self, client, path): 51 | super().__init__() 52 | self.client = client 53 | self.bucket, self.key = S3File.find_bucket_key(path) 54 | self.p = 0 55 | self.head = client.head_object(Bucket=self.bucket, Key=self.key) 56 | 57 | def readable(self): 58 | return True 59 | 60 | def read(self, size=-1): 61 | read_count = min(size, self.head["ContentLength"] - self.p) \ 62 | if size >= 0 else self.head["ContentLength"] - self.p 63 | 64 | if read_count == 0: 65 | return b"" 66 | 67 | end = self.p + read_count - 1 68 | response = self.client.get_object( 69 | Bucket=self.bucket, Key=self.key, 70 | Range="bytes={}-{}".format(self.p, end)) 71 | data = response["Body"].read() 72 | self.p += read_count 73 | return data 74 | 75 | def readall(self): 76 | return self.read(-1) 77 | 78 | def seek(self, offset, whence=os.SEEK_SET): 79 | if whence == os.SEEK_SET: 80 | self.p = max(0, min(offset, self.head["ContentLength"])) 81 | elif whence == os.SEEK_CUR: 82 | self.p = max(0, min(self.p + offset, self.head["ContentLength"])) 83 | elif whence == os.SEEK_END: 84 | self.p = max( 85 | 0, min( 86 | self.head["ContentLength"] + offset, 87 | self.head["ContentLength"])) 88 | 89 | return self.p 90 | 91 | def seekable(self): 92 | return True 93 | 94 | def tell(self): 95 | return self.p 96 | 97 | def writable(self): 98 | return False 99 | 100 | 101 | class ForkableS3FileSystem(fsspec.AbstractFileSystem): 102 | 103 | root_marker = "" 104 | protocol = "s3" 105 | cachable = False 106 | 107 | """This file system can be used to access files on the S3 service, and it 108 | is also compatible with the process fork under the multi-worker situation 109 | of the PyTorch data loader. 110 | 111 | Args: 112 | kwargs: The parameters follow the 113 | [Botocore confiruation](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html). 114 | """ 115 | 116 | def __init__(self, **kwargs): 117 | # readonly, TODO: support list 118 | 119 | super().__init__() 120 | self.kwargs = kwargs 121 | self.client = botocore.session.get_session()\ 122 | .create_client("s3", **kwargs) 123 | 124 | def reinit_if_forked(self): 125 | current_pid = os.getpid() 126 | if self._pid != current_pid: 127 | self.client = botocore.session.get_session()\ 128 | .create_client("s3", **self.kwargs) 129 | self._pid = current_pid 130 | 131 | def _open( 132 | self, path, mode="rb", block_size=None, autocommit=True, 133 | cache_options=None, **kwargs, 134 | ): 135 | self.reinit_if_forked() 136 | return S3File(self.client, path) 137 | 138 | def ls(self, path, detail=True, **kwargs): 139 | self.reinit_if_forked() 140 | bucket, key = S3File.find_bucket_key(path) 141 | if len(key) > 0 and not key.endswith("/"): 142 | key = key + "/" 143 | 144 | # NOTE: only files are listed 145 | paths = [] 146 | continuation_token = None 147 | while True: 148 | if continuation_token is None: 149 | response = self.client.list_objects( 150 | Bucket=bucket, Delimiter="/", Prefix=key) 151 | else: 152 | response = self.client.list_objects( 153 | Bucket=bucket, Delimiter="/", Prefix=key, 154 | Marker=continuation_token) 155 | 156 | if "Contents" in response: 157 | for i in response["Contents"]: 158 | paths.append({ 159 | "name": "{}/{}".format(bucket, i["Key"]), 160 | "size": i["Size"], 161 | "type": "file", 162 | "Owner": i["Owner"], 163 | "ETag": i["ETag"], 164 | "LastModified": i["LastModified"] 165 | }) 166 | 167 | if response["IsTruncated"]: 168 | continuation_token = response["NextMarker"] 169 | else: 170 | break 171 | 172 | if detail: 173 | return sorted(paths, key=lambda i: i["name"]) 174 | else: 175 | return sorted([i["name"] for i in paths]) 176 | -------------------------------------------------------------------------------- /src/dwm/fs/dirfs.py: -------------------------------------------------------------------------------- 1 | import fsspec 2 | import fsspec.implementations.local 3 | 4 | 5 | class DirFileSystem(fsspec.AbstractFileSystem): 6 | """Directory prefix filesystem 7 | 8 | The DirFileSystem is a filesystem-wrapper. It assumes every path it is 9 | dealing with is relative to the `path`. After performing the necessary 10 | paths operation it delegates everything to the wrapped filesystem. 11 | """ 12 | 13 | protocol = "dir" 14 | 15 | def __init__(self, path=None, fs=None, **kwargs): 16 | super().__init__(**kwargs) 17 | self.path = path 18 | self.fs = fsspec.implementations.local.LocalFileSystem() \ 19 | if fs is None else fs 20 | 21 | @property 22 | def sep(self): 23 | return self.fs.sep 24 | 25 | def _join(self, path): 26 | if isinstance(path, str): 27 | if not self.path: 28 | return path 29 | 30 | if not path: 31 | return self.path 32 | 33 | return self.fs.sep.join((self.path, self._strip_protocol(path))) 34 | 35 | return [self._join(_path) for _path in path] 36 | 37 | def _relpath(self, path): 38 | if isinstance(path, str): 39 | if not self.path: 40 | return path 41 | 42 | if path == self.path: 43 | return "" 44 | 45 | prefix = self.path + self.fs.sep 46 | assert path.startswith(prefix) 47 | return path[len(prefix):] 48 | 49 | return [self._relpath(_path) for _path in path] 50 | 51 | def rm_file(self, path, **kwargs): 52 | return self.fs.rm_file(self._join(path), **kwargs) 53 | 54 | def rm(self, path, *args, **kwargs): 55 | return self.fs.rm(self._join(path), *args, **kwargs) 56 | 57 | def cp_file(self, path1, path2, **kwargs): 58 | return self.fs.cp_file(self._join(path1), self._join(path2), **kwargs) 59 | 60 | def copy(self, path1, path2, *args, **kwargs): 61 | return self.fs.copy( 62 | self._join(path1), self._join(path2), *args, **kwargs) 63 | 64 | def pipe(self, path, *args, **kwargs): 65 | return self.fs.pipe(self._join(path), *args, **kwargs) 66 | 67 | def pipe_file(self, path, *args, **kwargs): 68 | return self.fs.pipe_file(self._join(path), *args, **kwargs) 69 | 70 | def cat_file(self, path, *args, **kwargs): 71 | return self.fs.cat_file(self._join(path), *args, **kwargs) 72 | 73 | def cat(self, path, *args, **kwargs): 74 | ret = self.fs.cat(self._join(path), *args, **kwargs) 75 | 76 | if isinstance(ret, dict): 77 | return {self._relpath(key): value for key, value in ret.items()} 78 | 79 | return ret 80 | 81 | def put_file(self, lpath, rpath, **kwargs): 82 | return self.fs.put_file(lpath, self._join(rpath), **kwargs) 83 | 84 | def put(self, lpath, rpath, *args, **kwargs): 85 | return self.fs.put(lpath, self._join(rpath), *args, **kwargs) 86 | 87 | def get_file(self, rpath, lpath, **kwargs): 88 | return self.fs.get_file(self._join(rpath), lpath, **kwargs) 89 | 90 | def get(self, rpath, *args, **kwargs): 91 | return self.fs.get(self._join(rpath), *args, **kwargs) 92 | 93 | def isfile(self, path): 94 | return self.fs.isfile(self._join(path)) 95 | 96 | def isdir(self, path): 97 | return self.fs.isdir(self._join(path)) 98 | 99 | def size(self, path): 100 | return self.fs.size(self._join(path)) 101 | 102 | def exists(self, path): 103 | return self.fs.exists(self._join(path)) 104 | 105 | def info(self, path, **kwargs): 106 | return self.fs.info(self._join(path), **kwargs) 107 | 108 | def ls(self, path, detail=True, **kwargs): 109 | ret = self.fs.ls(self._join(path), detail=detail, **kwargs).copy() 110 | if detail: 111 | out = [] 112 | for entry in ret: 113 | entry = entry.copy() 114 | entry["name"] = self._relpath(entry["name"]) 115 | out.append(entry) 116 | 117 | return out 118 | 119 | return self._relpath(ret) 120 | 121 | def walk(self, path, *args, **kwargs): 122 | for i in self.fs.walk(self._join(path), *args, **kwargs): 123 | root, dirs, files = i 124 | yield self._relpath(root), dirs, files 125 | 126 | def glob(self, path, **kwargs): 127 | detail = kwargs.get("detail", False) 128 | ret = self.fs.glob(self._join(path), **kwargs) 129 | if detail: 130 | return {self._relpath(path): info for path, info in ret.items()} 131 | return self._relpath(ret) 132 | 133 | def du(self, path, *args, **kwargs): 134 | total = kwargs.get("total", True) 135 | ret = self.fs.du(self._join(path), *args, **kwargs) 136 | if total: 137 | return ret 138 | 139 | return {self._relpath(path): size for path, size in ret.items()} 140 | 141 | def find(self, path, *args, **kwargs): 142 | detail = kwargs.get("detail", False) 143 | ret = self.fs.find(self._join(path), *args, **kwargs) 144 | if detail: 145 | return {self._relpath(path): info for path, info in ret.items()} 146 | return self._relpath(ret) 147 | 148 | def expand_path(self, path, *args, **kwargs): 149 | return self._relpath( 150 | self.fs.expand_path(self._join(path), *args, **kwargs)) 151 | 152 | def mkdir(self, path, *args, **kwargs): 153 | return self.fs.mkdir(self._join(path), *args, **kwargs) 154 | 155 | def makedirs(self, path, *args, **kwargs): 156 | return self.fs.makedirs(self._join(path), *args, **kwargs) 157 | 158 | def rmdir(self, path): 159 | return self.fs.rmdir(self._join(path)) 160 | 161 | def mv(self, path1, path2, **kwargs): 162 | return self.fs.mv(self._join(path1), self._join(path2), **kwargs) 163 | 164 | def touch(self, path, **kwargs): 165 | return self.fs.touch(self._join(path), **kwargs) 166 | 167 | def created(self, path): 168 | return self.fs.created(self._join(path)) 169 | 170 | def modified(self, path): 171 | return self.fs.modified(self._join(path)) 172 | 173 | def sign(self, path, *args, **kwargs): 174 | return self.fs.sign(self._join(path), *args, **kwargs) 175 | 176 | def __repr__(self): 177 | return "{}(path='{}', fs={})".format( 178 | self.__class__.__qualname__, self.path, self.fs) 179 | 180 | def open(self, path, *args, **kwargs): 181 | return self.fs.open(self._join(path), *args, **kwargs) 182 | -------------------------------------------------------------------------------- /src/dwm/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def create_frustum( 5 | frustum_depth_range: list, frustum_height: int, frustum_width: int, 6 | device=None 7 | ): 8 | """Create the lifted frustum from the camera view. 9 | 10 | Args: 11 | frustum_depth_range (list): 3 float numbers (start, stop, step) of the 12 | depth range. The stop is exclusive as same as the common definition 13 | of range functions. 14 | frustum_height: The frustum height. 15 | frustum_width: The frustum width. 16 | device: The prefered device of the returned frustum tensor. 17 | 18 | Returns: 19 | The lifted frustum tensor in the shape of [4, D, frustum_height, 20 | frustum_width]. The D = (stop - start) / step of the 21 | frustum_depth_range. The 4 items in the first dimension are the (nx * z, 22 | ny * z, z, 1), where nx and ny are the normalized coordinate in the 23 | range of -1 ~ 1, z is in the unit of camera coordinate system. 24 | """ 25 | 26 | depth_channels = int( 27 | (frustum_depth_range[1] - frustum_depth_range[0]) / 28 | frustum_depth_range[2]) 29 | x = torch.arange( 30 | 1 / frustum_width - 1, 1, 2 / frustum_width, device=device) 31 | y = torch.arange( 32 | 1 / frustum_height - 1, 1, 2 / frustum_height, device=device) 33 | z = torch.arange(*frustum_depth_range, device=device) 34 | frustum = torch.stack([ 35 | x.unsqueeze(0).unsqueeze(0).repeat(depth_channels, frustum_height, 1), 36 | y.unsqueeze(0).unsqueeze(-1).repeat(depth_channels, 1, frustum_width), 37 | z.unsqueeze(-1).unsqueeze(-1).repeat(1, frustum_height, frustum_width), 38 | torch.ones( 39 | (depth_channels, frustum_height, frustum_width), device=device) 40 | ]) 41 | frustum[:2] *= frustum[2:3] 42 | return frustum 43 | 44 | 45 | def make_homogeneous_matrix(a: torch.Tensor): 46 | right = torch.cat([ 47 | torch.zeros( 48 | tuple(a.shape[:-1]) + (1,), dtype=a.dtype, device=a.device), 49 | torch.ones( 50 | tuple(a.shape[:-2]) + (1, 1), dtype=a.dtype, device=a.device) 51 | ], -2) 52 | return torch.cat([torch.nn.functional.pad(a, (0, 0, 0, 1)), right], -1) 53 | 54 | 55 | def make_homogeneous_vector(a: torch.Tensor): 56 | return torch.cat([ 57 | a, 58 | torch.ones(tuple(a.shape[:-1]) + (1,), dtype=a.dtype, device=a.device) 59 | ], -1) 60 | 61 | 62 | def make_transform_to_frustum( 63 | image_size: torch.Tensor, frustum_width: int, frustum_height: int 64 | ): 65 | """Make the 4x4 transform from the image coordinates to the frustum 66 | coordinates. 67 | 68 | Args: 69 | image_size (torch.Tensor): The image size tensor in the shape of 70 | [..., 2], and the 2 numbers of the last dimension is 71 | (width, height). 72 | frustum_width (int): The width of the frustum. 73 | frustum_height (int): The height of the frustum. 74 | 75 | Returns: 76 | The transform matrix in the shape of [..., 4, 4]. 77 | """ 78 | 79 | base_shape = list(image_size.shape[:-1]) 80 | frustum_size = torch.tensor( 81 | [frustum_width, frustum_height], dtype=image_size.dtype, 82 | device=image_size.device).view(*([1 for _ in base_shape] + [-1])) 83 | scale = frustum_size / image_size 84 | zeros = torch.zeros( 85 | base_shape + [4], dtype=scale.dtype, device=scale.device) 86 | ones = torch.ones( 87 | base_shape + [1], dtype=scale.dtype, device=scale.device) 88 | return torch.cat([ 89 | scale[..., 0:1], zeros, scale[..., 1:2], zeros, ones, zeros, ones 90 | ], -1).unflatten(-1, (4, 4)) 91 | 92 | 93 | def normalize_intrinsic_transform( 94 | image_sizes: torch.Tensor, instrinsics: torch.Tensor 95 | ): 96 | """Make the normalized 3x3 intrinsic transform from the camera coordinates 97 | to the normalized coordinates (-1 ~ 1). 98 | 99 | Args: 100 | image_sizes (torch.Tensor): The image size tensor in the shape of 101 | [..., 2], and the 2 numbers of the last dimension is 102 | (width, height). 103 | instrinsics (torch.Tensor): The camera intrinsic transform from the 104 | camera coordinates (X-right, Y-down, Z-forward) to the image 105 | coordinates (X-right, Y-down). 106 | 107 | Returns: 108 | The transform matrix in the shape of [..., 3, 3]. 109 | """ 110 | base_shape = list(image_sizes.shape[:-1]) 111 | scale = 2 / image_sizes 112 | translate = 1 / image_sizes - 1 113 | zeros = torch.zeros( 114 | base_shape + [2], dtype=scale.dtype, device=scale.device) 115 | ones = torch.ones( 116 | base_shape + [1], dtype=scale.dtype, device=scale.device) 117 | normalization_transform = torch.cat([ 118 | scale[..., 0:1], zeros[..., :1], translate[..., 0:1], zeros[..., :1], 119 | scale[..., 1:2], translate[..., 1:2], zeros, ones 120 | ], -1).unflatten(-1, (3, 3)) 121 | return normalization_transform @ instrinsics 122 | 123 | 124 | def grid_sample_sequence( 125 | input: torch.Tensor, sequence: torch.Tensor, bundle_size: int = 128, 126 | mode: str = "bilinear", padding_mode: str = "border" 127 | ): 128 | count = sequence.shape[-2] 129 | bundle_count = (count + bundle_size - 1) // bundle_size 130 | assert bundle_count * bundle_size == count 131 | 132 | samples = torch.nn.functional.grid_sample( 133 | input.flatten(0, -4), 134 | sequence.unflatten(-2, (bundle_count, bundle_size)), mode=mode, 135 | padding_mode=padding_mode, align_corners=False) 136 | return samples.unflatten(0, input.shape[:-3]).flatten(-2) 137 | 138 | 139 | def _sample_logistic( 140 | shape: torch.Size, out: torch.Tensor = None, 141 | generator: torch.Generator = None 142 | ): 143 | x = torch.rand(shape, generator=generator) if out is None else \ 144 | out.resize_(shape).uniform_(generator=generator) 145 | return torch.log(x) - torch.log(1 - x) 146 | 147 | 148 | def _sigmoid_sample( 149 | logits: torch.Tensor, tau: float = 1.0, generator: torch.Generator = None 150 | ): 151 | # Refer to Bernouilli reparametrization based on Maddison et al. 2017 152 | noise = _sample_logistic(logits.size(), None, generator) 153 | y = logits + noise 154 | return torch.sigmoid(y / tau) 155 | 156 | 157 | def gumbel_sigmoid( 158 | logits, tau: float = 1.0, hard: bool = False, 159 | generator: torch.Generator = None 160 | ): 161 | # use CPU random generator 162 | y_soft = _sigmoid_sample(logits.cpu(), tau, generator).to(logits.device) 163 | if hard: 164 | y_hard = torch.where( 165 | y_soft > 0.5, torch.ones_like(y_soft), torch.zeros_like(y_soft)) 166 | return y_hard.data - y_soft.data + y_soft 167 | 168 | else: 169 | return y_soft 170 | 171 | 172 | def take_sequence_clip(item, start: int, stop: int): 173 | if isinstance(item, (int, float, bool, str)): 174 | return item 175 | elif isinstance(item, torch.Tensor): 176 | return item if len(item.shape) <= 1 else item[:, start:stop] 177 | elif isinstance(item, list): 178 | assert len(item) > 0 and all([isinstance(i, list) for i in item]) 179 | return [i[start:stop] for i in item] 180 | else: 181 | raise Exception("Unsupported type to take sequence clip.") 182 | 183 | 184 | def memory_efficient_split_call( 185 | block: torch.nn.Module, tensor: torch.Tensor, func, split_size: int 186 | ): 187 | if split_size == -1: 188 | return func(block, tensor) 189 | else: 190 | return torch.cat([ 191 | func(block, i) 192 | for i in tensor.split(split_size) 193 | ]) 194 | -------------------------------------------------------------------------------- /src/dwm/datasets/waymo_common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_inclination(inclination_range, height, dtype=np.float32): 5 | diff = inclination_range[1] - inclination_range[0] 6 | normal_range = (np.arange(0, height, dtype=dtype) + 0.5) / height 7 | inclination = normal_range * diff + inclination_range[0] 8 | return inclination 9 | 10 | 11 | def compute_range_image_polar( 12 | range_image: np.array, extrinsic: np.array, inclination: np.array, 13 | dtype=np.float32 14 | ): 15 | _, width = range_image.shape 16 | az_correction = np.arctan2(extrinsic[1, 0], extrinsic[0, 0]) 17 | ratios = (np.arange(width, 0, -1, dtype=dtype) - 0.5) / width 18 | azimuth = (ratios * 2 - 1) * np.pi - az_correction 19 | return azimuth[None, :], inclination[:, None], range_image 20 | 21 | 22 | def get_rotation_matrix( 23 | roll: np.array, pitch: np.array, yaw: np.array, 24 | ): 25 | """Gets a rotation matrix given roll, pitch, yaw. 26 | 27 | roll-pitch-yaw is z-y'-x'' intrinsic rotation which means we need to apply 28 | x(roll) rotation first, then y(pitch) rotation, then z(yaw) rotation. 29 | 30 | https://en.wikipedia.org/wiki/Euler_angles 31 | http://planning.cs.uiuc.edu/node102.html 32 | 33 | Args: 34 | roll : x-rotation in radians. 35 | pitch: y-rotation in radians. The shape must be the same as roll. 36 | yaw: z-rotation in radians. The shape must be the same as roll. 37 | 38 | Returns: 39 | A rotation tensor with the same data type of the input. Its shape is 40 | [3 ,3]. 41 | """ 42 | 43 | cos_roll = np.cos(roll) 44 | sin_roll = np.sin(roll) 45 | cos_yaw = np.cos(yaw) 46 | sin_yaw = np.sin(yaw) 47 | cos_pitch = np.cos(pitch) 48 | sin_pitch = np.sin(pitch) 49 | ones = np.ones_like(cos_yaw) 50 | zeros = np.zeros_like(cos_yaw) 51 | 52 | r_roll = np.stack([ 53 | np.stack([ones, zeros, zeros], axis=-1), 54 | np.stack([zeros, cos_roll, -sin_roll], axis=-1), 55 | np.stack([zeros, sin_roll, cos_roll], axis=-1), 56 | ], axis=-2) 57 | r_pitch = np.stack([ 58 | np.stack([cos_pitch, zeros, sin_pitch], axis=-1), 59 | np.stack([zeros, ones, zeros], axis=-1), 60 | np.stack([-sin_pitch, zeros, cos_pitch], axis=-1), 61 | ], axis=-2) 62 | r_yaw = np.stack([ 63 | np.stack([cos_yaw, -sin_yaw, zeros], axis=-1), 64 | np.stack([sin_yaw, cos_yaw, zeros], axis=-1), 65 | np.stack([zeros, zeros, ones], axis=-1), 66 | ], axis=-2) 67 | 68 | return np.matmul(r_yaw, np.matmul(r_pitch, r_roll)) 69 | 70 | 71 | def compute_range_image_cartesian( 72 | azimuth: np.array, inclination: np.array, range_image_range: np.array, 73 | extrinsic: np.array, pixel_pose=None, frame_pose=None, 74 | ): 75 | """Computes range image cartesian coordinates from polar ones. 76 | 77 | Args: 78 | range_image_polar: [B, H, W, 3] float tensor. Lidar range image in 79 | polar coordinate in sensor frame. 80 | extrinsic: [B, 4, 4] float tensor. Lidar extrinsic. 81 | pixel_pose: [B, H, W, 4, 4] float tensor. If not None, it sets pose for 82 | each range image pixel. 83 | frame_pose: [B, 4, 4] float tensor. This must be set when pixel_pose is 84 | set. It decides the vehicle frame at which the cartesian points are 85 | computed. 86 | dtype: float type to use internally. This is needed as extrinsic and 87 | inclination sometimes have higher resolution than range_image. 88 | 89 | Returns: 90 | range_image_cartesian: [B, H, W, 3] cartesian coordinates. 91 | """ 92 | 93 | cos_azimuth = np.cos(azimuth) 94 | sin_azimuth = np.sin(azimuth) 95 | cos_incl = np.cos(inclination) 96 | sin_incl = np.sin(inclination) 97 | 98 | # [H, W]. 99 | x = cos_azimuth * cos_incl * range_image_range 100 | y = sin_azimuth * cos_incl * range_image_range 101 | z = sin_incl * range_image_range 102 | 103 | # [H, W, 3] 104 | range_image_points = np.stack([x, y, z], -1) 105 | 106 | # To vehicle frame. 107 | rotation = extrinsic[0:3, 0:3].T 108 | translation = extrinsic[0:3, 3:4].T 109 | range_image_points = range_image_points @ rotation + translation 110 | 111 | if pixel_pose is not None: 112 | # To global frame. 113 | # [H, W, 3, 3] 114 | pixel_pose_rotation = np.swapaxes( 115 | get_rotation_matrix( 116 | pixel_pose[..., 0], pixel_pose[..., 1], pixel_pose[..., 2]), 117 | -1, -2) 118 | 119 | # [H, W, 1, 3] 120 | pixel_pose_translation = pixel_pose[..., None, 3:] 121 | range_image_points = ( 122 | range_image_points[..., None, :] @ pixel_pose_rotation + 123 | pixel_pose_translation) 124 | 125 | # [H, W, 3] 126 | range_image_points = range_image_points.reshape( 127 | *range_image_points.shape[:-2], -1) 128 | 129 | if frame_pose is None: 130 | raise ValueError('frame_pose must be set when pixel_pose is set.') 131 | 132 | # To vehicle frame corresponding to the given frame_pose 133 | # [4, 4] 134 | world_to_vehicle = np.linalg.inv(frame_pose) 135 | world_to_vehicle_rotation = world_to_vehicle[0:3, 0:3].T 136 | world_to_vehicle_translation = world_to_vehicle[0:3, 3:4].T 137 | 138 | # [H, W, 3] 139 | range_image_points = range_image_points @ world_to_vehicle_rotation + \ 140 | world_to_vehicle_translation 141 | 142 | return range_image_points 143 | 144 | 145 | def convert_range_image_to_cartesian( 146 | range_image: np.array, calibration: dict, lidar_pose=None, frame_pose=None, 147 | dtype=np.float32 148 | ): 149 | """Converts one range image from polar coordinates to Cartesian coordinates. 150 | 151 | Args: 152 | range_image: One range image return captured by a LiDAR sensor. 153 | calibration: Parameters for calibration of a LiDAR sensor. 154 | 155 | Returns: 156 | A [H, W, 3] image in Cartesian coordinates. 157 | """ 158 | 159 | extrinsic = np.array( 160 | calibration["[LiDARCalibrationComponent].extrinsic.transform"], 161 | dtype).reshape(4, 4) 162 | 163 | # Compute inclinations mapping range image rows to circles in the 3D worlds. 164 | bi_values_key = "[LiDARCalibrationComponent].beam_inclination.values" 165 | bi_min_key = "[LiDARCalibrationComponent].beam_inclination.min" 166 | bi_max_key = "[LiDARCalibrationComponent].beam_inclination.max" 167 | if calibration[bi_values_key] is not None: 168 | inclination = np.array(calibration[bi_values_key], dtype) 169 | else: 170 | inclination = compute_inclination( 171 | (calibration[bi_min_key], calibration[bi_max_key]), 172 | range_image.shape[0], dtype) 173 | 174 | inclination = np.flip(inclination, axis=-1) 175 | 176 | # Compute points from the range image 177 | azimuth, inclination, range_image_range = compute_range_image_polar( 178 | range_image[..., 0], extrinsic, inclination, dtype) 179 | range_image_cartesian = compute_range_image_cartesian( 180 | azimuth, inclination, range_image_range, extrinsic, 181 | pixel_pose=lidar_pose, frame_pose=frame_pose) 182 | range_image_mask = range_image[..., 0] > 0 183 | 184 | flatten_range_image_cartesian = range_image_cartesian.reshape(-1, 3) 185 | flatten_range_image_mask = range_image_mask.reshape(-1) 186 | return flatten_range_image_cartesian[flatten_range_image_mask] 187 | 188 | 189 | def laser_calibration_to_dict(lc): 190 | lc.beam_inclination_max 191 | return { 192 | "key.laser_name": lc.name, 193 | "[LiDARCalibrationComponent].beam_inclination.values": 194 | lc.beam_inclinations, 195 | "[LiDARCalibrationComponent].beam_inclination.min": 196 | lc.beam_inclination_min, 197 | "[LiDARCalibrationComponent].beam_inclination.max": 198 | lc.beam_inclination_max, 199 | "[LiDARCalibrationComponent].extrinsic.transform": 200 | lc.extrinsic.transform 201 | } 202 | -------------------------------------------------------------------------------- /src/dwm/models/maskgit_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from diffusers.models.attention import maybe_allow_in_graph 5 | from diffusers.models.attention import _chunked_feed_forward 6 | from diffusers.models.attention import Attention 7 | from diffusers.models.attention import FeedForward 8 | from diffusers.models.normalization import AdaLayerNormZero 9 | from diffusers.models.normalization import AdaLayerNormContinuous 10 | from diffusers.models.normalization import SD35AdaLayerNormZeroX 11 | from typing import Optional 12 | from einops import rearrange 13 | 14 | 15 | @maybe_allow_in_graph 16 | class TemporalTransformerBlock(nn.Module): 17 | r""" 18 | A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. 19 | 20 | Reference: https://arxiv.org/abs/2403.03206 21 | 22 | Parameters: 23 | dim (`int`): The number of channels in the input and output. 24 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 25 | attention_head_dim (`int`): The number of channels in each head. 26 | context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the 27 | processing of `context` conditions. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_attention_heads: int, 34 | attention_head_dim: int, 35 | context_pre_only: bool = False, 36 | qk_norm: Optional[str] = None, 37 | ): 38 | super().__init__() 39 | 40 | self.context_pre_only = context_pre_only 41 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 42 | 43 | if hasattr(F, "scaled_dot_product_attention"): 44 | processor = TemporalAttnProcessor() 45 | else: 46 | raise ValueError( 47 | "The current PyTorch version does not support the `scaled_dot_product_attention` function." 48 | ) 49 | self.attn = Attention( 50 | query_dim=dim, 51 | cross_attention_dim=None, 52 | added_kv_proj_dim=None, 53 | dim_head=attention_head_dim, 54 | heads=num_attention_heads, 55 | out_dim=dim, 56 | context_pre_only=None, 57 | bias=True, 58 | processor=processor, 59 | qk_norm=qk_norm, 60 | eps=1e-6, 61 | ) 62 | 63 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 64 | self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") 65 | 66 | # let chunk size default to None 67 | self._chunk_size = None 68 | self._chunk_dim = 0 69 | 70 | # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward 71 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 72 | # Sets chunk feed-forward 73 | self._chunk_size = chunk_size 74 | self._chunk_dim = dim 75 | 76 | def forward( 77 | self, hidden_states: torch.FloatTensor, rotary_emb: Optional[torch.nn.Module] = None 78 | ): 79 | norm_hidden_states = self.norm1(hidden_states) 80 | # Attention. 81 | attn_output = self.attn( 82 | hidden_states=norm_hidden_states, 83 | rotary_emb=rotary_emb 84 | ) 85 | 86 | # Process attention outputs for the `hidden_states`. 87 | hidden_states = hidden_states + attn_output 88 | 89 | 90 | norm_hidden_states = self.norm2(hidden_states) 91 | if self._chunk_size is not None: 92 | # "feed_forward_chunk_size" can be used to save memory 93 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 94 | else: 95 | ff_output = self.ff(norm_hidden_states) 96 | 97 | hidden_states = hidden_states + ff_output 98 | 99 | return hidden_states 100 | 101 | 102 | class TemporalAttnProcessor: 103 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 104 | 105 | def __init__(self): 106 | if not hasattr(F, "scaled_dot_product_attention"): 107 | raise ImportError("TemporalAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 108 | 109 | def __call__( 110 | self, 111 | attn: Attention, 112 | hidden_states: torch.FloatTensor, 113 | encoder_hidden_states: torch.FloatTensor = None, 114 | attention_mask: Optional[torch.FloatTensor] = None, 115 | rotary_emb: Optional[torch.nn.Module] = None, 116 | *args, 117 | **kwargs, 118 | ) -> torch.FloatTensor: 119 | residual = hidden_states 120 | 121 | batch_size = hidden_states.shape[0] 122 | 123 | # `sample` projections. 124 | query = attn.to_q(hidden_states) 125 | key = attn.to_k(hidden_states) 126 | value = attn.to_v(hidden_states) 127 | 128 | inner_dim = key.shape[-1] 129 | head_dim = inner_dim // attn.heads 130 | 131 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 132 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 133 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 134 | 135 | if rotary_emb is not None: 136 | query = rotary_emb(query) 137 | key = rotary_emb(key) 138 | 139 | if attn.norm_q is not None: 140 | query = attn.norm_q(query) 141 | if attn.norm_k is not None: 142 | key = attn.norm_k(key) 143 | 144 | # `context` projections. 145 | if encoder_hidden_states is not None: 146 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 147 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 148 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 149 | 150 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 151 | batch_size, -1, attn.heads, head_dim 152 | ).transpose(1, 2) 153 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 154 | batch_size, -1, attn.heads, head_dim 155 | ).transpose(1, 2) 156 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 157 | batch_size, -1, attn.heads, head_dim 158 | ).transpose(1, 2) 159 | 160 | # if attn.norm_added_q is not None: 161 | # encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 162 | # if attn.norm_added_k is not None: 163 | # encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 164 | 165 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) 166 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) 167 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) 168 | 169 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 170 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 171 | hidden_states = hidden_states.to(query.dtype) 172 | 173 | if encoder_hidden_states is not None: 174 | # Split the attention outputs. 175 | hidden_states, encoder_hidden_states = ( 176 | hidden_states[:, : residual.shape[1]], 177 | hidden_states[:, residual.shape[1] :], 178 | ) 179 | # if not attn.context_pre_only: 180 | # encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 181 | 182 | # linear proj 183 | hidden_states = attn.to_out[0](hidden_states) 184 | # dropout 185 | hidden_states = attn.to_out[1](hidden_states) 186 | 187 | if encoder_hidden_states is not None: 188 | return hidden_states, encoder_hidden_states 189 | else: 190 | return hidden_states 191 | 192 | -------------------------------------------------------------------------------- /configs/experimental/simulation/make_carla_cameras_from_nuscenes.json: -------------------------------------------------------------------------------- 1 | { 2 | "CAM_FRONT_LEFT": { 3 | "image_size": [ 4 | 1600, 5 | 900 6 | ], 7 | "intrinsic": [ 8 | [ 9 | 1272.5979470598488, 10 | 0.0, 11 | 826.6154927353808 12 | ], 13 | [ 14 | 0.0, 15 | 1272.5979470598488, 16 | 479.75165386361925 17 | ], 18 | [ 19 | 0.0, 20 | 0.0, 21 | 1.0 22 | ] 23 | ], 24 | "transform": [ 25 | [ 26 | 0.8207583481515547, 27 | -0.00034143667149916235, 28 | 0.5712754303840931, 29 | 1.52387798135 30 | ], 31 | [ 32 | -0.5712716007618155, 33 | 0.0032195018514857843, 34 | 0.8207547703003994, 35 | 0.494631336551 36 | ], 37 | [ 38 | -0.002119458082718406, 39 | -0.9999947591006804, 40 | 0.00244737994761024, 41 | 1.50932822144 42 | ], 43 | [ 44 | 0.0, 45 | 0.0, 46 | 0.0, 47 | 1.0 48 | ] 49 | ] 50 | }, 51 | "CAM_FRONT": { 52 | "image_size": [ 53 | 1600, 54 | 900 55 | ], 56 | "intrinsic": [ 57 | [ 58 | 1252.8131021185304, 59 | 0.0, 60 | 826.588114781398 61 | ], 62 | [ 63 | 0.0, 64 | 1252.8131021185304, 65 | 469.9846626224581 66 | ], 67 | [ 68 | 0.0, 69 | 0.0, 70 | 1.0 71 | ] 72 | ], 73 | "transform": [ 74 | [ 75 | 0.01026020777540071, 76 | 0.008433448071667293, 77 | 0.9999117986552757, 78 | 1.72200568478 79 | ], 80 | [ 81 | -0.9998725753702897, 82 | 0.012316255772487295, 83 | 0.01015592763520401, 84 | 0.00475453292289 85 | ], 86 | [ 87 | -0.012229519973835201, 88 | -0.9998885871922779, 89 | 0.008558740785910124, 90 | 1.49491291905 91 | ], 92 | [ 93 | 0.0, 94 | 0.0, 95 | 0.0, 96 | 1.0 97 | ] 98 | ] 99 | }, 100 | "CAM_FRONT_RIGHT": { 101 | "image_size": [ 102 | 1600, 103 | 900 104 | ], 105 | "intrinsic": [ 106 | [ 107 | 1256.7485116440405, 108 | 0.0, 109 | 817.7887570959712 110 | ], 111 | [ 112 | 0.0, 113 | 1256.7485116440403, 114 | 451.9541780095127 115 | ], 116 | [ 117 | 0.0, 118 | 0.0, 119 | 1.0 120 | ] 121 | ], 122 | "transform": [ 123 | [ 124 | -0.8439797263539992, 125 | 0.01645551436192705, 126 | 0.5361225956350965, 127 | 1.58082565783 128 | ], 129 | [ 130 | -0.5361413772261798, 131 | 0.003621074712163108, 132 | -0.8441204365752225, 133 | -0.499078711449 134 | ], 135 | [ 136 | -0.015831775940933213, 137 | -0.9998580418564494, 138 | 0.005766368488289819, 139 | 1.51749368405 140 | ], 141 | [ 142 | 0.0, 143 | 0.0, 144 | 0.0, 145 | 1.0 146 | ] 147 | ] 148 | }, 149 | "CAM_BACK_RIGHT": { 150 | "image_size": [ 151 | 1600, 152 | 900 153 | ], 154 | "intrinsic": [ 155 | [ 156 | 1259.5137405846733, 157 | 0.0, 158 | 807.2529053838625 159 | ], 160 | [ 161 | 0.0, 162 | 1259.5137405846733, 163 | 501.19579884916527 164 | ], 165 | [ 166 | 0.0, 167 | 0.0, 168 | 1.0 169 | ] 170 | ], 171 | "transform": [ 172 | [ 173 | -0.9347755391782977, 174 | 0.01587583795544978, 175 | -0.3548839938953783, 176 | 1.0148780988 177 | ], 178 | [ 179 | 0.3550745592741918, 180 | 0.011370495332616137, 181 | -0.9347688319537238, 182 | -0.480568219723 183 | ], 184 | [ 185 | -0.010805031705694884, 186 | -0.9998093166224763, 187 | -0.01626596707041017, 188 | 1.56239545128 189 | ], 190 | [ 191 | 0.0, 192 | 0.0, 193 | 0.0, 194 | 1.0 195 | ] 196 | ] 197 | }, 198 | "CAM_BACK": { 199 | "image_size": [ 200 | 1600, 201 | 900 202 | ], 203 | "intrinsic": [ 204 | [ 205 | 809.2209905677063, 206 | 0.0, 207 | 829.2196003259838 208 | ], 209 | [ 210 | 0.0, 211 | 809.2209905677063, 212 | 481.77842384512485 213 | ], 214 | [ 215 | 0.0, 216 | 0.0, 217 | 1.0 218 | ] 219 | ], 220 | "transform": [ 221 | [ 222 | 0.002421709860318977, 223 | -0.016753608478469628, 224 | -0.9998567156969553, 225 | 0.0283260309358 226 | ], 227 | [ 228 | 0.9999890666843356, 229 | -0.003959107249965843, 230 | 0.002488369260045864, 231 | 0.00345136761476 232 | ], 233 | [ 234 | -0.004000229136375599, 235 | -0.9998518100562368, 236 | 0.016743837496914438, 237 | 1.57910346144 238 | ], 239 | [ 240 | 0.0, 241 | 0.0, 242 | 0.0, 243 | 1.0 244 | ] 245 | ] 246 | }, 247 | "CAM_BACK_LEFT": { 248 | "image_size": [ 249 | 1600, 250 | 900 251 | ], 252 | "intrinsic": [ 253 | [ 254 | 1256.7414812095406, 255 | 0.0, 256 | 792.1125740759628 257 | ], 258 | [ 259 | 0.0, 260 | 1256.7414812095406, 261 | 492.7757465151356 262 | ], 263 | [ 264 | 0.0, 265 | 0.0, 266 | 1.0 267 | ] 268 | ], 269 | "transform": [ 270 | [ 271 | 0.9477603556843314, 272 | 0.008665721547931132, 273 | -0.31886550999310576, 274 | 1.03569100218 275 | ], 276 | [ 277 | 0.31896113144149074, 278 | -0.01397629983587656, 279 | 0.9476647401230364, 280 | 0.484795032713 281 | ], 282 | [ 283 | 0.0037556387837154315, 284 | -0.9998647750135773, 285 | -0.016010211253286277, 286 | 1.59097014818 287 | ], 288 | [ 289 | 0.0, 290 | 0.0, 291 | 0.0, 292 | 1.0 293 | ] 294 | ] 295 | } 296 | } -------------------------------------------------------------------------------- /src/dwm/utils/carla_control.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import carla 3 | import tkinter 4 | import tkinter.font 5 | import tkinter.ttk 6 | 7 | 8 | class KeyPanel(): 9 | 10 | def __init__( 11 | self, master, title: str, detail: str, style_class: str = "Inactivated" 12 | ): 13 | default_font = tkinter.font.nametofont("TkDefaultFont") 14 | default_font_family = default_font.cget("family") 15 | self.key_panel = tkinter.ttk.Frame( 16 | master, style="{}.TFrame".format(style_class)) 17 | self.label_group = tkinter.ttk.Frame( 18 | self.key_panel, style="{}.TFrame".format(style_class)) 19 | self.title = tkinter.ttk.Label( 20 | self.label_group, text=title, 21 | style="{}.TLabel".format(style_class), 22 | font=(default_font_family, 18), padding=(0, -4, 0, -4)) 23 | self.detail = tkinter.ttk.Label( 24 | self.label_group, text=detail, 25 | style="{}.TLabel".format(style_class), 26 | font=(default_font_family, 10), padding=(0, -2, 0, -2)) 27 | 28 | self.label_group.place(relx=0.5, rely=0.5, anchor="center") 29 | self.title.pack(anchor="center") 30 | self.detail.pack(anchor="center") 31 | 32 | def set_style_class(self, style_class: str): 33 | self.key_panel.configure(style="{}.TFrame".format(style_class)) 34 | self.label_group.configure(style="{}.TFrame".format(style_class)) 35 | self.title.configure(style="{}.TLabel".format(style_class)) 36 | self.detail.configure(style="{}.TLabel".format(style_class)) 37 | 38 | 39 | class KeyboardControlPanel(): 40 | 41 | def __init__(self, master, hero_vehicle=None, style_config=None): 42 | self.master = master 43 | self.hero_vehicle = hero_vehicle 44 | 45 | default_style_config = { 46 | "Inactivated.TFrame": { 47 | "background": self.master.cget("background") 48 | }, 49 | "Inactivated.TLabel": { 50 | "background": self.master.cget("background"), 51 | "foreground": "black", 52 | }, 53 | "Activated.TFrame": { 54 | "background": "dimgray" 55 | }, 56 | "Activated.TLabel": { 57 | "background": "dimgray", 58 | "foreground": "white", 59 | } 60 | } 61 | 62 | self.style = tkinter.ttk.Style() 63 | for k, v in (style_config or default_style_config).items(): 64 | self.style.configure(k, **v) 65 | 66 | self.frame = tkinter.ttk.Frame(master, padding=2) 67 | self.label_reverse = KeyPanel(self.frame, title="Q", detail="Reverse") 68 | self.label_up = KeyPanel(self.frame, title="W", detail="Throttle") 69 | self.label_autopilot = KeyPanel( 70 | self.frame, title="E", detail="Auto pilot") 71 | self.label_left = KeyPanel(self.frame, title="A", detail="Left") 72 | self.label_down = KeyPanel(self.frame, title="S", detail="Brake") 73 | self.label_right = KeyPanel(self.frame, title="D", detail="Right") 74 | 75 | self.pressed_key = {} 76 | self.is_auto = False 77 | self.reverse = False 78 | 79 | def setup_layout(self): 80 | for i in range(2): 81 | self.frame.grid_rowconfigure(i, weight=1) 82 | 83 | for i in range(3): 84 | self.frame.grid_columnconfigure(i, weight=1) 85 | 86 | grid_args = { 87 | "padx": 2, 88 | "pady": 2, 89 | "sticky": tkinter.NSEW 90 | } 91 | self.frame.pack(fill=tkinter.BOTH, expand=True) 92 | self.label_reverse.key_panel.grid(column=0, row=0, **grid_args) 93 | self.label_up.key_panel.grid(column=1, row=0, **grid_args) 94 | self.label_autopilot.key_panel.grid(column=2, row=0, **grid_args) 95 | self.label_left.key_panel.grid(column=0, row=1, **grid_args) 96 | self.label_down.key_panel.grid(column=1, row=1, **grid_args) 97 | self.label_right.key_panel.grid(column=2, row=1, **grid_args) 98 | 99 | def update_manual_control(self): 100 | control = carla.VehicleControl() 101 | control.throttle = \ 102 | 0.8 if any([i in self.pressed_key for i in ["w", "Up"]]) else 0 103 | control.steer = ( 104 | -0.8 if any([i in self.pressed_key for i in ["a", "Left"]]) else 0 105 | ) + ( 106 | 0.8 if any([i in self.pressed_key for i in ["d", "Right"]]) else 0 107 | ) 108 | control.brake = \ 109 | 1.0 if any([i in self.pressed_key for i in ["s", "Down"]]) else 0 110 | control.reverse = self.reverse 111 | self.hero_vehicle.apply_control(control) 112 | 113 | def on_key_pressed_event(self, event): 114 | self.pressed_key[event.keysym] = True 115 | if event.keysym in ["w", "Up"]: 116 | self.label_up.set_style_class("Activated") 117 | elif event.keysym in ["a", "Left"]: 118 | self.label_left.set_style_class("Activated") 119 | elif event.keysym in ["d", "Right"]: 120 | self.label_right.set_style_class("Activated") 121 | elif event.keysym in ["s", "Down"]: 122 | self.label_down.set_style_class("Activated") 123 | 124 | if self.hero_vehicle is not None and not self.is_auto: 125 | self.update_manual_control() 126 | 127 | def on_key_released_event(self, event): 128 | if event.keysym == "e": 129 | self.is_auto = not self.is_auto 130 | self.label_autopilot.set_style_class( 131 | "Activated" if self.is_auto else "Inactivated") 132 | if self.hero_vehicle is not None: 133 | self.hero_vehicle.set_autopilot(self.is_auto) 134 | elif event.keysym == "q": 135 | self.reverse = not self.reverse 136 | self.label_reverse.set_style_class( 137 | "Activated" if self.reverse else "Inactivated") 138 | elif event.keysym in ["w", "Up"]: 139 | self.label_up.set_style_class("Inactivated") 140 | elif event.keysym in ["a", "Left"]: 141 | self.label_left.set_style_class("Inactivated") 142 | elif event.keysym in ["d", "Right"]: 143 | self.label_right.set_style_class("Inactivated") 144 | elif event.keysym in ["s", "Down"]: 145 | self.label_down.set_style_class("Inactivated") 146 | 147 | if event.keysym in self.pressed_key: 148 | del self.pressed_key[event.keysym] 149 | 150 | if self.hero_vehicle is not None and not self.is_auto: 151 | self.update_manual_control() 152 | 153 | 154 | def create_parser(): 155 | parser = argparse.ArgumentParser( 156 | description="Carla control Client") 157 | parser.add_argument( 158 | "--host", default="127.0.0.1", type=str, 159 | help="The host address of the Carla simulator.") 160 | parser.add_argument( 161 | "-p", "--port", default=2000, type=int, 162 | help="The port of the Carla simulator.") 163 | parser.add_argument( 164 | "--client-timeout", default=10.0, type=float, 165 | help="The timeout of the Carla client.") 166 | return parser 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = create_parser() 171 | args = parser.parse_args() 172 | 173 | client = carla.Client(args.host, args.port, 1) 174 | client.set_timeout(args.client_timeout) 175 | world = client.get_world() 176 | world.wait_for_tick() 177 | 178 | hero_vehicle, = [ 179 | i for i in world.get_actors() 180 | if ( 181 | i.type_id.startswith("vehicle") and 182 | i.attributes.get("role_name") == "hero" 183 | ) 184 | ] 185 | print("Hero vehicle: {}".format(hero_vehicle.id)) 186 | 187 | window_args = { 188 | "title": "Carla Control", 189 | "geometry": "244x124" 190 | } 191 | window = tkinter.Tk() 192 | for k, v in window_args.items(): 193 | getattr(window, k)(v) 194 | 195 | control_panel = KeyboardControlPanel(window, hero_vehicle) 196 | control_panel.setup_layout() 197 | window.bind("", control_panel.on_key_pressed_event) 198 | window.bind("", control_panel.on_key_released_event) 199 | window.mainloop() 200 | -------------------------------------------------------------------------------- /docs/Datasets.md: -------------------------------------------------------------------------------- 1 | 2 | # Datasets 3 | 4 | Currently we support 4 datasets: nuScenes, Waymo Perception, Argoverse 2 Sensor, OpenDV. 5 | 6 | ## nuScenes 7 | 8 | 1. Download the [nuScenes](https://www.nuscenes.org/download) dataset files to `{NUSCENES_TGZ_ROOT}` on your file system. After the dataset is downloaded, there will be some `*.tgz` files under path `{NUSCENES_TGZ_ROOT}`. 9 | 10 | 2. Since the TGZ format does not support random access to content, we recommend converting these files to ZIP format using the following command lines: 11 | 12 | ``` 13 | mkdir -p {NUSCENES_ZIP_ROOT} 14 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval_meta.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval_meta.zip 15 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval01_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval01_blobs.zip 16 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval02_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval02_blobs.zip 17 | ... 18 | python src/dwm/tools/tar2zip.py -i {NUSCENES_TGZ_ROOT}/v1.0-trainval10_blobs.tgz -o {NUSCENES_ZIP_ROOT}/v1.0-trainval10_blobs.zip 19 | ``` 20 | 21 | 3. Now the `{NUSCENES_ZIP_ROOT}` is ready to update the nuScenes file system of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L12). 22 | 23 | 4. Prepare the HD map data. 24 | 25 | 1. Download the `nuScenes-map-expansion-v1.3.zip` file from the [nuScenes](https://www.nuscenes.org/download) to `{NUSCENES_ZIP_ROOT}`. 26 | 27 | 2. Add the file into the [config](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L27), so the dataset can load the map data. 28 | 29 | 5. *Optional*. When the 3D box conditions are used for training, the 12hz metadata is recommended. 30 | 31 | 1. Download 12 Hz nuScenes meta from [Corner Case Scene Generation](https://coda-dataset.github.io/w-coda2024/track2/). After the metadata is downloaded, there will be `interp_12Hz.tar` file. 32 | 33 | 2. Extract and repack the 12 Hz metadata to `interp_12Hz_trainval.zip`, then update the [FS](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L15) and [dataset name](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_nusc_a.json#L206) in the config. 34 | 35 | ``` 36 | python -m tarfile -e interp_12Hz.tar 37 | cd data/nuscenes 38 | python -m zipfile -c ../../interp_12Hz_trainval.zip interp_12Hz_trainval/ 39 | cd ../.. 40 | rm -rf data/ 41 | ``` 42 | 43 | 6. *Alternative solution for 5*. In the case of a broken download link, you can also regenerate 12 Hz annotations according to the instructions of [ASAP](https://github.com/JeffWang987/ASAP/blob/main/docs/prepare_data.md) from the origin nuScenes dataset. 44 | 45 | 7. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images) 46 | 47 | ## Waymo 48 | 49 | There are two versions of the Waymo Perception dataset. This project chooses version 1 (>= 1.4.2) because only this version provides HD map annotation, while version 2 does not provide HD map annotation. 50 | 51 | 1. *Optional*. The Waymo Perception 1.x requires protobuffer, if you try to avoid installing waymo_open_dataset and its dependencies, you need to compile the proto files. Install the [proto buffer compiler](https://github.com/protocolbuffers/protobuf/releases/tag/v25.4), then run following commands to compile proto files. After compilation, `import waymo_open_dataset.dataset_pb2` works by adding `externals/waymo-open-dataset/src` to the environmant variable `PYTHONPATH`. 52 | 53 | ``` 54 | cd externals/waymo-open-dataset/src 55 | protoc --proto_path=. --python_out=. waymo_open_dataset/*.proto 56 | protoc --proto_path=. --python_out=. waymo_open_dataset/protos/*.proto 57 | ``` 58 | 59 | 2. Download the [Waymo Perception](https://waymo.com/open/download) dataset (>= 1.4.2 for the annotation of HD map) to `{WAYMO_ROOT}`. After the dataset is downloaded, there will be some `*.tfrecord` files under the path `{WAYMO_ROOT}/training` and `{WAYMO_ROOT}/validation`. 60 | 61 | 3. Then make information JSON files to support inner-scene random access, by 62 | 63 | ``` 64 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt waymo -i {WAYMO_ROOT}/training -o {WAYMO_ROOT}/training.info.json 65 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt waymo -i {WAYMO_ROOT}/validation -o {WAYMO_ROOT}/validation.info.json 66 | ``` 67 | 68 | 4. Now the `{WAYMO_ROOT}` and its information JSON files are ready to update the Waymo dataset of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_waymo.json#L182). 69 | 70 | 5. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images) 71 | 72 | ## Argoverse 73 | 74 | 1. Download the [Argoverse 2 Sensor](https://www.argoverse.org/av2.html#download-link) dataset files to `{ARGOVERSE_ROOT}` on your file system. After the dataset is downloaded, there will be some `*.tar` files under path `{ARGOVERSE_ROOT}`. 75 | 76 | 2. Then make information JSON files to accelerate the loading speed, by: 77 | 78 | ``` 79 | PYTHONPATH=src python src/dwm/tools/dataset_make_info_json.py -dt argoverse -i {ARGOVERSE_ROOT} -o {ARGOVERSE_ROOT} 80 | ``` 81 | 82 | 3. Now the `{ARGOVERSE_ROOT}` is ready to update the Argoverse file system of your config file, for [example](../configs/ctsd/single_dataset/ctsd_21_crossview_tirda_bm_argo.json#L184). 83 | 84 | 4. Download the annotation of text prompt and update the config following the section [text description for images](#text-description-for-images) 85 | 86 | ## OpenDV 87 | 88 | 1. Download the [OpenDV](https://github.com/OpenDriveLab/DriveAGI/tree/main/opendv) dataset video files to `{OPENDV_ORIGIN_ROOT}` on your file system, and the meta file to `{OPENDV_JSON_META_PATH}` prepared as JSON format. After the dataset is downloaded, there will be about 2K video files under the path `{OPENDV_ORIGIN_ROOT}` in the format of `.mp4` and `.webp`. 89 | 90 | 2. *Optional.* It is recommended to transcode the original video files for better read and seek performance during training, by: 91 | 92 | ``` 93 | apt update && apt install -y ffmpeg 94 | python src/dwm/tools/transcode_video.py -c src/dwm/tools/transcode_video.json -i {OPENDV_ORIGIN_ROOT} -o {OPENDV_ROOT} 95 | ``` 96 | 97 | 3. Now the `{OPENDV_ORIGIN_ROOT}` (or `{OPENDV_ROOT}`) is ready to update the OpenDV [file system config](../configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json#L31), and `{OPENDV_JSON_META_PATH}` to update the [dataset config](../configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json#L409). 98 | 99 | ## KITTI360 100 | 101 | Register an account and download the [KITTI360](https://www.cvlibs.net/datasets/kitti-360/download.php) dataset. We only require the LiDAR data from KITTI360, so you only need to download the Raw Velodyne Scans, 3D Bounding Boxes, and Vehicle Poses. The scenes `2013_05_28_drive_0000_sync` and `2013_05_28_drive_0002_sync` are used for validation, while all other scenes are used as training data. You may keep the files in their original zip format, as our code processes them automatically. 102 | 103 | ## Text description for images 104 | 105 | We made the image captions for both nuScenes, Waymo, Argoverse, OpenDV datasets by [DriveMLM](https://arxiv.org/abs/2312.09245) model. The caption files are available here. 106 | 107 | | Dataset | Downloads | 108 | | :-: | :-: | 109 | | nuScenes | [mini](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/nuscenes_v1.0-mini_caption_v2.zip?download=true), [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/nuscenes_v1.0-trainval_caption_v2.zip?download=true) | 110 | | Waymo | [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/waymo_caption_v2.zip?download=true) | 111 | | Argoverse | [trainval](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/av2_sensor_caption_v2.zip?download=true) | 112 | | OpenDV | [all](https://huggingface.co/datasets/wzhgba/opendwm-data/resolve/main/opendv_caption.zip?download=true) | 113 | 114 | 1. Download the packages above and unzip them. 115 | 116 | 2. You will get some JSON files such as `nuscenes_v1.0-trainval_caption_v2_train.json` (text of image captions), `nuscenes_v1.0-trainval_caption_v2_times_train.json` (indicates the moments of frames selected for image caption annotations in each scenario) 117 | 118 | 3. Update [paths](../configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json#L315) to those files in your config. Please notice that the paths in the "image_description_settings" of all the datasets should be updated to your local downloaded and extracted files. 119 | -------------------------------------------------------------------------------- /src/dwm/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import dwm.common 2 | import torch 3 | import random 4 | from collections import OrderedDict, defaultdict 5 | from typing import Iterator, List, Optional 6 | from torch.utils.data import Dataset, DistributedSampler 7 | 8 | 9 | class VariableVideoBatchSampler(DistributedSampler): 10 | 11 | def __init__( 12 | self, 13 | dataset, 14 | bucket_config: dict, 15 | num_replicas: Optional[int] = None, 16 | rank: Optional[int] = None, 17 | shuffle: bool = True, 18 | seed: int = 0, 19 | drop_last: bool = False, 20 | verbose: bool = False, 21 | num_bucket_build_workers: int = 1, 22 | ) -> None: 23 | super().__init__( 24 | dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed, drop_last=drop_last 25 | ) 26 | self.dataset = dataset 27 | self.bucket = bucket_config 28 | 29 | self.res = [k for k in self.bucket.keys()] 30 | self.res_w = [v[0] for v in self.bucket.values()] 31 | 32 | self.res_tbw = {} 33 | 34 | for k, v in self.bucket.items(): 35 | self.res_tbw[k] = {} 36 | self.res_tbw[k]["t_bs"] = [(tri[0], tri[1]) for tri in v[1]] 37 | self.res_tbw[k]["w"] = [tri[2] for tri in v[1]] 38 | 39 | self.verbose = verbose 40 | self.last_micro_batch_access_index = 0 41 | self.approximate_num_batch = None 42 | 43 | self._get_num_batch_cached_bucket_sample_dict = None 44 | self.num_bucket_build_workers = num_bucket_build_workers 45 | 46 | def __iter__(self) -> Iterator[List[int]]: 47 | 48 | if self._get_num_batch_cached_bucket_sample_dict is not None: 49 | bucket_sample_dict = self._get_num_batch_cached_bucket_sample_dict 50 | self._get_num_batch_cached_bucket_sample_dict = None 51 | else: 52 | bucket_sample_dict = self.group_by_bucket() 53 | 54 | g = torch.Generator() 55 | g.manual_seed(self.seed + self.epoch) 56 | bucket_micro_batch_count = OrderedDict() 57 | bucket_last_consumed = OrderedDict() 58 | 59 | # process the samples 60 | for bucket_id, data_list in bucket_sample_dict.items(): 61 | 62 | # handle droplast 63 | bs_per_gpu = int(bucket_id.split("-")[-1]) 64 | remainder = len(data_list) % bs_per_gpu 65 | 66 | if remainder > 0: 67 | if not self.drop_last: 68 | # if there is remainder, we pad to make it divisible 69 | data_list += data_list[: bs_per_gpu - remainder] 70 | else: 71 | # we just drop the remainder to make it divisible 72 | data_list = data_list[:-remainder] 73 | 74 | bucket_sample_dict[bucket_id] = data_list 75 | 76 | # handle shuffle 77 | if self.shuffle: 78 | data_indices = torch.randperm( 79 | len(data_list), generator=g).tolist() 80 | data_list = [data_list[i] for i in data_indices] 81 | bucket_sample_dict[bucket_id] = data_list 82 | 83 | # compute how many micro-batches each bucket has 84 | num_micro_batches = len(data_list) // bs_per_gpu 85 | bucket_micro_batch_count[bucket_id] = num_micro_batches 86 | 87 | # compute the bucket access order 88 | # each bucket may have more than one batch of data 89 | # thus bucket_id may appear more than 1 time 90 | bucket_id_access_order = [] 91 | for bucket_id, num_micro_batch in bucket_micro_batch_count.items(): 92 | bucket_id_access_order.extend([bucket_id] * num_micro_batch) 93 | 94 | # randomize the access order 95 | if self.shuffle: 96 | bucket_id_access_order_indices = torch.randperm( 97 | len(bucket_id_access_order), generator=g).tolist() 98 | bucket_id_access_order = [ 99 | bucket_id_access_order[i] for i in bucket_id_access_order_indices] 100 | 101 | # make the number of bucket accesses divisible by dp size 102 | remainder = len(bucket_id_access_order) % self.num_replicas 103 | if remainder > 0: 104 | if self.drop_last: 105 | bucket_id_access_order = bucket_id_access_order[: len(bucket_id_access_order) - remainder] 106 | else: 107 | bucket_id_access_order += bucket_id_access_order[: self.num_replicas - remainder] 108 | 109 | # prepare each batch from its bucket 110 | # according to the predefined bucket access order 111 | num_iters = len(bucket_id_access_order) // self.num_replicas 112 | start_iter_idx = self.last_micro_batch_access_index // self.num_replicas 113 | 114 | # re-compute the micro-batch consumption 115 | # this is useful when resuming from a state dict with a different number of GPUs 116 | self.last_micro_batch_access_index = start_iter_idx * self.num_replicas 117 | for i in range(self.last_micro_batch_access_index): 118 | bucket_id = bucket_id_access_order[i] 119 | bucket_bs = int(bucket_id.split("-")[-1]) 120 | if bucket_id in bucket_last_consumed: 121 | bucket_last_consumed[bucket_id] += bucket_bs 122 | else: 123 | bucket_last_consumed[bucket_id] = bucket_bs 124 | 125 | for i in range(start_iter_idx, num_iters): 126 | bucket_access_list = bucket_id_access_order[ 127 | i * self.num_replicas: (i + 1) * self.num_replicas] 128 | self.last_micro_batch_access_index += self.num_replicas 129 | 130 | # compute the data samples consumed by each access 131 | bucket_access_boundaries = [] 132 | for bucket_id in bucket_access_list: 133 | bucket_bs = int(bucket_id.split("-")[-1]) 134 | last_consumed_index = bucket_last_consumed.get(bucket_id, 0) 135 | bucket_access_boundaries.append( 136 | [last_consumed_index, last_consumed_index + bucket_bs]) 137 | 138 | # update consumption 139 | if bucket_id in bucket_last_consumed: 140 | bucket_last_consumed[bucket_id] += bucket_bs 141 | else: 142 | bucket_last_consumed[bucket_id] = bucket_bs 143 | 144 | # compute the range of data accessed by each GPU 145 | bucket_id = bucket_access_list[self.rank] 146 | boundary = bucket_access_boundaries[self.rank] 147 | cur_micro_batch = bucket_sample_dict[bucket_id][boundary[0]: boundary[1]] 148 | 149 | # encode t, h, w into the sample index 150 | 151 | b_id = bucket_id.split("-") 152 | real_t, real_h, real_w = b_id[-2], b_id[0], b_id[1] 153 | cur_micro_batch = [ 154 | f"{idx}-{real_t}-{real_h}-{real_w}" for idx in cur_micro_batch] 155 | 156 | if len(cur_micro_batch) > 0: 157 | yield cur_micro_batch 158 | 159 | self.reset() 160 | 161 | def __len__(self) -> int: 162 | return self.get_num_batch() // dist.get_world_size() 163 | 164 | def group_by_bucket(self) -> dict: 165 | 166 | bucket_sample_dict = OrderedDict() 167 | for i in range(len(self.dataset)): 168 | 169 | res_i = random.choices(self.res, weights=self.res_w, k=1)[0] 170 | t_bs_i = random.choices( 171 | self.res_tbw[res_i]['t_bs'], weights=self.res_tbw[res_i]['w'], k=1)[0] 172 | 173 | bucket_id = f"{res_i}-{t_bs_i[0]}-{t_bs_i[1]}" 174 | 175 | if bucket_id not in bucket_sample_dict: 176 | bucket_sample_dict[bucket_id] = [] 177 | bucket_sample_dict[bucket_id].append(i) 178 | 179 | return bucket_sample_dict 180 | 181 | def get_num_batch(self) -> int: 182 | bucket_sample_dict = self.group_by_bucket() 183 | self._get_num_batch_cached_bucket_sample_dict = bucket_sample_dict 184 | 185 | return self.approximate_num_batch 186 | 187 | def reset(self): 188 | self.last_micro_batch_access_index = 0 189 | 190 | def state_dict(self, num_steps: int) -> dict: 191 | # the last_micro_batch_access_index in the __iter__ is often 192 | # not accurate during multi-workers and data prefetching 193 | # thus, we need the user to pass the actual steps which have been executed 194 | # to calculate the correct last_micro_batch_access_index 195 | return {"seed": self.seed, "epoch": self.epoch, "last_micro_batch_access_index": num_steps * self.num_replicas} 196 | 197 | def load_state_dict(self, state_dict: dict) -> None: 198 | self.__dict__.update(state_dict) 199 | -------------------------------------------------------------------------------- /docs/LiDAR_Generation.md: -------------------------------------------------------------------------------- 1 | # Layout-Condition LiDAR Generation with Masked Generative Transformer 2 | 3 | 4 | ## Introduction 5 | 6 | We propose a pipeline for generating LiDAR data conditioned on layout information using the Mask Generative Image Transformer (MaskGIT) [[1]](#1). Our approach builds upon the models introduced in Copilot4D [[5]](#5) and UltraLiDAR [[4]](#4). 7 | 8 | ## Method 9 | 10 | We first train a Vector Quantized Variational AutoEncoder (VQ-VAE) to tokenize LiDAR data into a 2D latent space. Then, 11 | the LiDAR MaskGIT model incorporates the layout information such as High-definition maps (HDmaps) and 3D object 12 | bounding boxes (3Dboxes) and guides the generation of LiDAR in the latent space. 13 | 14 | ### LiDAR VQ-VAE 15 | 16 | We build our LiDAR VQ-VAE following the approach of UltraLiDAR [[4]](#4) and Copilot4D [[5]](#5). The point cloud $\mathbf{x}$ is fed into a voxelizer $V$ and converted into a Birds-Eye-View (BEV) image. Additionally, we adopt an auxiliary depth rendering branch during decoding. Specifically, given a latent representation $\mathbf z$ of a LiDAR point cloud, the regular decoder transforms the latent back into the original voxel shape, and binary cross entropy loss is used to optimize the network. Furthermore, a depth render network $D$ decodes the latent into a 3D feature voxel, which is used to render the depth of each point. For any point $p$ in $\mathbf x$, we sample a ray from the LiDAR center to $p$ and use the method from [[3]](#3) to calculate the depth value in this direction. For further details, please refer to [[5]](#5). 17 | 18 | 19 |

20 | 21 |

22 | 23 | 24 | ### LiDAR MaskGIT 25 | 26 | Our LiDAR MaskGIT is designed to generate sequences of LiDAR point clouds conditioned on layouts. Specifically, given the LiDAR point clouds from the first $k$ frames, our model can predict subsequent LiDAR point clouds at following timestamps guided by layout conditions such as 3D bounding boxes (3D boxes) and high-definition maps (HD maps). Additionally, we extend our model to directly generate LiDAR data without reference frames. 27 | 28 | We follow the Copilot4D framework to build our LiDAR MaskGIT, as illustrated in Figure [1](#fig-main). The input to LiDAR MaskGIT is the masked quantized latent, where masked regions are filled with mask tokens, and the model outputs the probability distribution of VQ-VAE codebook IDs at each position. The architecture employs a spatio-temporal transformer that interleaves spatial and temporal attention. We also adopt free space suppression to encourage more diverse generation results. Please refer to~\cite{copilot4d} for additional details. Furthermore, our model can generate LiDAR data based on layout conditions, which we describe below. 29 | 30 | #### Layout Conditioning 31 | Since LiDAR point clouds are encoded into 2D BEV space, both 3D boxes and HD maps are also projected into BEV images. Similar to [[2]](#2), each instance is mapped into the color space. These two conditional images are concatenated and processed through a lightweight image adapter, generating multi-level features that are added to the latent representation at the corresponding transformer layers. 32 | 33 | 34 | ## Experiment 35 | We conduct our experiments on nuScenes [[6]](#6) and KITTI360 [[7]](#7) datasets and report the quantitative results in the following table. 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 |
DatasetIoUCDMMDJSD
nuScenes0.0554.438----------
KITTI3600.0455.8380.004610.471
nuScenes Temporal0.1263.487----------
KITTI360 Temporal0.1173.3470.004110.313
75 | 76 | ## Visualization 77 | In this section, we provide some qualitative results of our method. First, you need to install `open3d` in the environment. The visualization code is provided in `src/dwm/utils/lidar_visualizer.py`. You can run the following bash script to generate visualization results. 78 | ``` 79 | python src/dwm/utils/lidar_visualizer.py \ 80 | --data_type nuscenes \ 81 | --lidar_root /path/to/generated/lidar \ 82 | --data_root /path/to/nuscenes/json \ 83 | --output_path /path/to/output/folder \ 84 | ``` 85 | ### Single Frame Generation 86 | #### NuScenes 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 |
96 | 97 | #### KITTI360 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 |
107 | 108 | 109 | ### Autoregressive Generation 110 | #### NuScenes 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 |
Reference Frame
Frame=2
Frame=4
Frame=6
120 | 121 | #### KITTI360 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 |
Reference Frame
Frame=2
Frame=4
Frame=6
131 | 132 | ## References 133 | 134 | [1] Huiwen Chang, Han Zhang, Lu Jiang, Ce Liu, and William T Freeman. Maskgit: Masked generative image transformer. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 11315–11325, 2022. 135 | 136 | [2] Rui Chen, Zehuan Wu, Yichen Liu, Yuxin Guo, Jingcheng Ni, Haifeng Xia, and Siyu Xia. Unimlvg: Unified framework for multi-view long video generation with comprehensive control capabilities for autonomous driving. arXiv preprint arXiv:2412.04842, 2024. 137 | 138 | [3] Cheng Sun, Min Sun, and Hwann-Tzong Chen. Improved direct voxel grid optimization for radiance fields reconstruction. arXiv preprint arXiv:2206.05085, 2022. 139 | 140 | [4] Yuwen Xiong, Wei-Chiu Ma, Jingkang Wang, and Raquel Urtasun. Ultralidar: Learning compact representations for lidar completion and generation. arXiv preprint arXiv:2311.01448, 2023. 141 | 142 | [5] Lunjun Zhang, Yuwen Xiong, Ze Yang, Sergio Casas, Rui Hu, and Raquel Urtasun. Copilot4d: Learning unsupervised world models for autonomous driving via discrete diffusion. arXiv preprint arXiv:2311.01017, 2023. 143 | 144 | [6] Holger Caesar, Varun Bankiti, Alex H Lang, Sourabh Vora, Venice Erin Liong, Qiang Xu, Anush Krishnan, Yu Pan, Giancarlo Baldan, 145 | and Oscar Beijbom. nuscenes: A multimodal dataset for autonomous driving. In Proceedings of the IEEE/CVF conference on computer 146 | vision and pattern recognition, pages 11621–11631, 2020. 147 | 148 | [7] Yiyi Liao, Jun Xie, and Andreas Geiger. Kitti-360: A novel dataset and benchmarks for urban scene understanding in 2d and 3d. IEEE 149 | Transactions on Pattern Analysis and Machine Intelligence, 45(3):3292–3310, 2022. 150 | -------------------------------------------------------------------------------- /src/dwm/schedulers/temporal_independent.py: -------------------------------------------------------------------------------- 1 | import diffusers.schedulers 2 | import diffusers.utils.torch_utils 3 | import torch 4 | 5 | 6 | class DDPMScheduler(diffusers.schedulers.DDPMScheduler): 7 | 8 | def add_noise( 9 | self, original_samples: torch.Tensor, noise: torch.Tensor, 10 | timesteps: torch.IntTensor, 11 | ): 12 | while len(timesteps.shape) < len(original_samples.shape): 13 | timesteps = timesteps.unsqueeze(-1) 14 | 15 | # Make sure alphas_cumprod and timestep have same device and dtype as 16 | # original_samples Move the self.alphas_cumprod to device to avoid 17 | # redundant CPU to GPU data movement for the subsequent add_noise calls 18 | self.alphas_cumprod = self.alphas_cumprod\ 19 | .to(device=original_samples.device) 20 | alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) 21 | timesteps = timesteps.to(original_samples.device) 22 | 23 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 24 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 25 | noisy_samples = sqrt_alpha_prod * original_samples + \ 26 | sqrt_one_minus_alpha_prod * noise 27 | return noisy_samples 28 | 29 | def get_velocity( 30 | self, sample: torch.Tensor, noise: torch.Tensor, 31 | timesteps: torch.IntTensor 32 | ): 33 | while len(timesteps.shape) < len(sample.shape): 34 | timesteps = timesteps.unsqueeze(-1) 35 | 36 | # Make sure alphas_cumprod and timestep have same device and dtype as 37 | # sample 38 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) 39 | alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) 40 | timesteps = timesteps.to(sample.device) 41 | 42 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 43 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 44 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 45 | return velocity 46 | 47 | 48 | class DDIMScheduler(diffusers.schedulers.DDIMScheduler): 49 | 50 | def _get_variance(self, timestep, prev_timestep): 51 | alpha_prod_t = self.alphas_cumprod[timestep] 52 | alpha_prod_t_prev = torch.where( 53 | prev_timestep >= 0, 54 | self.alphas_cumprod[prev_timestep], 55 | torch.ones( 56 | prev_timestep.shape, dtype=self.alphas_cumprod.dtype, 57 | device=self.alphas_cumprod.device) * self.final_alpha_cumprod 58 | ) 59 | beta_prod_t = 1 - alpha_prod_t 60 | beta_prod_t_prev = 1 - alpha_prod_t_prev 61 | 62 | variance = (beta_prod_t_prev / beta_prod_t) * \ 63 | (1 - alpha_prod_t / alpha_prod_t_prev) 64 | 65 | return variance 66 | 67 | def step( 68 | self, model_output: torch.Tensor, timestep: torch.IntTensor, 69 | sample: torch.Tensor, eta: float = 0.0, 70 | use_clipped_model_output: bool = False, generator=None, 71 | variance_noise=None, return_dict: bool = True, 72 | ): 73 | if self.num_inference_steps is None: 74 | raise ValueError( 75 | "Number of inference steps is 'None', you need to run " 76 | "'set_timesteps' after creating the scheduler") 77 | 78 | while len(timestep.shape) < len(sample.shape): 79 | timestep = timestep.unsqueeze(-1) 80 | 81 | # 1. get previous step value (=t-1) 82 | prev_timestep = timestep - self.config.num_train_timesteps // \ 83 | self.num_inference_steps 84 | 85 | # 2. compute alphas, betas 86 | alpha_prod_t = self.alphas_cumprod[timestep].to(device=sample.device) 87 | alpha_prod_t_prev = torch.where( 88 | prev_timestep >= 0, 89 | self.alphas_cumprod[prev_timestep], 90 | torch.ones( 91 | prev_timestep.shape, dtype=self.alphas_cumprod.dtype, 92 | device=self.alphas_cumprod.device) * self.final_alpha_cumprod 93 | ).to(device=sample.device) 94 | 95 | beta_prod_t = 1 - alpha_prod_t 96 | 97 | # 3. compute predicted original sample from predicted noise also called 98 | # "predicted x_0" of formula (12) from 99 | # https://arxiv.org/pdf/2010.02502.pdf 100 | if self.config.prediction_type == "epsilon": 101 | pred_original_sample = \ 102 | (sample - beta_prod_t ** (0.5) * model_output) / \ 103 | alpha_prod_t ** (0.5) 104 | pred_epsilon = model_output 105 | elif self.config.prediction_type == "sample": 106 | pred_original_sample = model_output 107 | pred_epsilon = \ 108 | (sample - alpha_prod_t ** (0.5) * pred_original_sample) / \ 109 | beta_prod_t ** (0.5) 110 | elif self.config.prediction_type == "v_prediction": 111 | pred_original_sample = (alpha_prod_t ** 0.5) * sample - \ 112 | (beta_prod_t ** 0.5) * model_output 113 | pred_epsilon = (alpha_prod_t ** 0.5) * model_output + \ 114 | (beta_prod_t ** 0.5) * sample 115 | else: 116 | raise ValueError( 117 | "prediction_type given as {} must be one of `epsilon`, " 118 | "`sample`, or `v_prediction`" 119 | .format(self.config.prediction_type)) 120 | 121 | # 4. Clip or threshold "predicted x_0" 122 | if self.config.thresholding: 123 | pred_original_sample = self._threshold_sample(pred_original_sample) 124 | elif self.config.clip_sample: 125 | pred_original_sample = pred_original_sample.clamp( 126 | -self.config.clip_sample_range, self.config.clip_sample_range) 127 | 128 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 129 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 130 | variance = self._get_variance(timestep, prev_timestep)\ 131 | .to(sample.device) 132 | std_dev_t = eta * variance ** (0.5) 133 | 134 | if use_clipped_model_output: 135 | # the pred_epsilon is always re-derived from the clipped x_0 in 136 | # Glide 137 | pred_epsilon = \ 138 | (sample - alpha_prod_t ** (0.5) * pred_original_sample) / \ 139 | beta_prod_t ** (0.5) 140 | 141 | # 6. compute "direction pointing to x_t" of formula (12) from 142 | # https://arxiv.org/pdf/2010.02502.pdf 143 | pred_sample_direction = \ 144 | (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 145 | 146 | # 7. compute x_t without "random noise" of formula (12) from 147 | # https://arxiv.org/pdf/2010.02502.pdf 148 | prev_sample = alpha_prod_t_prev ** (0.5) * \ 149 | pred_original_sample + pred_sample_direction 150 | 151 | if eta > 0: 152 | if variance_noise is not None and generator is not None: 153 | raise ValueError( 154 | "Cannot pass both generator and variance_noise. Please " 155 | "make sure that either `generator` or `variance_noise` " 156 | "stays `None`.") 157 | 158 | if variance_noise is None: 159 | variance_noise = diffusers.utils.torch_utils.randn_tensor( 160 | model_output.shape, generator=generator, 161 | device=model_output.device, dtype=model_output.dtype) 162 | 163 | variance = std_dev_t * variance_noise 164 | prev_sample = prev_sample + variance 165 | 166 | if not return_dict: 167 | return (prev_sample, pred_original_sample) 168 | 169 | return diffusers.schedulers.scheduling_ddim.DDIMSchedulerOutput( 170 | prev_sample=prev_sample, pred_original_sample=pred_original_sample) 171 | 172 | 173 | class FlowMatchEulerDiscreteScheduler( 174 | diffusers.schedulers.FlowMatchEulerDiscreteScheduler 175 | ): 176 | def step_by_indices( 177 | self, model_output: torch.FloatTensor, timestep_indices, 178 | sample: torch.FloatTensor, return_dict: bool = True 179 | ): 180 | if isinstance(timestep_indices, torch.Tensor): 181 | while len(timestep_indices.shape) < model_output.ndim: 182 | timestep_indices = timestep_indices.unsqueeze(-1) 183 | 184 | # Upcast to avoid precision issues when computing prev_sample 185 | sample = sample.to(torch.float32) 186 | 187 | sigma = self.sigmas[timestep_indices] 188 | sigma_next = self.sigmas[timestep_indices + 1] 189 | prev_sample = sample + (sigma_next - sigma) * model_output 190 | 191 | # Cast sample back to model compatible dtype 192 | prev_sample = prev_sample.to(model_output.dtype) 193 | if not return_dict: 194 | return (prev_sample,) 195 | 196 | return diffusers.schedulers.scheduling_flow_match_euler_discrete.\ 197 | FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 198 | --------------------------------------------------------------------------------