├── LICENSE ├── README.md ├── args_file.py ├── ckp ├── demo.jpg └── video.gif ├── infer.py ├── infer.sh ├── models ├── blocks.py ├── controlnet1x1.py ├── pipeline_controlnet_1x1_4dunet.py └── unet_2d_condition_multiview.py ├── train.py ├── train.sh └── utils ├── common.py ├── dataset_nusmtv.py └── gen_mtp.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Leheng Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SyntheOcc 2 | 3 | > SyntheOcc: Synthesize Geometric-Controlled Street View Images through 3D Semantic MPIs
4 | > [Leheng Li](https://len-li.github.io), Weichao Qiu, Yingjie Cai, Xu Yan, Qing Lian, Bingbing Liu, Ying-Cong Chen 5 | 6 | Update (24.11.20): add video generation results in project page. 7 | 8 | SyntheOcc is a project focused on synthesizing image data under geometry control (occupancy voxel). This repository provides tools and scripts to process, train, and generate synthetic image data in the nuScenes dataset, using occupancy control. 9 | #### [Project Page](https://len-li.github.io/syntheocc-web) | [Paper](https://arxiv.org/) | [Video](https://len-li.github.io/syntheocc-web/videos/teaser-occedit.mp4) | [Checkpoint](https://huggingface.co/lilelife/SyntheOcc) 10 | 11 | 12 | ![video generation](ckp/video.gif) 13 | 14 | ## Table of Contents 15 | 16 | - [SyntheOcc](#syntheocc) 17 | - [Project Page | Paper | Video | Checkpoint](#project-page--paper--video--checkpoint) 18 | - [Table of Contents](#table-of-contents) 19 | - [Installation](#installation) 20 | - [Prepare Dataset](#prepare-dataset) 21 | - [Prepare Checkpoint](#prepare-checkpoint) 22 | - [Train](#train) 23 | - [Inference](#inference) 24 | - [Acknowledgment](#acknowledgment) 25 | - [BibTeX](#bibtex) 26 | 27 | 28 | 29 | 30 | ## Installation 31 | 32 | To get started with SyntheOcc, follow these steps: 33 | 34 | 1. **Clone the repository:** 35 | ```bash 36 | git clone https://github.com/Len-Li/SyntheOcc.git 37 | cd SyntheOcc 38 | ``` 39 | 40 | 2. **Set up a environment :** 41 | ```bash 42 | pip install torch torchvision transformers 43 | pip install diffusers==0.26.0.dev0 44 | # We use a old version of diffusers, please take care of it. 45 | ``` 46 | 47 | 48 | 49 | 50 | ## Prepare Dataset 51 | 52 | To use SyntheOcc, follow the steps below: 53 | 54 | 1. **Download the NuScenes dataset:** 55 | - Register and download the dataset from the [NuScenes website](https://www.nuscenes.org/nuscenes). 56 | - Download the [train](https://github.com/JeffWang987/OpenOccupancy/releases/tag/train_pkl)/[val](https://github.com/JeffWang987/OpenOccupancy/releases/tag/val_pkl) pickle files from OpenOccupancy and put them in `data/nuscenes` folder. 57 | 58 | 59 | 60 | After preparation, you will be able to see the following directory structure: 61 | 62 | ``` 63 | SyntheOcc/ 64 | ├── data/ 65 | │ ├── nuscenes/ 66 | │ │ ├── samples/ 67 | │ │ ├── sweeps/ 68 | | | ├── v1.0-trainval/ 69 | | | ├── nuscenes_occ_infos_train.pkl 70 | | | ├── nuscenes_occ_infos_val.pkl 71 | ``` 72 | 2. **Download occupancy annotation from [SurroundOcc](https://github.com/weiyithu/SurroundOcc/blob/main/docs/data.md)** 73 | 74 | You need to generate the high resolution 0.2m occupancy from mesh vertices and put them in `data/nuscenes` folder. 75 | 76 | You can also download the 0.5m occupancy. The precision may be limited compared with 0.2m. 77 | 78 | 79 | 3. **Run the script to convert occupancy to 3D semantic multiplane images:** 80 | ```bash 81 | torchrun utils/gen_mtp.py 82 | ``` 83 | It will generate the 3D semantic MPI and save them in `data/nuscenes/samples_syntheocc_surocc/` folder. 84 | 85 | ## Prepare Checkpoint 86 | Our model is based on [stable-diffusion-v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1). Please put them at `./SyntheOcc/ckp/`. 87 | 88 | Our checkpoint of SyntheOcc is released in [huggingface](https://huggingface.co/lilelife/SyntheOcc). If you want to use our model to run inference. Please also put them at `./SyntheOcc/ckp/`. 89 | 90 | ## Train 91 | 92 | ```bash 93 | bash train.sh 94 | ``` 95 | The details of the script are as follows: 96 | ```bash 97 | export WANDB_DISABLED=True 98 | export HF_HUB_OFFLINE=True 99 | 100 | export MODEL_DIR="./ckp/stable-diffusion-v2-1" 101 | 102 | export EXP_NAME="train_syntheocc" 103 | export OUTPUT_DIR="./ckp/$EXP_NAME" 104 | export SAVE_IMG_DIR="vis_dir/$EXP_NAME/samples" 105 | export DATA_USED="samples_syntheocc_surocc" 106 | 107 | accelerate launch --gpu_ids 0, --num_processes 1 --main_process_port 3226 train.py \ 108 | --pretrained_model_name_or_path=$MODEL_DIR \ 109 | --output_dir=$OUTPUT_DIR \ 110 | --width=800 \ 111 | --height=448 \ 112 | --learning_rate=2e-5 \ 113 | --num_train_epochs=6 \ 114 | --train_batch_size=1 \ 115 | --mixed_precision="fp16" \ 116 | --num_validation_images=2 \ 117 | --validation_steps=1000 \ 118 | --checkpointing_steps=5000 \ 119 | --checkpoints_total_limit=10 \ 120 | --ctrl_channel=257 \ 121 | --enable_xformers_memory_efficient_attention \ 122 | --report_to='wandb' \ 123 | --use_cbgs=True \ 124 | --mtp_path='samples_syntheocc_surocc' \ 125 | --resume_from_checkpoint="latest" 126 | ``` 127 | 128 | The training process will take 1~2 days to complete, depending on the hardware. We use a fixed batchsize=1, image resolution = (800, 448), which will take 25GB memory for each GPU. 129 | 130 | ## Inference 131 | 132 | ```bash 133 | bash infer.sh 134 | ``` 135 | You will find generated images at `./ckp/$EXP_NAME/samples`. The image is shown as follows: 136 | ![image](./ckp/demo.jpg) 137 | 138 | 139 | 140 | ## Acknowledgment 141 | Additionally, we express our gratitude to the authors of the following opensource projects: 142 | 143 | - [SurroundOcc](https://github.com/weiyithu/SurroundOcc) (Occupancy annotation) 144 | - [OpenOccupancy](https://github.com/JeffWang987/OpenOccupancy) (Occupancy annotation) 145 | - [MagicDrive](https://github.com/cure-lab/MagicDrive) (Cross-view and cross-frame attention implementation) 146 | - [Diffusers controlnet example](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) (Diffusion model implementation) 147 | 148 | 149 | 150 | 151 | 152 | ## BibTeX 153 | 154 | ```bibtex 155 | @article{li2024syntheocc, 156 | title={SyntheOcc: Synthesize Geometric-Controlled Street View Images through 3D Semantic MPIs}, 157 | author={Li, Leheng and Qiu, Weichao and Cai, Yingjie and Yan, Xu and Lian, Qing and Liu, Bingbing and Chen, Ying-Cong}, 158 | journal={arXiv preprint arXiv:2410.00337}, 159 | year={2024} 160 | } 161 | ``` 162 | 163 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 164 | 165 | 166 | -------------------------------------------------------------------------------- /args_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(input_args=None): 5 | parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") 6 | parser.add_argument( 7 | "--pretrained_model_name_or_path", 8 | type=str, 9 | default="/hpc2hdd/home/lli181/long_video/animate-anything/download/AI-ModelScope/stable-diffusion-v2-1", 10 | # required=True, 11 | help="Path to pretrained model or model identifier from huggingface.co/models.", 12 | ) 13 | parser.add_argument( 14 | "--dataroot_path", 15 | type=str, 16 | default='./data/nuscenes', 17 | help="The location of nuScenes dataset.", 18 | ) 19 | parser.add_argument( 20 | "--mtp_path", 21 | type=str, 22 | default='samples_occmask56', 23 | help="Path to the multiplane image.", 24 | ) 25 | parser.add_argument( 26 | "--gen_train_or_val", 27 | type=str, 28 | default='val', 29 | help="Use which model to run inference.", 30 | ) 31 | parser.add_argument( 32 | "--model_path_infer", 33 | type=str, 34 | default=None, 35 | help="Use which model to run inference.", 36 | ) 37 | parser.add_argument( 38 | "--save_img_path", 39 | type=str, 40 | default=None, 41 | help="Path to the saved image generated by models.", 42 | ) 43 | parser.add_argument( 44 | "--cam_name_list", 45 | type=list, 46 | default=['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT'], 47 | help="The list of camera name.", 48 | ) 49 | parser.add_argument( 50 | "--use_sdxl", 51 | type=bool, 52 | default=False, 53 | help="Path to the saved image generated by models.", 54 | ) 55 | parser.add_argument( 56 | "--use_cbgs", 57 | type=bool, 58 | default=False, 59 | help="Path to the saved image generated by models.", 60 | ) 61 | parser.add_argument( 62 | "--cfg_scale", 63 | type=float, 64 | default=7.0, 65 | help="Path to the saved image generated by models.", 66 | ) 67 | parser.add_argument( 68 | "--curr_gpu", 69 | type=int, 70 | default=0, 71 | help="Path to the saved image generated by models.", 72 | ) 73 | parser.add_argument( 74 | "--controlnet_model_name_or_path", 75 | type=str, 76 | default=None, 77 | help="Path to pretrained controlnet model or model identifier from huggingface.co/models." 78 | " If not specified controlnet weights are initialized from unet.", 79 | ) 80 | parser.add_argument( 81 | "--revision", 82 | type=str, 83 | default=None, 84 | required=False, 85 | help="Revision of pretrained model identifier from huggingface.co/models.", 86 | ) 87 | parser.add_argument( 88 | "--variant", 89 | type=str, 90 | default=None, 91 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 92 | ) 93 | parser.add_argument( 94 | "--tokenizer_name", 95 | type=str, 96 | default=None, 97 | help="Pretrained tokenizer name or path if not the same as model_name", 98 | ) 99 | parser.add_argument( 100 | "--output_dir", 101 | type=str, 102 | default="controlnet-model", 103 | help="The output directory where the model predictions and checkpoints will be written.", 104 | ) 105 | parser.add_argument( 106 | "--cache_dir", 107 | type=str, 108 | default=None, 109 | help="The directory where the downloaded models and datasets will be stored.", 110 | ) 111 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 112 | parser.add_argument( 113 | "--resolution", 114 | type=int, 115 | default=512, 116 | help=( 117 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 118 | " resolution" 119 | ), 120 | ) 121 | parser.add_argument( 122 | "--width", 123 | type=int, 124 | default=800, 125 | ) 126 | parser.add_argument( 127 | "--height", 128 | type=int, 129 | default=448, 130 | ) 131 | parser.add_argument( 132 | "--ctrl_channel", 133 | type=int, 134 | default=257, 135 | ) 136 | parser.add_argument( 137 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 138 | ) 139 | parser.add_argument("--num_train_epochs", type=int, default=1) 140 | parser.add_argument( 141 | "--max_train_steps", 142 | type=int, 143 | default=None, 144 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 145 | ) 146 | parser.add_argument( 147 | "--checkpointing_steps", 148 | type=int, 149 | default=500, 150 | help=( 151 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 152 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 153 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 154 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 155 | "instructions." 156 | ), 157 | ) 158 | parser.add_argument( 159 | "--checkpoints_total_limit", 160 | type=int, 161 | default=3, 162 | help=("Max number of checkpoints to store."), 163 | ) 164 | parser.add_argument( 165 | "--resume_from_checkpoint", 166 | type=str, 167 | default=None, 168 | help=( 169 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 170 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 171 | ), 172 | ) 173 | parser.add_argument( 174 | "--gradient_accumulation_steps", 175 | type=int, 176 | default=1, 177 | help="Number of updates steps to accumulate before performing a backward/update pass.", 178 | ) 179 | parser.add_argument( 180 | "--gradient_checkpointing", 181 | action="store_true", 182 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 183 | ) 184 | parser.add_argument( 185 | "--learning_rate", 186 | type=float, 187 | default=5e-6, 188 | help="Initial learning rate (after the potential warmup period) to use.", 189 | ) 190 | parser.add_argument( 191 | "--scale_lr", 192 | action="store_true", 193 | default=False, 194 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 195 | ) 196 | parser.add_argument( 197 | "--lr_scheduler", 198 | type=str, 199 | default="constant_with_warmup", 200 | # default="constant", 201 | help=( 202 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 203 | ' "constant", "constant_with_warmup"]' 204 | ), 205 | ) 206 | parser.add_argument( 207 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 208 | ) 209 | parser.add_argument( 210 | "--lr_num_cycles", 211 | type=int, 212 | default=1, 213 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 214 | ) 215 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 216 | parser.add_argument( 217 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 218 | ) 219 | parser.add_argument( 220 | "--dataloader_num_workers", 221 | type=int, 222 | # default=0, 223 | # default=12, 224 | default=4, 225 | help=( 226 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 227 | ), 228 | ) 229 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 230 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 231 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 232 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 233 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 234 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 235 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 236 | parser.add_argument( 237 | "--hub_model_id", 238 | type=str, 239 | default=None, 240 | help="The name of the repository to keep in sync with the local `output_dir`.", 241 | ) 242 | parser.add_argument( 243 | "--logging_dir", 244 | type=str, 245 | default="logs", 246 | help=( 247 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 248 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 249 | ), 250 | ) 251 | parser.add_argument( 252 | "--allow_tf32", 253 | action="store_true", 254 | help=( 255 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 256 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 257 | ), 258 | ) 259 | parser.add_argument( 260 | "--report_to", 261 | type=str, 262 | default="tensorboard", 263 | help=( 264 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 265 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 266 | ), 267 | ) 268 | parser.add_argument( 269 | "--mixed_precision", 270 | type=str, 271 | default=None, 272 | choices=["no", "fp16", "bf16"], 273 | help=( 274 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 275 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 276 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 277 | ), 278 | ) 279 | parser.add_argument( 280 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 281 | ) 282 | parser.add_argument( 283 | "--set_grads_to_none", 284 | action="store_true", 285 | help=( 286 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 287 | " behaviors, so disable this argument if it causes any problems. More info:" 288 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 289 | ), 290 | ) 291 | parser.add_argument( 292 | "--dataset_name", 293 | type=str, 294 | default=None, 295 | help=( 296 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 297 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 298 | " or to a folder containing files that 🤗 Datasets can understand." 299 | ), 300 | ) 301 | parser.add_argument( 302 | "--dataset_config_name", 303 | type=str, 304 | default=None, 305 | help="The config of the Dataset, leave as None if there's only one config.", 306 | ) 307 | parser.add_argument( 308 | "--train_data_dir", 309 | type=str, 310 | default=None, 311 | help=( 312 | "A folder containing the training data. Folder contents must follow the structure described in" 313 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 314 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 315 | ), 316 | ) 317 | parser.add_argument( 318 | "--pretrained_vae_model_name_or_path", 319 | type=str, 320 | default=None, 321 | help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", 322 | ) 323 | parser.add_argument( 324 | "--max_train_samples", 325 | type=int, 326 | default=None, 327 | help=( 328 | "For debugging purposes or quicker training, truncate the number of training examples to this " 329 | "value if set." 330 | ), 331 | ) 332 | parser.add_argument( 333 | "--proportion_empty_prompts", 334 | type=float, 335 | default=0, 336 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 337 | ) 338 | parser.add_argument( 339 | "--validation_prompt", 340 | type=str, 341 | default=None, 342 | nargs="+", 343 | help=( 344 | "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." 345 | " Provide either a matching number of `--validation_image`s, a single `--validation_image`" 346 | " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--validation_image", 351 | type=str, 352 | default=None, 353 | nargs="+", 354 | help=( 355 | "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" 356 | " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" 357 | " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" 358 | " `--validation_image` that will be used with all `--validation_prompt`s." 359 | ), 360 | ) 361 | parser.add_argument( 362 | "--num_validation_images", 363 | type=int, 364 | default=2, 365 | help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", 366 | ) 367 | parser.add_argument( 368 | "--validation_steps", 369 | type=int, 370 | default=100, 371 | help=( 372 | "Run validation every X steps. Validation consists of running the prompt" 373 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 374 | " and logging the images." 375 | ), 376 | ) 377 | parser.add_argument( 378 | "--tracker_project_name", 379 | type=str, 380 | default="train_controlnet", 381 | help=( 382 | "The `project_name` argument passed to Accelerator.init_trackers for" 383 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 384 | ), 385 | ) 386 | 387 | if input_args is not None: 388 | args = parser.parse_args(input_args) 389 | else: 390 | args = parser.parse_args() 391 | 392 | 393 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 394 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 395 | 396 | 397 | if args.resolution % 8 != 0: 398 | raise ValueError( 399 | "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." 400 | ) 401 | 402 | return args 403 | -------------------------------------------------------------------------------- /ckp/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/SyntheOcc/d7dccb5299aee7649343b34825167dbf92b757f4/ckp/demo.jpg -------------------------------------------------------------------------------- /ckp/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnVision-Research/SyntheOcc/d7dccb5299aee7649343b34825167dbf92b757f4/ckp/video.gif -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import cv2 4 | import os 5 | import tqdm 6 | import pickle 7 | import numpy as np 8 | from torchvision.io import read_image 9 | 10 | import torchvision.transforms.functional as tf 11 | 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | 15 | from models.controlnet1x1 import ControlNetModel1x1 as ControlNetModel 16 | from models.pipeline_controlnet_1x1_4dunet import ( 17 | StableDiffusionControlNetPipeline1x1 as StableDiffusionControlNetPipeline, 18 | ) 19 | 20 | from models.unet_2d_condition_multiview import UNet2DConditionModelMultiview 21 | from diffusers import UniPCMultistepScheduler 22 | 23 | from args_file import parse_args 24 | from transformers import AutoTokenizer 25 | 26 | from utils.dataset_nusmtv import NuScenesDatasetMtvSpar as NuScenesDataset 27 | 28 | 29 | args = parse_args() 30 | 31 | base_model_path = "/hpc2hdd/home/lli181/long_video/animate-anything/download/AI-ModelScope/stable-diffusion-v2-1" 32 | 33 | 34 | ckp_path = "./exp/out_sd21_cbgs_loss/" 35 | 36 | 37 | if args.model_path_infer is not None: 38 | ckp_path = args.model_path_infer 39 | 40 | if "checkpoint" not in ckp_path: 41 | dirs = os.listdir(ckp_path) 42 | dirs = [d for d in dirs if d.startswith("checkpoint")] 43 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 44 | ckp_path = os.path.join(ckp_path, dirs[-1]) if len(dirs) > 0 else ckp_path 45 | 46 | height = 448 47 | width = 800 48 | 49 | 50 | generator = torch.manual_seed(0) 51 | # generator = torch.manual_seed(666) 52 | 53 | num_validation_images = 2 54 | 55 | 56 | tokenizer = AutoTokenizer.from_pretrained( 57 | args.pretrained_model_name_or_path, 58 | subfolder="tokenizer", 59 | revision=args.revision, 60 | use_fast=False, 61 | ) 62 | 63 | val_dataset = NuScenesDataset(args, tokenizer, args.gen_train_or_val) 64 | 65 | CAM_NAMES = val_dataset.CAM_NAMES 66 | 67 | 68 | save_path = "vis_dir/out_sd21_cbgs_loss2_40/samples" 69 | 70 | 71 | if args.model_path_infer is not None: 72 | save_path = args.save_img_path 73 | 74 | 75 | print(ckp_path, save_path) 76 | 77 | 78 | for cam_name in CAM_NAMES: 79 | os.makedirs(os.path.join(save_path, cam_name), exist_ok=True) 80 | 81 | 82 | def run_inference(rank, world_size, pred_results, input_datas, pipe, args): 83 | 84 | pipe.to("cuda") 85 | 86 | all_list = input_datas[rank] 87 | 88 | validation_prompts = [] 89 | validation_prompts.append("show a photorealistic street view image.") 90 | 91 | with torch.no_grad(): 92 | for img_idx in tqdm.tqdm(all_list): 93 | 94 | data_dict = val_dataset.__getitem__(img_idx) 95 | mtv_condition = data_dict["ctrl_img"].to("cuda") 96 | 97 | cfg_scale = torch.tensor([7.5]).to("cuda") 98 | 99 | images_tensor = [] 100 | 101 | for iter_idx in range(len(validation_prompts)): 102 | 103 | curr_prompt = [validation_prompts[iter_idx]] * 6 104 | 105 | with torch.autocast("cuda"): 106 | image = pipe( 107 | prompt=curr_prompt, 108 | image=mtv_condition, 109 | num_inference_steps=20, 110 | generator=generator, 111 | height=height, 112 | width=width, 113 | controlnet_conditioning_scale=1.0, 114 | guidance_scale=cfg_scale, 115 | ).images # [0] 116 | 117 | for mtv_idx, img in enumerate(image): 118 | img = img.resize((1600, 900)) 119 | 120 | img_name = data_dict["path_img"][mtv_idx].split("/")[-1] 121 | img.save(f"{save_path}/{CAM_NAMES[mtv_idx]}/{img_name}") 122 | 123 | image = torch.cat([torch.tensor(np.array(ii)) for ii in image], 1) 124 | 125 | images_tensor.append(image) 126 | 127 | # [448, 6, 800, 3] to [448, 4800, 3] 128 | raw_img = ( 129 | data_dict["pixel_values"] 130 | .permute(2, 0, 3, 1) 131 | .reshape(images_tensor[0].shape) 132 | * 255 133 | ) 134 | occ_rgb = ( 135 | data_dict["occ_rgb"].permute(1, 0, 2, 3).reshape(images_tensor[0].shape) 136 | ) 137 | gen_img = torch.cat(images_tensor, 0) 138 | gen_img = torch.cat([occ_rgb, gen_img, raw_img], 0) 139 | 140 | out_path = os.path.join( 141 | f"{save_path}/{img_idx:06d}_{str(cfg_scale.item())}.jpg" 142 | ) 143 | 144 | cv2.imwrite( 145 | out_path, cv2.cvtColor(gen_img.cpu().numpy(), cv2.COLOR_RGB2BGR) 146 | ) 147 | 148 | 149 | if __name__ == "__main__": 150 | os.system("export NCCL_SOCKET_IFNAME=eth1") 151 | 152 | from torch.multiprocessing import Manager 153 | 154 | world_size = 4 155 | # world_size = 8 156 | 157 | all_len = len(val_dataset) 158 | # all_len = 500 159 | 160 | all_list = np.arange(0, all_len, 1) 161 | 162 | all_len_sel = all_list.shape[0] 163 | val_len = all_len_sel // world_size * world_size 164 | 165 | all_list_filter = all_list[:val_len] 166 | 167 | all_list_filter = np.split(all_list_filter, world_size) 168 | 169 | input_datas = {} 170 | for i in range(world_size): 171 | input_datas[i] = list(all_list_filter[i]) 172 | print(len(input_datas[i])) 173 | 174 | input_datas[0] += list(all_list[val_len:]) 175 | 176 | controlnet = ControlNetModel.from_pretrained( 177 | ckp_path, subfolder="controlnet", torch_dtype=torch.float16 178 | ) 179 | unet = UNet2DConditionModelMultiview.from_pretrained( 180 | ckp_path, subfolder="unet", torch_dtype=torch.float16 181 | ) 182 | pipe = StableDiffusionControlNetPipeline.from_pretrained( 183 | base_model_path, unet=unet, controlnet=controlnet, torch_dtype=torch.float16 184 | ) 185 | 186 | # speed up diffusion process with faster scheduler and memory optimization 187 | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) 188 | 189 | pipe.enable_xformers_memory_efficient_attention() 190 | pipe.set_progress_bar_config(disable=True) 191 | from diffusers.models.attention_processor import AttnProcessor2_0 192 | 193 | pipe.unet.set_attn_processor(AttnProcessor2_0()) 194 | 195 | run_inference(args.curr_gpu, 1, None, input_datas, pipe, args) 196 | 197 | # with Manager() as manager: 198 | # pred_results = manager.list() 199 | # mp.spawn(run_inference, nprocs=world_size, args=(world_size,pred_results,input_datas,pipe,args,), join=True) 200 | -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | 2 | export MODEL_DIR="./ckp/stable-diffusion-v2-1" 3 | 4 | export EXP_NAME="train_syntheocc" 5 | export OUTPUT_DIR="./ckp/$EXP_NAME" 6 | export SAVE_IMG_DIR="vis_dir/$EXP_NAME/samples" 7 | export DATA_USED="samples_syntheocc_surocc" 8 | 9 | export TRAIN_OR_VAL="val" 10 | 11 | 12 | CUDA_VISIBLE_DEVICES=0 python infer.py --save_img_path=$SAVE_IMG_DIR --model_path_infer=$OUTPUT_DIR --curr_gpu=0 --mtp_path=$DATA_USED --ctrl_channel=257 --gen_train_or_val=$TRAIN_OR_VAL 13 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint 6 | from einops import rearrange 7 | 8 | from diffusers.models.attention_processor import Attention 9 | from diffusers.models.attention import AdaLayerNorm 10 | from diffusers.models.controlnet import zero_module 11 | 12 | 13 | from diffusers.utils import USE_PEFT_BACKEND 14 | from diffusers.utils.torch_utils import maybe_allow_in_graph 15 | from diffusers.models.attention_processor import Attention 16 | from diffusers.models.lora import LoRACompatibleLinear 17 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 18 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 19 | 20 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 21 | 22 | 23 | 24 | 25 | 26 | def _chunked_feed_forward( 27 | ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None 28 | ): 29 | # "feed_forward_chunk_size" can be used to save memory 30 | if hidden_states.shape[chunk_dim] % chunk_size != 0: 31 | raise ValueError( 32 | f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 33 | ) 34 | 35 | num_chunks = hidden_states.shape[chunk_dim] // chunk_size 36 | if lora_scale is None: 37 | ff_output = torch.cat( 38 | [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 39 | dim=chunk_dim, 40 | ) 41 | else: 42 | # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete 43 | ff_output = torch.cat( 44 | [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 45 | dim=chunk_dim, 46 | ) 47 | 48 | return ff_output 49 | 50 | 51 | @maybe_allow_in_graph 52 | class GatedSelfAttentionDense(nn.Module): 53 | r""" 54 | A gated self-attention dense layer that combines visual features and object features. 55 | 56 | Parameters: 57 | query_dim (`int`): The number of channels in the query. 58 | context_dim (`int`): The number of channels in the context. 59 | n_heads (`int`): The number of heads to use for attention. 60 | d_head (`int`): The number of channels in each head. 61 | """ 62 | 63 | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): 64 | super().__init__() 65 | 66 | # we need a linear projection since we need cat visual feature and obj feature 67 | self.linear = nn.Linear(context_dim, query_dim) 68 | 69 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 70 | self.ff = FeedForward(query_dim, activation_fn="geglu") 71 | 72 | self.norm1 = nn.LayerNorm(query_dim) 73 | self.norm2 = nn.LayerNorm(query_dim) 74 | 75 | self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) 76 | self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) 77 | 78 | self.enabled = True 79 | 80 | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: 81 | if not self.enabled: 82 | return x 83 | 84 | n_visual = x.shape[1] 85 | objs = self.linear(objs) 86 | 87 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] 88 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 89 | 90 | return x 91 | 92 | 93 | @maybe_allow_in_graph 94 | class BasicTransformerBlock(nn.Module): 95 | r""" 96 | A basic Transformer block. 97 | 98 | Parameters: 99 | dim (`int`): The number of channels in the input and output. 100 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 101 | attention_head_dim (`int`): The number of channels in each head. 102 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 103 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 104 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 105 | num_embeds_ada_norm (: 106 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 107 | attention_bias (: 108 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 109 | only_cross_attention (`bool`, *optional*): 110 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 111 | double_self_attention (`bool`, *optional*): 112 | Whether to use two self-attention layers. In this case no cross attention layers are used. 113 | upcast_attention (`bool`, *optional*): 114 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 115 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 116 | Whether to use learnable elementwise affine parameters for normalization. 117 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 118 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 119 | final_dropout (`bool` *optional*, defaults to False): 120 | Whether to apply a final dropout after the last feed-forward layer. 121 | attention_type (`str`, *optional*, defaults to `"default"`): 122 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 123 | positional_embeddings (`str`, *optional*, defaults to `None`): 124 | The type of positional embeddings to apply to. 125 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 126 | The maximum number of positional embeddings to apply. 127 | """ 128 | 129 | def __init__( 130 | self, 131 | dim: int, 132 | num_attention_heads: int, 133 | attention_head_dim: int, 134 | dropout=0.0, 135 | cross_attention_dim: Optional[int] = None, 136 | activation_fn: str = "geglu", 137 | num_embeds_ada_norm: Optional[int] = None, 138 | attention_bias: bool = False, 139 | only_cross_attention: bool = False, 140 | double_self_attention: bool = False, 141 | upcast_attention: bool = False, 142 | norm_elementwise_affine: bool = True, 143 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen' 144 | norm_eps: float = 1e-5, 145 | final_dropout: bool = False, 146 | attention_type: str = "default", 147 | positional_embeddings: Optional[str] = None, 148 | num_positional_embeddings: Optional[int] = None, 149 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 150 | ada_norm_bias: Optional[int] = None, 151 | ff_inner_dim: Optional[int] = None, 152 | ff_bias: bool = True, 153 | attention_out_bias: bool = True, 154 | ): 155 | super().__init__() 156 | self._args = {k: v for k, v in locals().items() if k != "self" and not k.startswith("_")} 157 | 158 | self.only_cross_attention = only_cross_attention 159 | 160 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 161 | raise ValueError( 162 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 163 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 164 | ) 165 | 166 | self.norm_type = norm_type 167 | self.num_embeds_ada_norm = num_embeds_ada_norm 168 | 169 | if positional_embeddings and (num_positional_embeddings is None): 170 | raise ValueError( 171 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 172 | ) 173 | 174 | if positional_embeddings == "sinusoidal": 175 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 176 | else: 177 | self.pos_embed = None 178 | 179 | # Define 3 blocks. Each block has its own normalization layer. 180 | # 1. Self-Attn 181 | if norm_type == "ada_norm": 182 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 183 | elif norm_type == "ada_norm_zero": 184 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 185 | elif norm_type == "ada_norm_continuous": 186 | self.norm1 = AdaLayerNormContinuous( 187 | dim, 188 | ada_norm_continous_conditioning_embedding_dim, 189 | norm_elementwise_affine, 190 | norm_eps, 191 | ada_norm_bias, 192 | "rms_norm", 193 | ) 194 | else: 195 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 196 | 197 | self.attn1 = Attention( 198 | query_dim=dim, 199 | heads=num_attention_heads, 200 | dim_head=attention_head_dim, 201 | dropout=dropout, 202 | bias=attention_bias, 203 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 204 | upcast_attention=upcast_attention, 205 | out_bias=attention_out_bias, 206 | ) 207 | 208 | # 2. Cross-Attn 209 | if cross_attention_dim is not None or double_self_attention: 210 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 211 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 212 | # the second cross attention block. 213 | if norm_type == "ada_norm": 214 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 215 | elif norm_type == "ada_norm_continuous": 216 | self.norm2 = AdaLayerNormContinuous( 217 | dim, 218 | ada_norm_continous_conditioning_embedding_dim, 219 | norm_elementwise_affine, 220 | norm_eps, 221 | ada_norm_bias, 222 | "rms_norm", 223 | ) 224 | else: 225 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 226 | 227 | self.attn2 = Attention( 228 | query_dim=dim, 229 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 230 | heads=num_attention_heads, 231 | dim_head=attention_head_dim, 232 | dropout=dropout, 233 | bias=attention_bias, 234 | upcast_attention=upcast_attention, 235 | out_bias=attention_out_bias, 236 | ) # is self-attn if encoder_hidden_states is none 237 | else: 238 | self.norm2 = None 239 | self.attn2 = None 240 | 241 | # 3. Feed-forward 242 | if norm_type == "ada_norm_continuous": 243 | self.norm3 = AdaLayerNormContinuous( 244 | dim, 245 | ada_norm_continous_conditioning_embedding_dim, 246 | norm_elementwise_affine, 247 | norm_eps, 248 | ada_norm_bias, 249 | "layer_norm", 250 | ) 251 | 252 | elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]: 253 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 254 | elif norm_type == "layer_norm_i2vgen": 255 | self.norm3 = None 256 | 257 | self.ff = FeedForward( 258 | dim, 259 | dropout=dropout, 260 | activation_fn=activation_fn, 261 | final_dropout=final_dropout, 262 | inner_dim=ff_inner_dim, 263 | bias=ff_bias, 264 | ) 265 | 266 | # 4. Fuser 267 | if attention_type == "gated" or attention_type == "gated-text-image": 268 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 269 | 270 | # 5. Scale-shift for PixArt-Alpha. 271 | if norm_type == "ada_norm_single": 272 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 273 | 274 | # let chunk size default to None 275 | self._chunk_size = None 276 | self._chunk_dim = 0 277 | 278 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 279 | # Sets chunk feed-forward 280 | self._chunk_size = chunk_size 281 | self._chunk_dim = dim 282 | 283 | def forward( 284 | self, 285 | hidden_states: torch.FloatTensor, 286 | attention_mask: Optional[torch.FloatTensor] = None, 287 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 288 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 289 | timestep: Optional[torch.LongTensor] = None, 290 | cross_attention_kwargs: Dict[str, Any] = None, 291 | class_labels: Optional[torch.LongTensor] = None, 292 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 293 | ) -> torch.FloatTensor: 294 | # Notice that normalization is always applied before the real computation in the following blocks. 295 | # 0. Self-Attention 296 | batch_size = hidden_states.shape[0] 297 | 298 | if self.norm_type == "ada_norm": 299 | norm_hidden_states = self.norm1(hidden_states, timestep) 300 | elif self.norm_type == "ada_norm_zero": 301 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 302 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 303 | ) 304 | elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: 305 | norm_hidden_states = self.norm1(hidden_states) 306 | elif self.norm_type == "ada_norm_continuous": 307 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 308 | elif self.norm_type == "ada_norm_single": 309 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 310 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 311 | ).chunk(6, dim=1) 312 | norm_hidden_states = self.norm1(hidden_states) 313 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 314 | norm_hidden_states = norm_hidden_states.squeeze(1) 315 | else: 316 | raise ValueError("Incorrect norm used") 317 | 318 | if self.pos_embed is not None: 319 | norm_hidden_states = self.pos_embed(norm_hidden_states) 320 | 321 | # 1. Retrieve lora scale. 322 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 323 | 324 | # 2. Prepare GLIGEN inputs 325 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 326 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 327 | 328 | attn_output = self.attn1( 329 | norm_hidden_states, 330 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 331 | attention_mask=attention_mask, 332 | **cross_attention_kwargs, 333 | ) 334 | if self.norm_type == "ada_norm_zero": 335 | attn_output = gate_msa.unsqueeze(1) * attn_output 336 | elif self.norm_type == "ada_norm_single": 337 | attn_output = gate_msa * attn_output 338 | 339 | hidden_states = attn_output + hidden_states 340 | if hidden_states.ndim == 4: 341 | hidden_states = hidden_states.squeeze(1) 342 | 343 | # 2.5 GLIGEN Control 344 | if gligen_kwargs is not None: 345 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 346 | 347 | # 3. Cross-Attention 348 | if self.attn2 is not None: 349 | if self.norm_type == "ada_norm": 350 | norm_hidden_states = self.norm2(hidden_states, timestep) 351 | elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: 352 | norm_hidden_states = self.norm2(hidden_states) 353 | elif self.norm_type == "ada_norm_single": 354 | # For PixArt norm2 isn't applied here: 355 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 356 | norm_hidden_states = hidden_states 357 | elif self.norm_type == "ada_norm_continuous": 358 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 359 | else: 360 | raise ValueError("Incorrect norm") 361 | 362 | if self.pos_embed is not None and self.norm_type != "ada_norm_single": 363 | norm_hidden_states = self.pos_embed(norm_hidden_states) 364 | 365 | attn_output = self.attn2( 366 | norm_hidden_states, 367 | encoder_hidden_states=encoder_hidden_states, 368 | attention_mask=encoder_attention_mask, 369 | **cross_attention_kwargs, 370 | ) 371 | hidden_states = attn_output + hidden_states 372 | 373 | # 4. Feed-forward 374 | # i2vgen doesn't have this norm 🤷‍♂️ 375 | if self.norm_type == "ada_norm_continuous": 376 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 377 | elif not self.norm_type == "ada_norm_single": 378 | norm_hidden_states = self.norm3(hidden_states) 379 | 380 | if self.norm_type == "ada_norm_zero": 381 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 382 | 383 | if self.norm_type == "ada_norm_single": 384 | norm_hidden_states = self.norm2(hidden_states) 385 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 386 | 387 | if self._chunk_size is not None: 388 | # "feed_forward_chunk_size" can be used to save memory 389 | ff_output = _chunked_feed_forward( 390 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 391 | ) 392 | else: 393 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 394 | 395 | if self.norm_type == "ada_norm_zero": 396 | ff_output = gate_mlp.unsqueeze(1) * ff_output 397 | elif self.norm_type == "ada_norm_single": 398 | ff_output = gate_mlp * ff_output 399 | 400 | hidden_states = ff_output + hidden_states 401 | if hidden_states.ndim == 4: 402 | hidden_states = hidden_states.squeeze(1) 403 | 404 | return hidden_states 405 | 406 | 407 | 408 | class FeedForward(nn.Module): 409 | r""" 410 | A feed-forward layer. 411 | 412 | Parameters: 413 | dim (`int`): The number of channels in the input. 414 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 415 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 416 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 417 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 418 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 419 | bias (`bool`, defaults to True): Whether to use a bias in the linear layer. 420 | """ 421 | 422 | def __init__( 423 | self, 424 | dim: int, 425 | dim_out: Optional[int] = None, 426 | mult: int = 4, 427 | dropout: float = 0.0, 428 | activation_fn: str = "geglu", 429 | final_dropout: bool = False, 430 | inner_dim=None, 431 | bias: bool = True, 432 | ): 433 | super().__init__() 434 | if inner_dim is None: 435 | inner_dim = int(dim * mult) 436 | dim_out = dim_out if dim_out is not None else dim 437 | linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear 438 | 439 | if activation_fn == "gelu": 440 | act_fn = GELU(dim, inner_dim, bias=bias) 441 | if activation_fn == "gelu-approximate": 442 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 443 | elif activation_fn == "geglu": 444 | act_fn = GEGLU(dim, inner_dim, bias=bias) 445 | elif activation_fn == "geglu-approximate": 446 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 447 | 448 | self.net = nn.ModuleList([]) 449 | # project in 450 | self.net.append(act_fn) 451 | # project dropout 452 | self.net.append(nn.Dropout(dropout)) 453 | # project out 454 | self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) 455 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 456 | if final_dropout: 457 | self.net.append(nn.Dropout(dropout)) 458 | 459 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 460 | compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) 461 | for module in self.net: 462 | if isinstance(module, compatible_cls): 463 | hidden_states = module(hidden_states, scale) 464 | else: 465 | hidden_states = module(hidden_states) 466 | return hidden_states 467 | 468 | 469 | def _ensure_kv_is_int(view_pair: dict): 470 | """yaml key can be int, while json cannot. We convert here. 471 | """ 472 | new_dict = {} 473 | for k, v in view_pair.items(): 474 | new_value = [int(vi) for vi in v] 475 | new_dict[int(k)] = new_value 476 | return new_dict 477 | 478 | 479 | class GatedConnector(nn.Module): 480 | def __init__(self, dim) -> None: 481 | super().__init__() 482 | data = torch.zeros(dim) 483 | self.alpha = nn.parameter.Parameter(data) 484 | 485 | def forward(self, inx): 486 | # as long as last dim of input == dim, pytorch can auto-broad 487 | return F.tanh(self.alpha) * inx 488 | 489 | 490 | class BasicMultiviewTransformerBlock(BasicTransformerBlock): 491 | 492 | def __init__( 493 | self, 494 | dim: int, 495 | num_attention_heads: int, 496 | attention_head_dim: int, 497 | dropout=0.0, 498 | cross_attention_dim: Optional[int] = None, 499 | activation_fn: str = "geglu", 500 | num_embeds_ada_norm: Optional[int] = None, 501 | attention_bias: bool = False, 502 | only_cross_attention: bool = False, 503 | double_self_attention: bool = False, 504 | upcast_attention: bool = False, 505 | norm_elementwise_affine: bool = True, 506 | norm_type: str = "layer_norm", 507 | final_dropout: bool = False, 508 | # multi_view 509 | neighboring_view_pair: Optional[Dict[int, List[int]]] = None, 510 | neighboring_attn_type: Optional[str] = "add", 511 | zero_module_type="zero_linear", 512 | **kargs, 513 | ): 514 | super().__init__( 515 | dim, num_attention_heads, attention_head_dim, dropout, 516 | cross_attention_dim, activation_fn, num_embeds_ada_norm, 517 | attention_bias, only_cross_attention, double_self_attention, 518 | upcast_attention, norm_elementwise_affine, norm_type, final_dropout) 519 | 520 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 521 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 522 | 523 | self.neighboring_view_pair = _ensure_kv_is_int(neighboring_view_pair) 524 | self.neighboring_attn_type = neighboring_attn_type 525 | # multiview attention 526 | self.norm4 = ( 527 | AdaLayerNorm(dim, num_embeds_ada_norm) 528 | if self.use_ada_layer_norm 529 | else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 530 | ) 531 | self.attn4 = Attention( 532 | query_dim=dim, 533 | cross_attention_dim=dim, 534 | heads=num_attention_heads, 535 | dim_head=attention_head_dim, 536 | dropout=dropout, 537 | bias=attention_bias, 538 | upcast_attention=upcast_attention, 539 | ) 540 | if zero_module_type == "zero_linear": 541 | # NOTE: zero_module cannot apply to successive layers. 542 | self.connector = zero_module(nn.Linear(dim, dim)) 543 | elif zero_module_type == "gated": 544 | self.connector = GatedConnector(dim) 545 | elif zero_module_type == "none": 546 | # TODO: if this block is in controlnet, we may not need zero here. 547 | self.connector = lambda x: x 548 | else: 549 | raise TypeError(f"Unknown zero module type: {zero_module_type}") 550 | 551 | @property 552 | def new_module(self): 553 | ret = { 554 | "norm4": self.norm4, 555 | "attn4": self.attn4, 556 | } 557 | if isinstance(self.connector, nn.Module): 558 | ret["connector"] = self.connector 559 | return ret 560 | 561 | @property 562 | def n_cam(self): 563 | return len(self.neighboring_view_pair) 564 | 565 | def _construct_attn_input(self, norm_hidden_states): 566 | B = len(norm_hidden_states) 567 | # reshape, key for origin view, value for ref view 568 | hidden_states_in1 = [] 569 | hidden_states_in2 = [] 570 | cam_order = [] 571 | if self.neighboring_attn_type == "add": 572 | for key, values in self.neighboring_view_pair.items(): 573 | for value in values: 574 | hidden_states_in1.append(norm_hidden_states[:, key]) 575 | hidden_states_in2.append(norm_hidden_states[:, value]) 576 | cam_order += [key] * B 577 | # N*2*B, H*W, head*dim 578 | hidden_states_in1 = torch.cat(hidden_states_in1, dim=0) 579 | hidden_states_in2 = torch.cat(hidden_states_in2, dim=0) 580 | cam_order = torch.LongTensor(cam_order) 581 | elif self.neighboring_attn_type == "concat": 582 | for key, values in self.neighboring_view_pair.items(): 583 | hidden_states_in1.append(norm_hidden_states[:, key]) 584 | hidden_states_in2.append(torch.cat([ 585 | norm_hidden_states[:, value] for value in values 586 | ], dim=1)) 587 | cam_order += [key] * B 588 | # N*B, H*W, head*dim 589 | hidden_states_in1 = torch.cat(hidden_states_in1, dim=0) 590 | # N*B, 2*H*W, head*dim 591 | hidden_states_in2 = torch.cat(hidden_states_in2, dim=0) 592 | cam_order = torch.LongTensor(cam_order) 593 | elif self.neighboring_attn_type == "self": 594 | hidden_states_in1 = rearrange( 595 | norm_hidden_states, "b n l ... -> b (n l) ...") 596 | hidden_states_in2 = None 597 | cam_order = None 598 | else: 599 | raise NotImplementedError( 600 | f"Unknown type: {self.neighboring_attn_type}") 601 | return hidden_states_in1, hidden_states_in2, cam_order 602 | 603 | def forward( 604 | self, 605 | hidden_states, 606 | attention_mask=None, 607 | encoder_hidden_states=None, 608 | encoder_attention_mask=None, 609 | timestep=None, 610 | cross_attention_kwargs=None, 611 | class_labels=None, 612 | ): 613 | # Notice that normalization is always applied before the real computation in the following blocks. 614 | # 1. Self-Attention 615 | if self.use_ada_layer_norm: 616 | norm_hidden_states = self.norm1(hidden_states, timestep) 617 | elif self.use_ada_layer_norm_zero: 618 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 619 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 620 | ) 621 | else: 622 | norm_hidden_states = self.norm1(hidden_states) 623 | 624 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 625 | attn_output = self.attn1( 626 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states 627 | if self.only_cross_attention else None, 628 | attention_mask=attention_mask, **cross_attention_kwargs,) 629 | if self.use_ada_layer_norm_zero: 630 | attn_output = gate_msa.unsqueeze(1) * attn_output 631 | hidden_states = attn_output + hidden_states 632 | 633 | # 2. Cross-Attention 634 | if self.attn2 is not None: 635 | norm_hidden_states = ( 636 | self.norm2(hidden_states, timestep) 637 | if self.use_ada_layer_norm else self.norm2(hidden_states)) 638 | # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly 639 | # prepare attention mask here 640 | 641 | attn_output = self.attn2( 642 | norm_hidden_states, 643 | encoder_hidden_states=encoder_hidden_states, 644 | attention_mask=encoder_attention_mask, 645 | **cross_attention_kwargs, 646 | ) 647 | hidden_states = attn_output + hidden_states 648 | 649 | # multi-view cross attention 650 | norm_hidden_states = ( 651 | self.norm4(hidden_states, timestep) if self.use_ada_layer_norm else 652 | self.norm4(hidden_states) 653 | ) 654 | # batch dim first, cam dim second 655 | norm_hidden_states = rearrange( 656 | norm_hidden_states, '(b n) ... -> b n ...', n=self.n_cam) 657 | B = len(norm_hidden_states) 658 | # key is query in attention; value is key-value in attention 659 | hidden_states_in1, hidden_states_in2, cam_order = self._construct_attn_input( 660 | norm_hidden_states, ) 661 | # attention 662 | attn_raw_output = self.attn4( 663 | hidden_states_in1, 664 | encoder_hidden_states=hidden_states_in2, 665 | **cross_attention_kwargs, 666 | ) 667 | # import ipdb; ipdb.set_trace() 668 | 669 | # final output 670 | if self.neighboring_attn_type == "self": 671 | attn_output = rearrange( 672 | attn_raw_output, 'b (n l) ... -> b n l ...', n=self.n_cam) 673 | else: 674 | attn_output = torch.zeros_like(norm_hidden_states) 675 | for cam_i in range(self.n_cam): 676 | attn_out_mv = rearrange(attn_raw_output[cam_order == cam_i], 677 | '(n b) ... -> b n ...', b=B) 678 | attn_output[:, cam_i] = torch.sum(attn_out_mv, dim=1) 679 | attn_output = rearrange(attn_output, 'b n ... -> (b n) ...') 680 | # apply zero init connector (one layer) 681 | attn_output = self.connector(attn_output) 682 | # short-cut 683 | hidden_states = attn_output + hidden_states 684 | 685 | # 3. Feed-forward 686 | norm_hidden_states = self.norm3(hidden_states) 687 | 688 | if self.use_ada_layer_norm_zero: 689 | norm_hidden_states = norm_hidden_states * ( 690 | 1 + scale_mlp[:, None]) + shift_mlp[:, None] 691 | 692 | ff_output = self.ff(norm_hidden_states) 693 | 694 | if self.use_ada_layer_norm_zero: 695 | ff_output = gate_mlp.unsqueeze(1) * ff_output 696 | 697 | hidden_states = ff_output + hidden_states 698 | 699 | return hidden_states 700 | 701 | -------------------------------------------------------------------------------- /models/controlnet1x1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import FromOriginalControlNetMixin 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.models.attention_processor import ( 25 | ADDED_KV_ATTENTION_PROCESSORS, 26 | CROSS_ATTENTION_PROCESSORS, 27 | AttentionProcessor, 28 | AttnAddedKVProcessor, 29 | AttnProcessor, 30 | ) 31 | from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps 32 | from diffusers.models.modeling_utils import ModelMixin 33 | from diffusers.models.unets.unet_2d_blocks import ( 34 | CrossAttnDownBlock2D, 35 | DownBlock2D, 36 | UNetMidBlock2D, 37 | UNetMidBlock2DCrossAttn, 38 | get_down_block, 39 | ) 40 | from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel 41 | 42 | 43 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 44 | 45 | 46 | @dataclass 47 | class ControlNetOutput(BaseOutput): 48 | """ 49 | The output of [`ControlNetModel`]. 50 | 51 | Args: 52 | down_block_res_samples (`tuple[torch.Tensor]`): 53 | A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should 54 | be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be 55 | used to condition the original UNet's downsampling activations. 56 | mid_down_block_re_sample (`torch.Tensor`): 57 | The activation of the midde block (the lowest sample resolution). Each tensor should be of shape 58 | `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. 59 | Output can be used to condition the original UNet's middle block activation. 60 | """ 61 | 62 | down_block_res_samples: Tuple[torch.Tensor] 63 | mid_block_res_sample: torch.Tensor 64 | 65 | 66 | class ControlNetConditioningEmbedding(nn.Module): 67 | """ 68 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 69 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 70 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 71 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 72 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 73 | model) to encode image-space conditions ... into feature maps ..." 74 | """ 75 | 76 | def __init__( 77 | self, 78 | conditioning_embedding_channels: int, 79 | conditioning_channels: int = 3, 80 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), 81 | ): 82 | super().__init__() 83 | 84 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 85 | 86 | self.blocks = nn.ModuleList([]) 87 | 88 | for i in range(len(block_out_channels) - 1): 89 | channel_in = block_out_channels[i] 90 | channel_out = block_out_channels[i + 1] 91 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 92 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 93 | 94 | self.conv_out = zero_module( 95 | nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 96 | ) 97 | 98 | def forward(self, conditioning): 99 | embedding = self.conv_in(conditioning) 100 | embedding = F.silu(embedding) 101 | 102 | for block in self.blocks: 103 | embedding = block(embedding) 104 | embedding = F.silu(embedding) 105 | 106 | embedding = self.conv_out(embedding) 107 | 108 | return embedding 109 | 110 | 111 | class ControlNetConditioningEmbeddingNodown(nn.Module): 112 | """ 113 | Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN 114 | [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized 115 | training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the 116 | convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides 117 | (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full 118 | model) to encode image-space conditions ... into feature maps ..." 119 | """ 120 | 121 | def __init__( 122 | self, 123 | conditioning_embedding_channels: int, 124 | conditioning_channels: int = 3, 125 | block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), 126 | ): 127 | super().__init__() 128 | 129 | # self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=1, padding=0) 130 | 131 | # self.blocks = nn.ModuleList([]) 132 | 133 | # for i in range(len(block_out_channels) - 1): 134 | # channel_in = block_out_channels[i] 135 | # channel_out = block_out_channels[i + 1] 136 | # self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0)) 137 | # self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, stride=1)) 138 | 139 | # self.conv_out = zero_module( 140 | # nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=1, padding=0) 141 | # ) 142 | 143 | self.position_encoder = nn.Sequential( 144 | nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0), 145 | nn.ReLU(), 146 | # zero_module(nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0)), 147 | nn.Conv2d(conditioning_embedding_channels, conditioning_embedding_channels, kernel_size=1, stride=1, padding=0), 148 | ) 149 | 150 | def forward(self, conditioning): 151 | # embedding = self.conv_in(conditioning) 152 | # embedding = F.silu(embedding) 153 | 154 | # for block in self.blocks: 155 | # embedding = block(embedding) 156 | # embedding = F.silu(embedding) 157 | 158 | # embedding = self.conv_out(embedding) 159 | embedding = self.position_encoder(conditioning) 160 | 161 | return embedding 162 | 163 | 164 | class ControlNetModel1x1(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): 165 | """ 166 | A ControlNet model. 167 | 168 | Args: 169 | in_channels (`int`, defaults to 4): 170 | The number of channels in the input sample. 171 | flip_sin_to_cos (`bool`, defaults to `True`): 172 | Whether to flip the sin to cos in the time embedding. 173 | freq_shift (`int`, defaults to 0): 174 | The frequency shift to apply to the time embedding. 175 | down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 176 | The tuple of downsample blocks to use. 177 | only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): 178 | block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): 179 | The tuple of output channels for each block. 180 | layers_per_block (`int`, defaults to 2): 181 | The number of layers per block. 182 | downsample_padding (`int`, defaults to 1): 183 | The padding to use for the downsampling convolution. 184 | mid_block_scale_factor (`float`, defaults to 1): 185 | The scale factor to use for the mid block. 186 | act_fn (`str`, defaults to "silu"): 187 | The activation function to use. 188 | norm_num_groups (`int`, *optional*, defaults to 32): 189 | The number of groups to use for the normalization. If None, normalization and activation layers is skipped 190 | in post-processing. 191 | norm_eps (`float`, defaults to 1e-5): 192 | The epsilon to use for the normalization. 193 | cross_attention_dim (`int`, defaults to 1280): 194 | The dimension of the cross attention features. 195 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): 196 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 197 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 198 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 199 | encoder_hid_dim (`int`, *optional*, defaults to None): 200 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 201 | dimension to `cross_attention_dim`. 202 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 203 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 204 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 205 | attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): 206 | The dimension of the attention heads. 207 | use_linear_projection (`bool`, defaults to `False`): 208 | class_embed_type (`str`, *optional*, defaults to `None`): 209 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, 210 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 211 | addition_embed_type (`str`, *optional*, defaults to `None`): 212 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 213 | "text". "text" will use the `TextTimeEmbedding` layer. 214 | num_class_embeds (`int`, *optional*, defaults to 0): 215 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 216 | class conditioning with `class_embed_type` equal to `None`. 217 | upcast_attention (`bool`, defaults to `False`): 218 | resnet_time_scale_shift (`str`, defaults to `"default"`): 219 | Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. 220 | projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): 221 | The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when 222 | `class_embed_type="projection"`. 223 | controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): 224 | The channel order of conditional image. Will convert to `rgb` if it's `bgr`. 225 | conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): 226 | The tuple of output channel for each block in the `conditioning_embedding` layer. 227 | global_pool_conditions (`bool`, defaults to `False`): 228 | TODO(Patrick) - unused parameter. 229 | addition_embed_type_num_heads (`int`, defaults to 64): 230 | The number of heads to use for the `TextTimeEmbedding` layer. 231 | """ 232 | 233 | _supports_gradient_checkpointing = True 234 | 235 | @register_to_config 236 | def __init__( 237 | self, 238 | in_channels: int = 4, 239 | conditioning_channels: int = 3, 240 | flip_sin_to_cos: bool = True, 241 | freq_shift: int = 0, 242 | down_block_types: Tuple[str, ...] = ( 243 | "CrossAttnDownBlock2D", 244 | "CrossAttnDownBlock2D", 245 | "CrossAttnDownBlock2D", 246 | "DownBlock2D", 247 | ), 248 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 249 | only_cross_attention: Union[bool, Tuple[bool]] = False, 250 | block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), 251 | layers_per_block: int = 2, 252 | downsample_padding: int = 1, 253 | mid_block_scale_factor: float = 1, 254 | act_fn: str = "silu", 255 | norm_num_groups: Optional[int] = 32, 256 | norm_eps: float = 1e-5, 257 | cross_attention_dim: int = 1280, 258 | transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, 259 | encoder_hid_dim: Optional[int] = None, 260 | encoder_hid_dim_type: Optional[str] = None, 261 | attention_head_dim: Union[int, Tuple[int, ...]] = 8, 262 | num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, 263 | use_linear_projection: bool = False, 264 | class_embed_type: Optional[str] = None, 265 | addition_embed_type: Optional[str] = None, 266 | addition_time_embed_dim: Optional[int] = None, 267 | num_class_embeds: Optional[int] = None, 268 | upcast_attention: bool = False, 269 | resnet_time_scale_shift: str = "default", 270 | projection_class_embeddings_input_dim: Optional[int] = None, 271 | controlnet_conditioning_channel_order: str = "rgb", 272 | conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), 273 | global_pool_conditions: bool = False, 274 | addition_embed_type_num_heads: int = 64, 275 | ): 276 | super().__init__() 277 | 278 | # If `num_attention_heads` is not defined (which is the case for most models) 279 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 280 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 281 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 282 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 283 | # which is why we correct for the naming here. 284 | num_attention_heads = num_attention_heads or attention_head_dim 285 | 286 | # Check inputs 287 | if len(block_out_channels) != len(down_block_types): 288 | raise ValueError( 289 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 290 | ) 291 | 292 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 293 | raise ValueError( 294 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 295 | ) 296 | 297 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 298 | raise ValueError( 299 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 300 | ) 301 | 302 | if isinstance(transformer_layers_per_block, int): 303 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 304 | 305 | # input 306 | conv_in_kernel = 3 307 | conv_in_padding = (conv_in_kernel - 1) // 2 308 | self.conv_in = nn.Conv2d( 309 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 310 | ) 311 | 312 | # time 313 | time_embed_dim = block_out_channels[0] * 4 314 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 315 | timestep_input_dim = block_out_channels[0] 316 | self.time_embedding = TimestepEmbedding( 317 | timestep_input_dim, 318 | time_embed_dim, 319 | act_fn=act_fn, 320 | ) 321 | 322 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 323 | encoder_hid_dim_type = "text_proj" 324 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 325 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") 326 | 327 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 328 | raise ValueError( 329 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 330 | ) 331 | 332 | if encoder_hid_dim_type == "text_proj": 333 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 334 | elif encoder_hid_dim_type == "text_image_proj": 335 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 336 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 337 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 338 | self.encoder_hid_proj = TextImageProjection( 339 | text_embed_dim=encoder_hid_dim, 340 | image_embed_dim=cross_attention_dim, 341 | cross_attention_dim=cross_attention_dim, 342 | ) 343 | 344 | elif encoder_hid_dim_type is not None: 345 | raise ValueError( 346 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 347 | ) 348 | else: 349 | self.encoder_hid_proj = None 350 | 351 | # class embedding 352 | if class_embed_type is None and num_class_embeds is not None: 353 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 354 | elif class_embed_type == "timestep": 355 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 356 | elif class_embed_type == "identity": 357 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 358 | elif class_embed_type == "projection": 359 | if projection_class_embeddings_input_dim is None: 360 | raise ValueError( 361 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 362 | ) 363 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 364 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 365 | # 2. it projects from an arbitrary input dimension. 366 | # 367 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 368 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 369 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 370 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 371 | else: 372 | self.class_embedding = None 373 | 374 | if addition_embed_type == "text": 375 | if encoder_hid_dim is not None: 376 | text_time_embedding_from_dim = encoder_hid_dim 377 | else: 378 | text_time_embedding_from_dim = cross_attention_dim 379 | 380 | self.add_embedding = TextTimeEmbedding( 381 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads 382 | ) 383 | elif addition_embed_type == "text_image": 384 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 385 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 386 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 387 | self.add_embedding = TextImageTimeEmbedding( 388 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim 389 | ) 390 | elif addition_embed_type == "text_time": 391 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 392 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 393 | 394 | elif addition_embed_type is not None: 395 | raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") 396 | 397 | # control net conditioning embedding 398 | # self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 399 | self.controlnet_cond_embedding = ControlNetConditioningEmbeddingNodown( 400 | conditioning_embedding_channels=block_out_channels[0], 401 | block_out_channels=conditioning_embedding_out_channels, 402 | conditioning_channels=conditioning_channels, 403 | ) 404 | 405 | self.down_blocks = nn.ModuleList([]) 406 | self.controlnet_down_blocks = nn.ModuleList([]) 407 | 408 | if isinstance(only_cross_attention, bool): 409 | only_cross_attention = [only_cross_attention] * len(down_block_types) 410 | 411 | if isinstance(attention_head_dim, int): 412 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 413 | 414 | if isinstance(num_attention_heads, int): 415 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 416 | 417 | # down 418 | output_channel = block_out_channels[0] 419 | 420 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 421 | controlnet_block = zero_module(controlnet_block) 422 | self.controlnet_down_blocks.append(controlnet_block) 423 | 424 | for i, down_block_type in enumerate(down_block_types): 425 | input_channel = output_channel 426 | output_channel = block_out_channels[i] 427 | is_final_block = i == len(block_out_channels) - 1 428 | 429 | down_block = get_down_block( 430 | down_block_type, 431 | num_layers=layers_per_block, 432 | transformer_layers_per_block=transformer_layers_per_block[i], 433 | in_channels=input_channel, 434 | out_channels=output_channel, 435 | temb_channels=time_embed_dim, 436 | add_downsample=not is_final_block, 437 | resnet_eps=norm_eps, 438 | resnet_act_fn=act_fn, 439 | resnet_groups=norm_num_groups, 440 | cross_attention_dim=cross_attention_dim, 441 | num_attention_heads=num_attention_heads[i], 442 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 443 | downsample_padding=downsample_padding, 444 | use_linear_projection=use_linear_projection, 445 | only_cross_attention=only_cross_attention[i], 446 | upcast_attention=upcast_attention, 447 | resnet_time_scale_shift=resnet_time_scale_shift, 448 | ) 449 | self.down_blocks.append(down_block) 450 | 451 | for _ in range(layers_per_block): 452 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 453 | controlnet_block = zero_module(controlnet_block) 454 | self.controlnet_down_blocks.append(controlnet_block) 455 | 456 | if not is_final_block: 457 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 458 | controlnet_block = zero_module(controlnet_block) 459 | self.controlnet_down_blocks.append(controlnet_block) 460 | 461 | # mid 462 | mid_block_channel = block_out_channels[-1] 463 | 464 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 465 | controlnet_block = zero_module(controlnet_block) 466 | self.controlnet_mid_block = controlnet_block 467 | 468 | if mid_block_type == "UNetMidBlock2DCrossAttn": 469 | self.mid_block = UNetMidBlock2DCrossAttn( 470 | transformer_layers_per_block=transformer_layers_per_block[-1], 471 | in_channels=mid_block_channel, 472 | temb_channels=time_embed_dim, 473 | resnet_eps=norm_eps, 474 | resnet_act_fn=act_fn, 475 | output_scale_factor=mid_block_scale_factor, 476 | resnet_time_scale_shift=resnet_time_scale_shift, 477 | cross_attention_dim=cross_attention_dim, 478 | num_attention_heads=num_attention_heads[-1], 479 | resnet_groups=norm_num_groups, 480 | use_linear_projection=use_linear_projection, 481 | upcast_attention=upcast_attention, 482 | ) 483 | elif mid_block_type == "UNetMidBlock2D": 484 | self.mid_block = UNetMidBlock2D( 485 | in_channels=block_out_channels[-1], 486 | temb_channels=time_embed_dim, 487 | num_layers=0, 488 | resnet_eps=norm_eps, 489 | resnet_act_fn=act_fn, 490 | output_scale_factor=mid_block_scale_factor, 491 | resnet_groups=norm_num_groups, 492 | resnet_time_scale_shift=resnet_time_scale_shift, 493 | add_attention=False, 494 | ) 495 | else: 496 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 497 | 498 | @classmethod 499 | def from_unet( 500 | cls, 501 | unet: UNet2DConditionModel, 502 | controlnet_conditioning_channel_order: str = "rgb", 503 | conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), 504 | load_weights_from_unet: bool = True, 505 | conditioning_channels: int = 3, 506 | ): 507 | r""" 508 | Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. 509 | 510 | Parameters: 511 | unet (`UNet2DConditionModel`): 512 | The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied 513 | where applicable. 514 | """ 515 | transformer_layers_per_block = ( 516 | unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 517 | ) 518 | encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None 519 | encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None 520 | addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None 521 | addition_time_embed_dim = ( 522 | unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None 523 | ) 524 | 525 | controlnet = cls( 526 | encoder_hid_dim=encoder_hid_dim, 527 | encoder_hid_dim_type=encoder_hid_dim_type, 528 | addition_embed_type=addition_embed_type, 529 | addition_time_embed_dim=addition_time_embed_dim, 530 | transformer_layers_per_block=transformer_layers_per_block, 531 | in_channels=unet.config.in_channels, 532 | flip_sin_to_cos=unet.config.flip_sin_to_cos, 533 | freq_shift=unet.config.freq_shift, 534 | down_block_types=unet.config.down_block_types, 535 | only_cross_attention=unet.config.only_cross_attention, 536 | block_out_channels=unet.config.block_out_channels, 537 | layers_per_block=unet.config.layers_per_block, 538 | downsample_padding=unet.config.downsample_padding, 539 | mid_block_scale_factor=unet.config.mid_block_scale_factor, 540 | act_fn=unet.config.act_fn, 541 | norm_num_groups=unet.config.norm_num_groups, 542 | norm_eps=unet.config.norm_eps, 543 | cross_attention_dim=unet.config.cross_attention_dim, 544 | attention_head_dim=unet.config.attention_head_dim, 545 | num_attention_heads=unet.config.num_attention_heads, 546 | use_linear_projection=unet.config.use_linear_projection, 547 | class_embed_type=unet.config.class_embed_type, 548 | num_class_embeds=unet.config.num_class_embeds, 549 | upcast_attention=unet.config.upcast_attention, 550 | resnet_time_scale_shift=unet.config.resnet_time_scale_shift, 551 | projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, 552 | mid_block_type=unet.config.mid_block_type, 553 | controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, 554 | conditioning_embedding_out_channels=conditioning_embedding_out_channels, 555 | conditioning_channels=conditioning_channels, 556 | ) 557 | 558 | if load_weights_from_unet: 559 | controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) 560 | controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) 561 | controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) 562 | 563 | if controlnet.class_embedding: 564 | controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) 565 | 566 | controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) 567 | controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) 568 | 569 | return controlnet 570 | 571 | @property 572 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 573 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 574 | r""" 575 | Returns: 576 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 577 | indexed by its weight name. 578 | """ 579 | # set recursively 580 | processors = {} 581 | 582 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 583 | if hasattr(module, "get_processor"): 584 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 585 | 586 | for sub_name, child in module.named_children(): 587 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 588 | 589 | return processors 590 | 591 | for name, module in self.named_children(): 592 | fn_recursive_add_processors(name, module, processors) 593 | 594 | return processors 595 | 596 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 597 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 598 | r""" 599 | Sets the attention processor to use to compute attention. 600 | 601 | Parameters: 602 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 603 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 604 | for **all** `Attention` layers. 605 | 606 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 607 | processor. This is strongly recommended when setting trainable attention processors. 608 | 609 | """ 610 | count = len(self.attn_processors.keys()) 611 | 612 | if isinstance(processor, dict) and len(processor) != count: 613 | raise ValueError( 614 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 615 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 616 | ) 617 | 618 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 619 | if hasattr(module, "set_processor"): 620 | if not isinstance(processor, dict): 621 | module.set_processor(processor) 622 | else: 623 | module.set_processor(processor.pop(f"{name}.processor")) 624 | 625 | for sub_name, child in module.named_children(): 626 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 627 | 628 | for name, module in self.named_children(): 629 | fn_recursive_attn_processor(name, module, processor) 630 | 631 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor 632 | def set_default_attn_processor(self): 633 | """ 634 | Disables custom attention processors and sets the default attention implementation. 635 | """ 636 | if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 637 | processor = AttnAddedKVProcessor() 638 | elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): 639 | processor = AttnProcessor() 640 | else: 641 | raise ValueError( 642 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 643 | ) 644 | 645 | self.set_attn_processor(processor) 646 | 647 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice 648 | def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: 649 | r""" 650 | Enable sliced attention computation. 651 | 652 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 653 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 654 | 655 | Args: 656 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 657 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 658 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 659 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 660 | must be a multiple of `slice_size`. 661 | """ 662 | sliceable_head_dims = [] 663 | 664 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 665 | if hasattr(module, "set_attention_slice"): 666 | sliceable_head_dims.append(module.sliceable_head_dim) 667 | 668 | for child in module.children(): 669 | fn_recursive_retrieve_sliceable_dims(child) 670 | 671 | # retrieve number of attention layers 672 | for module in self.children(): 673 | fn_recursive_retrieve_sliceable_dims(module) 674 | 675 | num_sliceable_layers = len(sliceable_head_dims) 676 | 677 | if slice_size == "auto": 678 | # half the attention head size is usually a good trade-off between 679 | # speed and memory 680 | slice_size = [dim // 2 for dim in sliceable_head_dims] 681 | elif slice_size == "max": 682 | # make smallest slice possible 683 | slice_size = num_sliceable_layers * [1] 684 | 685 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 686 | 687 | if len(slice_size) != len(sliceable_head_dims): 688 | raise ValueError( 689 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 690 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 691 | ) 692 | 693 | for i in range(len(slice_size)): 694 | size = slice_size[i] 695 | dim = sliceable_head_dims[i] 696 | if size is not None and size > dim: 697 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 698 | 699 | # Recursively walk through all the children. 700 | # Any children which exposes the set_attention_slice method 701 | # gets the message 702 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 703 | if hasattr(module, "set_attention_slice"): 704 | module.set_attention_slice(slice_size.pop()) 705 | 706 | for child in module.children(): 707 | fn_recursive_set_attention_slice(child, slice_size) 708 | 709 | reversed_slice_size = list(reversed(slice_size)) 710 | for module in self.children(): 711 | fn_recursive_set_attention_slice(module, reversed_slice_size) 712 | 713 | def _set_gradient_checkpointing(self, module, value: bool = False) -> None: 714 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): 715 | module.gradient_checkpointing = value 716 | 717 | def forward( 718 | self, 719 | sample: torch.FloatTensor, 720 | timestep: Union[torch.Tensor, float, int], 721 | encoder_hidden_states: torch.Tensor, 722 | controlnet_cond: torch.FloatTensor, 723 | conditioning_scale: float = 1.0, 724 | class_labels: Optional[torch.Tensor] = None, 725 | timestep_cond: Optional[torch.Tensor] = None, 726 | attention_mask: Optional[torch.Tensor] = None, 727 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 728 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 729 | guess_mode: bool = False, 730 | return_dict: bool = True, 731 | ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: 732 | """ 733 | The [`ControlNetModel`] forward method. 734 | 735 | Args: 736 | sample (`torch.FloatTensor`): 737 | The noisy input tensor. 738 | timestep (`Union[torch.Tensor, float, int]`): 739 | The number of timesteps to denoise an input. 740 | encoder_hidden_states (`torch.Tensor`): 741 | The encoder hidden states. 742 | controlnet_cond (`torch.FloatTensor`): 743 | The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. 744 | conditioning_scale (`float`, defaults to `1.0`): 745 | The scale factor for ControlNet outputs. 746 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 747 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 748 | timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): 749 | Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the 750 | timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep 751 | embeddings. 752 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 753 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 754 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 755 | negative values to the attention scores corresponding to "discard" tokens. 756 | added_cond_kwargs (`dict`): 757 | Additional conditions for the Stable Diffusion XL UNet. 758 | cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): 759 | A kwargs dictionary that if specified is passed along to the `AttnProcessor`. 760 | guess_mode (`bool`, defaults to `False`): 761 | In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if 762 | you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. 763 | return_dict (`bool`, defaults to `True`): 764 | Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. 765 | 766 | Returns: 767 | [`~models.controlnet.ControlNetOutput`] **or** `tuple`: 768 | If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is 769 | returned where the first element is the sample tensor. 770 | """ 771 | # check channel order 772 | channel_order = self.config.controlnet_conditioning_channel_order 773 | 774 | if channel_order == "rgb": 775 | # in rgb order by default 776 | ... 777 | elif channel_order == "bgr": 778 | controlnet_cond = torch.flip(controlnet_cond, dims=[1]) 779 | else: 780 | raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") 781 | 782 | # prepare attention_mask 783 | if attention_mask is not None: 784 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 785 | attention_mask = attention_mask.unsqueeze(1) 786 | 787 | # 1. time 788 | timesteps = timestep 789 | if not torch.is_tensor(timesteps): 790 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 791 | # This would be a good case for the `match` statement (Python 3.10+) 792 | is_mps = sample.device.type == "mps" 793 | if isinstance(timestep, float): 794 | dtype = torch.float32 if is_mps else torch.float64 795 | else: 796 | dtype = torch.int32 if is_mps else torch.int64 797 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 798 | elif len(timesteps.shape) == 0: 799 | timesteps = timesteps[None].to(sample.device) 800 | 801 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 802 | timesteps = timesteps.expand(sample.shape[0]) 803 | 804 | t_emb = self.time_proj(timesteps) 805 | 806 | # timesteps does not contain any weights and will always return f32 tensors 807 | # but time_embedding might actually be running in fp16. so we need to cast here. 808 | # there might be better ways to encapsulate this. 809 | t_emb = t_emb.to(dtype=sample.dtype) 810 | 811 | emb = self.time_embedding(t_emb, timestep_cond) 812 | aug_emb = None 813 | 814 | if self.class_embedding is not None: 815 | if class_labels is None: 816 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 817 | 818 | if self.config.class_embed_type == "timestep": 819 | class_labels = self.time_proj(class_labels) 820 | 821 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 822 | emb = emb + class_emb 823 | 824 | if self.config.addition_embed_type is not None: 825 | if self.config.addition_embed_type == "text": 826 | aug_emb = self.add_embedding(encoder_hidden_states) 827 | 828 | elif self.config.addition_embed_type == "text_time": 829 | if "text_embeds" not in added_cond_kwargs: 830 | raise ValueError( 831 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 832 | ) 833 | text_embeds = added_cond_kwargs.get("text_embeds") 834 | if "time_ids" not in added_cond_kwargs: 835 | raise ValueError( 836 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 837 | ) 838 | time_ids = added_cond_kwargs.get("time_ids") 839 | time_embeds = self.add_time_proj(time_ids.flatten()) 840 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 841 | 842 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 843 | add_embeds = add_embeds.to(emb.dtype) 844 | aug_emb = self.add_embedding(add_embeds) 845 | 846 | emb = emb + aug_emb if aug_emb is not None else emb 847 | # print(666, controlnet_cond.shape, sample.shape) 848 | # 2. pre-process 849 | sample = self.conv_in(sample) 850 | 851 | # controlnet_cond = torch.zeros([1, 256, 56, 100]).to(sample) 852 | 853 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 854 | # print(controlnet_cond.shape, sample.shape) 855 | # import ipdb; ipdb.set_trace() 856 | # # ipdb> controlnet_cond.shape 857 | # # torch.Size([6, 320, 56, 100]) 858 | # # ipdb> sample.shape 859 | # # torch.Size([12, 320, 56, 100]) 860 | # controlnet_cond = torch.cat([controlnet_cond] * 2) 861 | 862 | sample = sample + controlnet_cond 863 | # 666 torch.Size([1, 256, 448, 800]) torch.Size([1, 4, 56, 100]) 864 | # torch.Size([1, 320, 56, 100]) torch.Size([1, 320, 56, 100]) 865 | # 3. down 866 | down_block_res_samples = (sample,) 867 | for downsample_block in self.down_blocks: 868 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 869 | sample, res_samples = downsample_block( 870 | hidden_states=sample, 871 | temb=emb, 872 | encoder_hidden_states=encoder_hidden_states, 873 | attention_mask=attention_mask, 874 | cross_attention_kwargs=cross_attention_kwargs, 875 | ) 876 | else: 877 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 878 | 879 | down_block_res_samples += res_samples 880 | 881 | # 4. mid 882 | if self.mid_block is not None: 883 | if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: 884 | sample = self.mid_block( 885 | sample, 886 | emb, 887 | encoder_hidden_states=encoder_hidden_states, 888 | attention_mask=attention_mask, 889 | cross_attention_kwargs=cross_attention_kwargs, 890 | ) 891 | else: 892 | sample = self.mid_block(sample, emb) 893 | 894 | # 5. Control net blocks 895 | 896 | controlnet_down_block_res_samples = () 897 | 898 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 899 | down_block_res_sample = controlnet_block(down_block_res_sample) 900 | controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) 901 | 902 | down_block_res_samples = controlnet_down_block_res_samples 903 | 904 | mid_block_res_sample = self.controlnet_mid_block(sample) 905 | 906 | # 6. scaling 907 | if guess_mode and not self.config.global_pool_conditions: 908 | scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 909 | scales = scales * conditioning_scale 910 | down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] 911 | mid_block_res_sample = mid_block_res_sample * scales[-1] # last one 912 | else: 913 | down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] 914 | mid_block_res_sample = mid_block_res_sample * conditioning_scale 915 | 916 | if self.config.global_pool_conditions: 917 | down_block_res_samples = [ 918 | torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples 919 | ] 920 | mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) 921 | 922 | if not return_dict: 923 | return (down_block_res_samples, mid_block_res_sample) 924 | 925 | return ControlNetOutput( 926 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample 927 | ) 928 | 929 | 930 | def zero_module(module): 931 | for p in module.parameters(): 932 | nn.init.zeros_(p) 933 | return module 934 | -------------------------------------------------------------------------------- /models/unet_2d_condition_multiview.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from functools import partial 16 | from typing import Any, Dict, List, Optional, Tuple, Union 17 | import logging 18 | import copy 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.utils.checkpoint 24 | 25 | from diffusers.configuration_utils import register_to_config 26 | from diffusers.models.unet_2d_condition import ( 27 | UNet2DConditionModel, 28 | UNet2DConditionOutput, 29 | ) 30 | from diffusers.models.unet_2d_blocks import ( 31 | CrossAttnDownBlock2D, 32 | CrossAttnUpBlock2D, 33 | DownBlock2D, 34 | UpBlock2D, 35 | ) 36 | from diffusers.models.attention import BasicTransformerBlock 37 | 38 | # from ..misc.common import _get_module, _set_module 39 | from .blocks import ( 40 | BasicMultiviewTransformerBlock, 41 | ) 42 | 43 | # take from torch.ao.quantization.fuse_modules 44 | # Generalization of getattr 45 | def _get_module(model, submodule_key): 46 | tokens = submodule_key.split('.') 47 | cur_mod = model 48 | for s in tokens: 49 | cur_mod = getattr(cur_mod, s) 50 | return cur_mod 51 | 52 | 53 | # Generalization of setattr 54 | def _set_module(model, submodule_key, module): 55 | tokens = submodule_key.split('.') 56 | sub_tokens = tokens[:-1] 57 | cur_mod = model 58 | for s in sub_tokens: 59 | cur_mod = getattr(cur_mod, s) 60 | 61 | setattr(cur_mod, tokens[-1], module) 62 | 63 | 64 | class UNet2DConditionModelMultiview(UNet2DConditionModel): 65 | r""" 66 | UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep 67 | and returns sample shaped output. 68 | 69 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library 70 | implements for all the models (such as downloading or saving, etc.) 71 | 72 | Parameters: 73 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 74 | Height and width of input/output sample. 75 | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. 76 | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. 77 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 78 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 79 | Whether to flip the sin to cos in the time embedding. 80 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 81 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 82 | The tuple of downsample blocks to use. 83 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 84 | The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the 85 | mid block layer if `None`. 86 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): 87 | The tuple of upsample blocks to use. 88 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 89 | Whether to include self-attention in the basic transformer blocks, see 90 | [`~models.attention.BasicTransformerBlock`]. 91 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 92 | The tuple of output channels for each block. 93 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 94 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 95 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 96 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 97 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 98 | If `None`, it will skip the normalization and activation layers in post-processing 99 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 100 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 101 | The dimension of the cross attention features. 102 | encoder_hid_dim (`int`, *optional*, defaults to None): 103 | If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. 104 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 105 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 106 | for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. 107 | class_embed_type (`str`, *optional*, defaults to None): 108 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 109 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 110 | addition_embed_type (`str`, *optional*, defaults to None): 111 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 112 | "text". "text" will use the `TextTimeEmbedding` layer. 113 | num_class_embeds (`int`, *optional*, defaults to None): 114 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 115 | class conditioning with `class_embed_type` equal to `None`. 116 | time_embedding_type (`str`, *optional*, default to `positional`): 117 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 118 | time_embedding_dim (`int`, *optional*, default to `None`): 119 | An optional override for the dimension of the projected time embedding. 120 | time_embedding_act_fn (`str`, *optional*, default to `None`): 121 | Optional activation function to use on the time embeddings only one time before they as passed to the rest 122 | of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. 123 | timestep_post_act (`str, *optional*, default to `None`): 124 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 125 | time_cond_proj_dim (`int`, *optional*, default to `None`): 126 | The dimension of `cond_proj` layer in timestep embedding. 127 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. 128 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. 129 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when 130 | using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. 131 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time 132 | embeddings with the class embeddings. 133 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): 134 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If 135 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the 136 | `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will 137 | default to `False`. 138 | """ 139 | 140 | _supports_gradient_checkpointing = True 141 | _WARN_ONCE = 0 142 | 143 | @register_to_config 144 | def __init__( 145 | self, 146 | sample_size: Optional[int] = None, 147 | in_channels: int = 4, 148 | out_channels: int = 4, 149 | center_input_sample: bool = False, 150 | flip_sin_to_cos: bool = True, 151 | freq_shift: int = 0, 152 | down_block_types: Tuple[str] = ( 153 | "CrossAttnDownBlock2D", 154 | "CrossAttnDownBlock2D", 155 | "CrossAttnDownBlock2D", 156 | "DownBlock2D", 157 | ), 158 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 159 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 160 | only_cross_attention: Union[bool, Tuple[bool]] = False, 161 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 162 | layers_per_block: Union[int, Tuple[int]] = 2, 163 | downsample_padding: int = 1, 164 | mid_block_scale_factor: float = 1, 165 | dropout: float = 0.0, 166 | act_fn: str = "silu", 167 | norm_num_groups: Optional[int] = 32, 168 | norm_eps: float = 1e-5, 169 | cross_attention_dim: Union[int, Tuple[int]] = 1280, 170 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 171 | reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, 172 | encoder_hid_dim: Optional[int] = None, 173 | encoder_hid_dim_type: Optional[str] = None, 174 | attention_head_dim: Union[int, Tuple[int]] = 8, 175 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 176 | dual_cross_attention: bool = False, 177 | use_linear_projection: bool = False, 178 | class_embed_type: Optional[str] = None, 179 | addition_embed_type: Optional[str] = None, 180 | addition_time_embed_dim: Optional[int] = None, 181 | num_class_embeds: Optional[int] = None, 182 | upcast_attention: bool = False, 183 | resnet_time_scale_shift: str = "default", 184 | resnet_skip_time_act: bool = False, 185 | resnet_out_scale_factor: int = 1.0, 186 | time_embedding_type: str = "positional", 187 | time_embedding_dim: Optional[int] = None, 188 | time_embedding_act_fn: Optional[str] = None, 189 | timestep_post_act: Optional[str] = None, 190 | time_cond_proj_dim: Optional[int] = None, 191 | conv_in_kernel: int = 3, 192 | conv_out_kernel: int = 3, 193 | projection_class_embeddings_input_dim: Optional[int] = None, 194 | class_embeddings_concat: bool = False, 195 | mid_block_only_cross_attention: Optional[bool] = None, 196 | cross_attention_norm: Optional[str] = None, 197 | attention_type: str = "default", 198 | addition_embed_type_num_heads=64, 199 | # parameter added, we should keep all above (do not use kwargs) 200 | trainable_state="only_new", 201 | neighboring_view_pair: Optional[dict] = None, 202 | neighboring_attn_type: str = "add", 203 | zero_module_type: str = "zero_linear", 204 | crossview_attn_type: str = "basic", 205 | img_size: Optional[Tuple[int, int]] = None, 206 | ): 207 | super().__init__( 208 | sample_size=sample_size, in_channels=in_channels, 209 | out_channels=out_channels, center_input_sample=center_input_sample, 210 | flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, 211 | down_block_types=down_block_types, mid_block_type=mid_block_type, 212 | up_block_types=up_block_types, 213 | only_cross_attention=only_cross_attention, 214 | block_out_channels=block_out_channels, 215 | layers_per_block=layers_per_block, 216 | downsample_padding=downsample_padding, 217 | mid_block_scale_factor=mid_block_scale_factor, act_fn=act_fn, 218 | norm_num_groups=norm_num_groups, norm_eps=norm_eps, 219 | cross_attention_dim=cross_attention_dim, 220 | encoder_hid_dim=encoder_hid_dim, 221 | encoder_hid_dim_type=encoder_hid_dim_type, 222 | attention_head_dim=attention_head_dim, 223 | dual_cross_attention=dual_cross_attention, 224 | use_linear_projection=use_linear_projection, 225 | class_embed_type=class_embed_type, 226 | addition_embed_type=addition_embed_type, 227 | num_class_embeds=num_class_embeds, 228 | upcast_attention=upcast_attention, 229 | resnet_time_scale_shift=resnet_time_scale_shift, 230 | resnet_skip_time_act=resnet_skip_time_act, 231 | resnet_out_scale_factor=resnet_out_scale_factor, 232 | time_embedding_type=time_embedding_type, 233 | time_embedding_dim=time_embedding_dim, 234 | time_embedding_act_fn=time_embedding_act_fn, 235 | timestep_post_act=timestep_post_act, 236 | time_cond_proj_dim=time_cond_proj_dim, 237 | conv_in_kernel=conv_in_kernel, conv_out_kernel=conv_out_kernel, 238 | projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, 239 | class_embeddings_concat=class_embeddings_concat, 240 | mid_block_only_cross_attention=mid_block_only_cross_attention, 241 | cross_attention_norm=cross_attention_norm, 242 | addition_embed_type_num_heads=addition_embed_type_num_heads,) 243 | 244 | self.crossview_attn_type = crossview_attn_type 245 | self.img_size = [int(s) for s in img_size] \ 246 | if img_size is not None else None 247 | self._new_module = {} 248 | for name, mod in list(self.named_modules()): 249 | if isinstance(mod, BasicTransformerBlock): 250 | if crossview_attn_type == "basic": 251 | # import ipdb; ipdb.set_trace() 252 | 253 | # print(mod._args) 254 | # print("-"*100) 255 | _set_module(self, name, BasicMultiviewTransformerBlock( 256 | **mod._args, 257 | neighboring_view_pair=neighboring_view_pair, 258 | neighboring_attn_type=neighboring_attn_type, 259 | zero_module_type=zero_module_type, 260 | )) 261 | else: 262 | raise TypeError(f"Unknown attn type: {crossview_attn_type}") 263 | for k, v in _get_module(self, name).new_module.items(): 264 | self._new_module[f"{name}.{k}"] = v 265 | self.trainable_state = trainable_state 266 | 267 | @property 268 | def trainable_module(self) -> Dict[str, nn.Module]: 269 | if self.trainable_state == "all": 270 | return {self.__class__: self} 271 | elif self.trainable_state == "only_new": 272 | return self._new_module 273 | else: 274 | raise ValueError(f"Unknown trainable_state: {self.trainable_state}") 275 | 276 | @property 277 | def trainable_parameters(self) -> List[nn.Parameter]: 278 | params = [] 279 | for mod in self.trainable_module.values(): 280 | for param in mod.parameters(): 281 | params.append(param) 282 | return params 283 | 284 | def train(self, mode=True): 285 | if not isinstance(mode, bool): 286 | raise ValueError("training mode is expected to be boolean") 287 | # first, set all to false 288 | super().train(False) 289 | if mode: 290 | # ensure gradient_checkpointing is usable, set training = True 291 | for mod in self.modules(): 292 | if getattr(mod, "gradient_checkpointing", False): 293 | mod.training = True 294 | # then, for some modules, we set according to `mode` 295 | self.training = False 296 | for mod in self.trainable_module.values(): 297 | if mod is self: 298 | super().train(mode) 299 | else: 300 | mod.train(mode) 301 | return self 302 | 303 | def enable_gradient_checkpointing(self, flag=None): 304 | """ 305 | Activates gradient checkpointing for the current model. 306 | 307 | Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint 308 | activations". 309 | """ 310 | # self.apply(partial(self._set_gradient_checkpointing, value=True)) 311 | mod_idx = -1 312 | for module in self.modules(): 313 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): 314 | mod_idx += 1 315 | if flag is not None and not flag[mod_idx]: 316 | logging.debug( 317 | f"[UNet2DConditionModelMultiview] " 318 | f"gradient_checkpointing skip [{module.__class__}]") 319 | continue 320 | logging.debug(f"[UNet2DConditionModelMultiview] set " 321 | f"[{module.__class__}] to gradient_checkpointing") 322 | module.gradient_checkpointing = True 323 | 324 | @classmethod 325 | def from_unet_2d_condition( 326 | cls, 327 | unet: UNet2DConditionModel, 328 | load_weights_from_unet: bool = True, 329 | # multivew 330 | **kwargs, 331 | ): 332 | r""" 333 | Instantiate Multiview unet class from UNet2DConditionModel. 334 | 335 | Parameters: 336 | unet (`UNet2DConditionModel`): 337 | UNet model which weights are copied to the ControlNet. Note that all configuration options are also 338 | copied where applicable. 339 | """ 340 | 341 | unet_2d_condition_multiview = cls( 342 | **unet.config, 343 | # multivew 344 | **kwargs, 345 | ) 346 | 347 | if load_weights_from_unet: 348 | missing_keys, unexpected_keys = unet_2d_condition_multiview.load_state_dict( 349 | unet.state_dict(), strict=False) 350 | logging.info( 351 | f"[UNet2DConditionModelMultiview] load pretrained with " 352 | f"missing_keys: {missing_keys}; " 353 | f"unexpected_keys: {unexpected_keys}") 354 | 355 | return unet_2d_condition_multiview 356 | 357 | def forward( 358 | self, 359 | sample: torch.FloatTensor, 360 | timestep: Union[torch.Tensor, float, int], 361 | encoder_hidden_states: torch.Tensor, 362 | class_labels: Optional[torch.Tensor] = None, 363 | timestep_cond: Optional[torch.Tensor] = None, 364 | attention_mask: Optional[torch.Tensor] = None, 365 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 366 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 367 | mid_block_additional_residual: Optional[torch.Tensor] = None, 368 | return_dict: bool = True, 369 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 370 | ) -> Union[UNet2DConditionOutput, Tuple]: 371 | r""" 372 | Args: 373 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 374 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 375 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 376 | return_dict (`bool`, *optional*, defaults to `True`): 377 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 378 | cross_attention_kwargs (`dict`, *optional*): 379 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 380 | `self.processor` in 381 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 382 | 383 | Returns: 384 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 385 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 386 | returning a tuple, the first element is the sample tensor. 387 | """ 388 | # TODO: actually, we do not change logic in forward 389 | 390 | # By default samples have to be AT least a multiple of the overall upsampling factor. 391 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 392 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 393 | # on the fly if necessary. 394 | default_overall_up_factor = 2**self.num_upsamplers 395 | 396 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 397 | forward_upsample_size = False 398 | upsample_size = None 399 | 400 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 401 | if self._WARN_ONCE == 0: 402 | logging.warning( 403 | "[UNet2DConditionModelMultiview] Forward upsample size to force interpolation output size.") 404 | self._WARN_ONCE = 1 405 | forward_upsample_size = True 406 | 407 | # prepare attention_mask 408 | if attention_mask is not None: 409 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 410 | attention_mask = attention_mask.unsqueeze(1) 411 | 412 | # 0. center input if necessary 413 | if self.config.center_input_sample: 414 | sample = 2 * sample - 1.0 415 | 416 | # 1. time 417 | timesteps = timestep 418 | if not torch.is_tensor(timesteps): 419 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 420 | # This would be a good case for the `match` statement (Python 3.10+) 421 | is_mps = sample.device.type == "mps" 422 | if isinstance(timestep, float): 423 | dtype = torch.float32 if is_mps else torch.float64 424 | else: 425 | dtype = torch.int32 if is_mps else torch.int64 426 | timesteps = torch.tensor( 427 | [timesteps], 428 | dtype=dtype, device=sample.device) 429 | elif len(timesteps.shape) == 0: 430 | timesteps = timesteps[None].to(sample.device) 431 | 432 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 433 | timesteps = timesteps.expand(sample.shape[0]) 434 | 435 | t_emb = self.time_proj(timesteps) 436 | 437 | # `Timesteps` does not contain any weights and will always return f32 tensors 438 | # but time_embedding might actually be running in fp16. so we need to cast here. 439 | # there might be better ways to encapsulate this. 440 | t_emb = t_emb.to(dtype=self.dtype) 441 | 442 | emb = self.time_embedding(t_emb, timestep_cond) 443 | 444 | if self.class_embedding is not None: 445 | if class_labels is None: 446 | raise ValueError( 447 | "class_labels should be provided when num_class_embeds > 0") 448 | 449 | if self.config.class_embed_type == "timestep": 450 | class_labels = self.time_proj(class_labels) 451 | 452 | # `Timesteps` does not contain any weights and will always return f32 tensors 453 | # there might be better ways to encapsulate this. 454 | class_labels = class_labels.to(dtype=sample.dtype) 455 | 456 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 457 | 458 | if self.config.class_embeddings_concat: 459 | emb = torch.cat([emb, class_emb], dim=-1) 460 | else: 461 | emb = emb + class_emb 462 | 463 | if self.config.addition_embed_type == "text": 464 | aug_emb = self.add_embedding(encoder_hidden_states) 465 | emb = emb + aug_emb 466 | 467 | if self.time_embed_act is not None: 468 | emb = self.time_embed_act(emb) 469 | 470 | if self.encoder_hid_proj is not None: 471 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 472 | 473 | # 2. pre-process 474 | sample = self.conv_in(sample) 475 | 476 | # 3. down 477 | down_block_res_samples = (sample,) 478 | for downsample_block in self.down_blocks: 479 | if self.crossview_attn_type == 'epipolar': 480 | cross_attention_kwargs['out_size'] = sample.shape[-2:] 481 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 482 | sample, res_samples = downsample_block( 483 | hidden_states=sample, 484 | temb=emb, 485 | encoder_hidden_states=encoder_hidden_states, 486 | attention_mask=attention_mask, 487 | cross_attention_kwargs=copy.deepcopy(cross_attention_kwargs), 488 | ) 489 | else: 490 | sample, res_samples = downsample_block( 491 | hidden_states=sample, temb=emb) 492 | 493 | down_block_res_samples += res_samples 494 | 495 | if down_block_additional_residuals is not None: 496 | new_down_block_res_samples = () 497 | 498 | for down_block_res_sample, down_block_additional_residual in zip( 499 | down_block_res_samples, down_block_additional_residuals 500 | ): 501 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 502 | new_down_block_res_samples += (down_block_res_sample,) 503 | 504 | down_block_res_samples = new_down_block_res_samples 505 | 506 | # 4. mid 507 | if self.mid_block is not None: 508 | if self.crossview_attn_type == 'epipolar': 509 | cross_attention_kwargs['out_size'] = sample.shape[-2:] 510 | sample = self.mid_block( 511 | sample, 512 | emb, 513 | encoder_hidden_states=encoder_hidden_states, 514 | attention_mask=attention_mask, 515 | cross_attention_kwargs=copy.deepcopy(cross_attention_kwargs), 516 | ) 517 | 518 | if mid_block_additional_residual is not None: 519 | sample = sample + mid_block_additional_residual 520 | 521 | # 5. up 522 | for i, upsample_block in enumerate(self.up_blocks): 523 | is_final_block = i == len(self.up_blocks) - 1 524 | 525 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 526 | down_block_res_samples = down_block_res_samples[: -len( 527 | upsample_block.resnets)] 528 | 529 | # if we have not reached the final block and need to forward the 530 | # upsample size, we do it here 531 | if not is_final_block and forward_upsample_size: 532 | upsample_size = down_block_res_samples[-1].shape[2:] 533 | 534 | # import ipdb; ipdb.set_trace() 535 | # print(sample.shape) 536 | # print([x.shape for x in res_samples]) 537 | 538 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 539 | if self.crossview_attn_type == 'epipolar': 540 | cross_attention_kwargs['out_size'] = sample.shape[-2:] 541 | sample = upsample_block( 542 | hidden_states=sample, temb=emb, 543 | res_hidden_states_tuple=res_samples, 544 | encoder_hidden_states=encoder_hidden_states, 545 | cross_attention_kwargs=copy.deepcopy(cross_attention_kwargs), 546 | upsample_size=upsample_size, attention_mask=attention_mask,) 547 | else: 548 | sample = upsample_block( 549 | hidden_states=sample, temb=emb, 550 | res_hidden_states_tuple=res_samples, 551 | upsample_size=upsample_size) 552 | 553 | # 6. post-process 554 | if self.conv_norm_out: 555 | sample = self.conv_norm_out(sample) 556 | sample = self.conv_act(sample) 557 | sample = self.conv_out(sample) 558 | 559 | if not return_dict: 560 | return (sample,) 561 | 562 | return UNet2DConditionOutput(sample=sample) 563 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import cv2 5 | import random 6 | import shutil 7 | from pathlib import Path 8 | import pickle 9 | 10 | import accelerate 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import torchvision.transforms.functional as tf 15 | 16 | import torch.utils.checkpoint 17 | import transformers 18 | from accelerate import Accelerator 19 | from accelerate.logging import get_logger 20 | from accelerate.utils import ProjectConfiguration, set_seed 21 | from huggingface_hub import create_repo, upload_folder 22 | from packaging import version 23 | from PIL import Image 24 | from tqdm.auto import tqdm 25 | from transformers import AutoTokenizer, PretrainedConfig 26 | from torchvision import utils 27 | 28 | import diffusers 29 | from diffusers import ( 30 | AutoencoderKL, 31 | # ControlNetModel, 32 | DDPMScheduler, 33 | # StableDiffusionControlNetPipeline, 34 | UNet2DConditionModel, 35 | UniPCMultistepScheduler, 36 | ) 37 | 38 | from models.controlnet1x1 import ControlNetModel1x1 as ControlNetModel 39 | from models.pipeline_controlnet_1x1_4dunet import StableDiffusionControlNetPipeline1x1 as StableDiffusionControlNetPipeline 40 | 41 | from models.unet_2d_condition_multiview import UNet2DConditionModelMultiview 42 | 43 | from diffusers.optimization import get_scheduler 44 | from diffusers.utils import check_min_version, is_wandb_available 45 | from diffusers.utils.import_utils import is_xformers_available 46 | from diffusers.utils.torch_utils import is_compiled_module 47 | 48 | from utils.common import convert_outputs_to_fp16, unet_param 49 | from utils.dataset_nusmtv import NuScenesDatasetMtvSpar as NuScenesDataset 50 | 51 | 52 | 53 | 54 | if is_wandb_available(): 55 | import wandb 56 | 57 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 58 | check_min_version("0.26.0.dev0") 59 | 60 | logger = get_logger(__name__) 61 | 62 | 63 | 64 | def image_grid(imgs, rows, cols): 65 | assert len(imgs) == rows * cols 66 | 67 | w, h = imgs[0].size 68 | grid = Image.new("RGB", size=(cols * w, rows * h)) 69 | 70 | for i, img in enumerate(imgs): 71 | grid.paste(img, box=(i % cols * w, i // cols * h)) 72 | return grid 73 | 74 | 75 | def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, val_dataset): 76 | logger.info("Running validation... ") 77 | 78 | controlnet = accelerator.unwrap_model(controlnet) 79 | 80 | pipeline = StableDiffusionControlNetPipeline.from_pretrained( 81 | args.pretrained_model_name_or_path, 82 | vae=vae, 83 | text_encoder=text_encoder, 84 | tokenizer=tokenizer, 85 | unet=unet, 86 | controlnet=controlnet, 87 | safety_checker=None, 88 | revision=args.revision, 89 | variant=args.variant, 90 | torch_dtype=weight_dtype, 91 | ) 92 | pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) 93 | pipeline = pipeline.to(accelerator.device) 94 | pipeline.set_progress_bar_config(disable=True) 95 | 96 | if args.enable_xformers_memory_efficient_attention: 97 | pipeline.enable_xformers_memory_efficient_attention() 98 | 99 | if args.seed is None: 100 | generator = None 101 | else: 102 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 103 | 104 | 105 | for img_idx in [0, 678, 1879, ]: 106 | data_dict = val_dataset.__getitem__(img_idx) 107 | mtv_condition = data_dict['ctrl_img'] 108 | 109 | validation_prompts = ['show a photorealistic street view image'] * 6 110 | images_tensor = [] 111 | 112 | for _ in range(args.num_validation_images): 113 | with torch.autocast("cuda"): 114 | image = pipeline( 115 | prompt=validation_prompts, image=mtv_condition, num_inference_steps=20, generator=generator, height=args.height, width=args.width, controlnet_conditioning_scale=1.0, guidance_scale=args.cfg_scale 116 | ).images#[0] 117 | image = torch.cat([torch.tensor(np.array(ii)) for ii in image], 1) 118 | 119 | images_tensor.append(image) 120 | 121 | # [448, 6, 800, 3] to [448, 4800, 3] 122 | raw_img = data_dict['pixel_values'].permute(2,0,3,1).reshape(images_tensor[0].shape) * 255 123 | # print(data_dict['occ_rgb'].shape) 124 | occ_rgb = data_dict['occ_rgb'].permute(1,0,2,3).reshape(images_tensor[0].shape) 125 | gen_img = torch.cat(images_tensor, 0) 126 | gen_img = torch.cat([occ_rgb, gen_img, raw_img], 0) 127 | 128 | out_path = os.path.join( 129 | args.output_dir, 130 | f"step_{step:06d}_{img_idx:04d}.jpg", 131 | ) 132 | 133 | cv2.imwrite(out_path, cv2.cvtColor(gen_img.cpu().numpy(), cv2.COLOR_RGB2BGR)) 134 | 135 | return None 136 | 137 | 138 | 139 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 140 | text_encoder_config = PretrainedConfig.from_pretrained( 141 | pretrained_model_name_or_path, 142 | subfolder="text_encoder", 143 | revision=revision, 144 | ) 145 | model_class = text_encoder_config.architectures[0] 146 | 147 | if model_class == "CLIPTextModel": 148 | from transformers import CLIPTextModel 149 | 150 | return CLIPTextModel 151 | elif model_class == "RobertaSeriesModelWithTransformation": 152 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 153 | 154 | return RobertaSeriesModelWithTransformation 155 | else: 156 | raise ValueError(f"{model_class} is not supported.") 157 | 158 | 159 | def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): 160 | img_str = "" 161 | if image_logs is not None: 162 | img_str = "You can find some example images below.\n" 163 | for i, log in enumerate(image_logs): 164 | images = log["images"] 165 | validation_prompt = log["validation_prompt"] 166 | validation_image = log["validation_image"] 167 | validation_image.save(os.path.join(repo_folder, "image_control.png")) 168 | img_str += f"prompt: {validation_prompt}\n" 169 | images = [validation_image] + images 170 | image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) 171 | img_str += f"![images_{i})](./images_{i}.png)\n" 172 | 173 | yaml = f""" 174 | --- 175 | license: creativeml-openrail-m 176 | base_model: {base_model} 177 | tags: 178 | - stable-diffusion 179 | - stable-diffusion-diffusers 180 | - text-to-image 181 | - diffusers 182 | - controlnet 183 | inference: true 184 | --- 185 | """ 186 | model_card = f""" 187 | # controlnet-{repo_id} 188 | 189 | These are controlnet weights trained on {base_model} with new type of conditioning. 190 | {img_str} 191 | """ 192 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 193 | f.write(yaml + model_card) 194 | 195 | 196 | def main(args): 197 | logging_dir = Path(args.output_dir, args.logging_dir) 198 | 199 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 200 | 201 | accelerator = Accelerator( 202 | gradient_accumulation_steps=args.gradient_accumulation_steps, 203 | mixed_precision=args.mixed_precision, 204 | log_with=args.report_to, 205 | project_config=accelerator_project_config, 206 | ) 207 | 208 | # Make one log on every process with the configuration for debugging. 209 | logging.basicConfig( 210 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 211 | datefmt="%m/%d/%Y %H:%M:%S", 212 | level=logging.INFO, 213 | ) 214 | logger.info(accelerator.state, main_process_only=False) 215 | if accelerator.is_local_main_process: 216 | transformers.utils.logging.set_verbosity_warning() 217 | diffusers.utils.logging.set_verbosity_info() 218 | else: 219 | transformers.utils.logging.set_verbosity_error() 220 | diffusers.utils.logging.set_verbosity_error() 221 | 222 | # If passed along, set the training seed now. 223 | if args.seed is not None: 224 | set_seed(args.seed) 225 | 226 | # Handle the repository creation 227 | if accelerator.is_main_process: 228 | if args.output_dir is not None: 229 | os.makedirs(args.output_dir, exist_ok=True) 230 | 231 | if args.push_to_hub: 232 | repo_id = create_repo( 233 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 234 | ).repo_id 235 | 236 | # Load the tokenizer 237 | if args.tokenizer_name: 238 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 239 | elif args.pretrained_model_name_or_path: 240 | tokenizer = AutoTokenizer.from_pretrained( 241 | args.pretrained_model_name_or_path, 242 | subfolder="tokenizer", 243 | revision=args.revision, 244 | use_fast=False, 245 | ) 246 | 247 | # import correct text encoder class 248 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 249 | 250 | # Load scheduler and models 251 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 252 | text_encoder = text_encoder_cls.from_pretrained( 253 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 254 | ) 255 | vae = AutoencoderKL.from_pretrained( 256 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 257 | ) 258 | unet = UNet2DConditionModel.from_pretrained( 259 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 260 | ) 261 | 262 | if args.controlnet_model_name_or_path: 263 | logger.info("Loading existing controlnet weights") 264 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) 265 | else: 266 | logger.info("Initializing controlnet weights from unet") 267 | controlnet = ControlNetModel.from_unet(unet, conditioning_channels=args.ctrl_channel) 268 | # controlnet = ControlNetModel.from_unet(unet) 269 | 270 | unet = UNet2DConditionModelMultiview.from_unet_2d_condition(unet, **unet_param) 271 | 272 | # Resuming unet 273 | if args.resume_from_checkpoint == "latest": 274 | # Get the most recent checkpoint 275 | dirs = os.listdir(args.output_dir) 276 | dirs = [d for d in dirs if d.startswith("checkpoint")] 277 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 278 | path = dirs[-1] if len(dirs) > 0 else None 279 | 280 | if path is None: 281 | accelerator.print( 282 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 283 | ) 284 | args.resume_from_checkpoint = None 285 | initial_global_step = 0 286 | else: 287 | accelerator.print(f"Resuming unet from checkpoint {path}") 288 | unet = unet.from_pretrained(os.path.join(args.output_dir, path), subfolder="unet") 289 | 290 | # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) 291 | def unwrap_model(model): 292 | model = accelerator.unwrap_model(model) 293 | model = model._orig_mod if is_compiled_module(model) else model 294 | return model 295 | 296 | # `accelerate` 0.16.0 will have better support for customized saving 297 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 298 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 299 | def save_model_hook(models, weights, output_dir): 300 | if accelerator.is_main_process: 301 | i = len(weights) - 1 302 | 303 | while len(weights) > 0: 304 | weights.pop() 305 | model = models[i] 306 | 307 | sub_dir = "controlnet" 308 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 309 | 310 | i -= 1 311 | 312 | def load_model_hook(models, input_dir): 313 | while len(models) > 0: 314 | # pop models so that they are not loaded again 315 | model = models.pop() 316 | 317 | # load diffusers style into model 318 | load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") 319 | model.register_to_config(**load_model.config) 320 | model.load_state_dict(load_model.state_dict()) 321 | 322 | del load_model 323 | 324 | accelerator.register_save_state_pre_hook(save_model_hook) 325 | accelerator.register_load_state_pre_hook(load_model_hook) 326 | 327 | vae.requires_grad_(False) 328 | unet.requires_grad_(False) 329 | text_encoder.requires_grad_(False) 330 | controlnet.train() 331 | 332 | if args.enable_xformers_memory_efficient_attention: 333 | if is_xformers_available(): 334 | import xformers 335 | 336 | xformers_version = version.parse(xformers.__version__) 337 | if xformers_version == version.parse("0.0.16"): 338 | logger.warn( 339 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 340 | ) 341 | unet.enable_xformers_memory_efficient_attention() 342 | controlnet.enable_xformers_memory_efficient_attention() 343 | else: 344 | raise ValueError("xformers is not available. Make sure it is installed correctly") 345 | 346 | if args.gradient_checkpointing: 347 | controlnet.enable_gradient_checkpointing() 348 | 349 | # Check that all trainable models are in full precision 350 | low_precision_error_string = ( 351 | " Please make sure to always have all model weights in full float32 precision when starting training - even if" 352 | " doing mixed precision training, copy of the weights should still be float32." 353 | ) 354 | 355 | if unwrap_model(controlnet).dtype != torch.float32: 356 | raise ValueError( 357 | f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" 358 | ) 359 | 360 | # Enable TF32 for faster training on Ampere GPUs, 361 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 362 | if args.allow_tf32: 363 | torch.backends.cuda.matmul.allow_tf32 = True 364 | 365 | if args.scale_lr: 366 | args.learning_rate = ( 367 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 368 | ) 369 | 370 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 371 | if args.use_8bit_adam: 372 | try: 373 | import bitsandbytes as bnb 374 | except ImportError: 375 | raise ImportError( 376 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 377 | ) 378 | 379 | optimizer_class = bnb.optim.AdamW8bit 380 | else: 381 | optimizer_class = torch.optim.AdamW 382 | 383 | # Optimizer creation 384 | params_to_optimize = list(controlnet.parameters()) 385 | unet.to(accelerator.device, dtype=torch.float16) 386 | 387 | for name, mod in unet.trainable_module.items(): 388 | mod.requires_grad_(True) 389 | 390 | for name, mod in unet.trainable_module.items(): 391 | logging.debug(f"[MultiviewRunner] set {name} to fp32") 392 | mod.to(dtype=torch.float32) 393 | mod._original_forward = mod.forward 394 | # autocast intermediate is necessary since others are fp16 395 | mod.forward = torch.cuda.amp.autocast( 396 | dtype=torch.float16)(mod.forward) 397 | # we ensure output is always fp16 398 | mod.forward = convert_outputs_to_fp16(mod.forward) 399 | 400 | unet_params = unet.trainable_parameters 401 | params_to_optimize += unet_params 402 | 403 | 404 | optimizer = optimizer_class( 405 | params_to_optimize, 406 | lr=args.learning_rate, 407 | betas=(args.adam_beta1, args.adam_beta2), 408 | weight_decay=args.adam_weight_decay, 409 | eps=args.adam_epsilon, 410 | ) 411 | 412 | # train_dataset = NuScenesDatasetMtvCBGS(args, tokenizer, 'train') 413 | train_dataset = NuScenesDataset(args, tokenizer, 'train') 414 | val_dataset = NuScenesDataset(args, tokenizer, 'val') 415 | 416 | train_dataloader = torch.utils.data.DataLoader( 417 | train_dataset, 418 | shuffle=True, 419 | batch_size=args.train_batch_size, 420 | num_workers=args.dataloader_num_workers, 421 | ) 422 | 423 | # Scheduler and math around the number of training steps. 424 | overrode_max_train_steps = False 425 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 426 | if args.max_train_steps is None: 427 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 428 | overrode_max_train_steps = True 429 | 430 | lr_scheduler = get_scheduler( 431 | args.lr_scheduler, 432 | optimizer=optimizer, 433 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 434 | num_training_steps=args.max_train_steps * accelerator.num_processes, 435 | num_cycles=args.lr_num_cycles, 436 | power=args.lr_power, 437 | ) 438 | 439 | # Prepare everything with our `accelerator`. 440 | controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 441 | controlnet, optimizer, train_dataloader, lr_scheduler 442 | ) 443 | 444 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 445 | # as these models are only used for inference, keeping weights in full precision is not required. 446 | weight_dtype = torch.float32 447 | if accelerator.mixed_precision == "fp16": 448 | weight_dtype = torch.float16 449 | elif accelerator.mixed_precision == "bf16": 450 | weight_dtype = torch.bfloat16 451 | 452 | # Move vae, unet and text_encoder to device and cast to weight_dtype 453 | vae.to(accelerator.device, dtype=weight_dtype) 454 | # unet.to(accelerator.device, dtype=weight_dtype) 455 | text_encoder.to(accelerator.device, dtype=weight_dtype) 456 | 457 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 458 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 459 | if overrode_max_train_steps: 460 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 461 | # Afterwards we recalculate our number of training epochs 462 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 463 | 464 | # We need to initialize the trackers we use, and also store our configuration. 465 | # The trackers initializes automatically on the main process. 466 | if accelerator.is_main_process: 467 | tracker_config = dict(vars(args)) 468 | 469 | # tensorboard cannot handle list types for config 470 | tracker_config.pop("validation_prompt") 471 | tracker_config.pop("validation_image") 472 | 473 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 474 | 475 | # Train! 476 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 477 | 478 | logger.info("***** Running training *****") 479 | logger.info(f" Num examples = {len(train_dataset)}") 480 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 481 | logger.info(f" Num Epochs = {args.num_train_epochs}") 482 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 483 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 484 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 485 | logger.info(f" Total optimization steps = {args.max_train_steps}") 486 | global_step = 0 487 | first_epoch = 0 488 | 489 | # Potentially load in the weights and states from a previous save 490 | if args.resume_from_checkpoint: 491 | if args.resume_from_checkpoint != "latest": 492 | path = os.path.basename(args.resume_from_checkpoint) 493 | else: 494 | # Get the most recent checkpoint 495 | dirs = os.listdir(args.output_dir) 496 | dirs = [d for d in dirs if d.startswith("checkpoint")] 497 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 498 | path = dirs[-1] if len(dirs) > 0 else None 499 | 500 | if path is None: 501 | accelerator.print( 502 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 503 | ) 504 | args.resume_from_checkpoint = None 505 | initial_global_step = 0 506 | else: 507 | accelerator.print(f"Resuming from checkpoint {path}") 508 | # import ipdb; ipdb.set_trace() 509 | accelerator.load_state(os.path.join(args.output_dir, path)) 510 | # unet = unet.from_pretrained(os.path.join(args.output_dir, path), subfolder="unet") 511 | 512 | global_step = int(path.split("-")[1]) 513 | 514 | initial_global_step = global_step 515 | first_epoch = global_step // num_update_steps_per_epoch 516 | else: 517 | initial_global_step = 0 518 | 519 | progress_bar = tqdm( 520 | range(0, args.max_train_steps), 521 | initial=initial_global_step, 522 | desc="Steps", 523 | # Only show the progress bar once on each machine. 524 | disable=not accelerator.is_local_main_process, 525 | ) 526 | 527 | image_logs = None 528 | for epoch in range(first_epoch, args.num_train_epochs): 529 | for step, batch in enumerate(train_dataloader): 530 | with accelerator.accumulate(controlnet): 531 | 532 | batch["pixel_values"] = batch["pixel_values"][0] 533 | batch["input_ids"] = batch["input_ids"][0] 534 | ctrl_img = batch['ctrl_img'][0] 535 | weight_mask = batch['weight_mask'][0] 536 | 537 | batch["conditioning_pixel_values"] = ctrl_img 538 | 539 | # Convert images to latent space 540 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 541 | latents = latents * vae.config.scaling_factor 542 | 543 | # Sample noise that we'll add to the latents 544 | noise = torch.randn_like(latents) 545 | bsz = latents.shape[0] 546 | # Sample a random timestep for each image 547 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 548 | timesteps = timesteps.long() 549 | 550 | # Add noise to the latents according to the noise magnitude at each timestep 551 | # (this is the forward diffusion process) 552 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 553 | 554 | # Get the text embedding for conditioning 555 | encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] 556 | 557 | controlnet_image = batch["conditioning_pixel_values"].to(device=latents.device, dtype=weight_dtype) 558 | 559 | down_block_res_samples, mid_block_res_sample = controlnet( 560 | noisy_latents, 561 | timesteps, 562 | encoder_hidden_states=encoder_hidden_states, 563 | controlnet_cond=controlnet_image, 564 | return_dict=False, 565 | ) 566 | 567 | model_pred = unet( 568 | noisy_latents, 569 | timesteps, 570 | encoder_hidden_states=encoder_hidden_states, 571 | down_block_additional_residuals=[ 572 | sample.to(dtype=weight_dtype) for sample in down_block_res_samples 573 | ], 574 | mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), 575 | return_dict=False, 576 | )[0] 577 | 578 | # Get the target for loss depending on the prediction type 579 | if noise_scheduler.config.prediction_type == "epsilon": 580 | target = noise 581 | elif noise_scheduler.config.prediction_type == "v_prediction": 582 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 583 | else: 584 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 585 | # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 586 | 587 | def weighted_mse_loss(input, target, weight): 588 | return torch.mean(weight * (input - target) ** 2) 589 | 590 | weight_mask = weight_mask.to(dtype=weight_dtype) 591 | weight_mask, depth_map = weight_mask[:, 0, ...], weight_mask[:, 1, ...] 592 | 593 | depth_map = -0.5 * (1 + torch.cos(torch.pi * depth_map / 255)) + 2.1 594 | fore = -0.5 * (1 + np.cos(np.pi * global_step / args.max_train_steps)) + 2.2 595 | 596 | curr_mask = weight_mask==2 597 | 598 | weight_mask[curr_mask] = depth_map[curr_mask] * fore 599 | 600 | loss = weighted_mse_loss(model_pred.float(), target.float(), weight_mask.unsqueeze(1).repeat(1,target.shape[1],1,1)) 601 | 602 | accelerator.backward(loss) 603 | if accelerator.sync_gradients: 604 | params_to_clip = controlnet.parameters() 605 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 606 | optimizer.step() 607 | lr_scheduler.step() 608 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 609 | 610 | # Checks if the accelerator has performed an optimization step behind the scenes 611 | if accelerator.sync_gradients: 612 | progress_bar.update(1) 613 | global_step += 1 614 | 615 | if accelerator.is_main_process: 616 | if global_step % args.checkpointing_steps == 0: 617 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 618 | if args.checkpoints_total_limit is not None: 619 | checkpoints = os.listdir(args.output_dir) 620 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 621 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 622 | 623 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 624 | if len(checkpoints) >= args.checkpoints_total_limit: 625 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 626 | removing_checkpoints = checkpoints[0:num_to_remove] 627 | 628 | logger.info( 629 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 630 | ) 631 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 632 | 633 | for removing_checkpoint in removing_checkpoints: 634 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 635 | shutil.rmtree(removing_checkpoint) 636 | 637 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 638 | accelerator.save_state(save_path) 639 | 640 | # save unet 641 | accelerator.unwrap_model(unet).save_pretrained(os.path.join(save_path, 'unet')) 642 | 643 | logger.info(f"Saved state to {save_path}") 644 | 645 | if global_step % args.validation_steps == 0: 646 | image_logs = log_validation( 647 | vae, 648 | text_encoder, 649 | tokenizer, 650 | unet, 651 | controlnet, 652 | args, 653 | accelerator, 654 | weight_dtype, 655 | global_step, 656 | val_dataset 657 | ) 658 | 659 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 660 | progress_bar.set_postfix(**logs) 661 | accelerator.log(logs, step=global_step) 662 | 663 | if global_step >= args.max_train_steps: 664 | break 665 | 666 | # Create the pipeline using using the trained modules and save it. 667 | accelerator.wait_for_everyone() 668 | if accelerator.is_main_process: 669 | controlnet = unwrap_model(controlnet) 670 | controlnet.save_pretrained(os.path.join(args.output_dir, f"checkpoint-{global_step}", 'controlnet')) 671 | unwrap_model(unet).save_pretrained(os.path.join(args.output_dir, f"checkpoint-{global_step}", 'unet')) 672 | 673 | if args.push_to_hub: 674 | save_model_card( 675 | repo_id, 676 | image_logs=image_logs, 677 | base_model=args.pretrained_model_name_or_path, 678 | repo_folder=args.output_dir, 679 | ) 680 | upload_folder( 681 | repo_id=repo_id, 682 | folder_path=args.output_dir, 683 | commit_message="End of training", 684 | ignore_patterns=["step_*", "epoch_*"], 685 | ) 686 | 687 | accelerator.end_training() 688 | 689 | 690 | if __name__ == "__main__": 691 | from args_file import parse_args 692 | args = parse_args() 693 | main(args) 694 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | export WANDB_DISABLED=True 2 | export HF_HUB_OFFLINE=True 3 | 4 | export MODEL_DIR="./ckp/stable-diffusion-v2-1" 5 | 6 | 7 | 8 | 9 | export EXP_NAME="train_syntheocc" 10 | export OUTPUT_DIR="./ckp/$EXP_NAME" 11 | export SAVE_IMG_DIR="vis_dir/$EXP_NAME/samples" 12 | export DATA_USED="samples_syntheocc_surocc" 13 | 14 | 15 | 16 | # accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --main_process_port 3226 train.py \ 17 | accelerate launch --gpu_ids 0, --num_processes 1 --main_process_port 3226 train.py \ 18 | --pretrained_model_name_or_path=$MODEL_DIR \ 19 | --output_dir=$OUTPUT_DIR \ 20 | --width=800 \ 21 | --height=448 \ 22 | --learning_rate=2e-5 \ 23 | --num_train_epochs=6 \ 24 | --train_batch_size=1 \ 25 | --mixed_precision="fp16" \ 26 | --num_validation_images=2 \ 27 | --validation_steps=1000 \ 28 | --checkpointing_steps=5000 \ 29 | --checkpoints_total_limit=10 \ 30 | --ctrl_channel=257 \ 31 | --enable_xformers_memory_efficient_attention \ 32 | --report_to='wandb' \ 33 | --use_cbgs=True \ 34 | --mtp_path='samples_syntheocc_surocc' \ 35 | --resume_from_checkpoint="latest" 36 | 37 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import importlib 3 | from functools import update_wrapper 4 | 5 | import torch 6 | import accelerate 7 | from accelerate.state import AcceleratorState 8 | from accelerate.utils import recursively_apply 9 | 10 | 11 | neighboring_view_pair = { 12 | 0: [5, 1], 13 | 1: [0, 2], 14 | 2: [1, 3], 15 | 3: [2, 4], 16 | 4: [3, 5], 17 | 5: [4, 0], 18 | } 19 | 20 | 21 | unet_param = { 22 | 'trainable_state': 'only_new', 23 | 'neighboring_view_pair': neighboring_view_pair, 24 | 'neighboring_attn_type': 'add', 25 | 'zero_module_type': 'zero_linear', 26 | 'crossview_attn_type': 'basic', 27 | 'img_size': [448, 800], 28 | } 29 | 30 | 31 | 32 | def load_module(name): 33 | p, m = name.rsplit(".", 1) 34 | mod = importlib.import_module(p) 35 | model_cls = getattr(mod, m) 36 | return model_cls 37 | 38 | 39 | def move_to(obj, device, filter=lambda x: True): 40 | if torch.is_tensor(obj): 41 | if filter(obj): 42 | return obj.to(device) 43 | else: 44 | return obj 45 | elif isinstance(obj, dict): 46 | res = {} 47 | for k, v in obj.items(): 48 | res[k] = move_to(v, device, filter) 49 | return res 50 | elif isinstance(obj, list): 51 | res = [] 52 | for v in obj: 53 | res.append(move_to(v, device, filter)) 54 | return res 55 | elif obj is None: 56 | return obj 57 | else: 58 | raise TypeError(f"Invalid type {obj.__class__} for move_to.") 59 | 60 | 61 | # take from torch.ao.quantization.fuse_modules 62 | # Generalization of getattr 63 | def _get_module(model, submodule_key): 64 | tokens = submodule_key.split('.') 65 | cur_mod = model 66 | for s in tokens: 67 | cur_mod = getattr(cur_mod, s) 68 | return cur_mod 69 | 70 | 71 | # Generalization of setattr 72 | def _set_module(model, submodule_key, module): 73 | tokens = submodule_key.split('.') 74 | sub_tokens = tokens[:-1] 75 | cur_mod = model 76 | for s in sub_tokens: 77 | cur_mod = getattr(cur_mod, s) 78 | 79 | setattr(cur_mod, tokens[-1], module) 80 | 81 | 82 | def convert_to_fp16(tensor): 83 | """ 84 | Recursively converts the elements nested list/tuple/dictionary of tensors in FP32 precision to FP16. 85 | 86 | Args: 87 | tensor (nested list/tuple/dictionary of `torch.Tensor`): 88 | The data to convert from FP32 to FP16. 89 | 90 | Returns: 91 | The same data structure as `tensor` with all tensors that were in FP32 precision converted to FP16. 92 | """ 93 | 94 | def _convert_to_fp16(tensor): 95 | return tensor.half() 96 | 97 | def _is_fp32_tensor(tensor): 98 | return hasattr(tensor, "dtype") and ( 99 | tensor.dtype == torch.float32 100 | ) 101 | 102 | return recursively_apply(_convert_to_fp16, tensor, 103 | test_type=_is_fp32_tensor) 104 | 105 | 106 | class ConvertOutputsToFp16: 107 | """ 108 | Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP32 109 | precision will be convert back to FP16. 110 | 111 | Args: 112 | model_forward (`Callable`): 113 | The function which outputs we want to treat. 114 | 115 | Returns: 116 | The same function as `model_forward` but with converted outputs. 117 | """ 118 | 119 | def __init__(self, model_forward): 120 | self.model_forward = model_forward 121 | update_wrapper(self, model_forward) 122 | 123 | def __call__(self, *args, **kwargs): 124 | return convert_to_fp16(self.model_forward(*args, **kwargs)) 125 | 126 | def __getstate__(self): 127 | raise pickle.PicklingError( 128 | "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it." 129 | ) 130 | 131 | 132 | convert_outputs_to_fp16 = ConvertOutputsToFp16 133 | 134 | 135 | def deepspeed_zero_init_disabled_context_manager(): 136 | """ 137 | returns either a context list that includes one that will disable zero.Init or an empty context list 138 | """ 139 | deepspeed_plugin = AcceleratorState().deepspeed_plugin \ 140 | if accelerate.state.is_initialized() else None 141 | if deepspeed_plugin is None: 142 | return [] 143 | 144 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 145 | -------------------------------------------------------------------------------- /utils/dataset_nusmtv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import os 4 | import cv2 5 | import json 6 | import torch 7 | import random 8 | import pickle 9 | import numpy as np 10 | from torchvision import transforms 11 | from PIL import Image 12 | import torchvision.transforms.functional as tf 13 | from torchvision.io import read_image 14 | from scipy import sparse 15 | 16 | colors_map = torch.tensor( 17 | [ 18 | [0, 0, 0, 255], # unknown 19 | [255, 158, 0, 255], # 1 car orange 20 | [255, 99, 71, 255], # 2 truck Tomato 21 | [255, 140, 0, 255], # 3 trailer Darkorange 22 | [255, 69, 0, 255], # 4 bus Orangered 23 | [233, 150, 70, 255], # 5 construction_vehicle Darksalmon 24 | [220, 20, 60, 255], # 6 bicycle Crimson 25 | [255, 61, 99, 255], # 7 motorcycle Red 26 | [0, 0, 230, 255], # 8 pedestrian Blue 27 | [47, 79, 79, 255], # 9 traffic_cone Darkslategrey 28 | [112, 128, 144, 255], # 10 barrier Slategrey 29 | [0, 207, 191, 255], # 11 driveable_surface nuTonomy green 30 | [175, 0, 75, 255], # 12 other_flat 31 | [75, 0, 75, 255], # 13 sidewalk 32 | [112, 180, 60, 255], # 14 terrain 33 | [222, 184, 135, 255], # 15 manmade Burlywood 34 | [0, 175, 0, 255], # 16 vegetation Green 35 | ] 36 | ).type(torch.uint8) 37 | 38 | 39 | class NuScenesDatasetMtvSpar(torch.utils.data.Dataset): 40 | def __init__(self, args, tokenizer, trainorval): 41 | self.args = args 42 | self.trainorval = trainorval 43 | 44 | dataroot = args.dataroot_path 45 | self.save_name = args.mtp_path 46 | 47 | if trainorval == 'train': 48 | data_file = os.path.join(dataroot, 'nuscenes_occ_infos_train.pkl') 49 | elif trainorval == 'val': 50 | data_file = os.path.join(dataroot, 'nuscenes_occ_infos_val.pkl') 51 | 52 | 53 | with open(data_file, "rb") as file: 54 | nus_pkl = pickle.load(file) 55 | 56 | self.dataset = nus_pkl['infos']#[:500] 57 | 58 | self.length = len(self.dataset) 59 | print(f"data scale: {self.length}") 60 | # random.shuffle(self.dataset) 61 | 62 | transforms_list = [ 63 | transforms.Resize((self.args.height, self.args.width), interpolation=transforms.InterpolationMode.BILINEAR), 64 | transforms.ToTensor(), 65 | ] 66 | 67 | if trainorval == 'train': 68 | transforms_list.append(transforms.Normalize([0.5], [0.5])) 69 | 70 | self.image_transforms = transforms.Compose(transforms_list) 71 | 72 | 73 | self.prompt = ['show a photorealistic street view image'] 74 | # self.CAM_NAMES = ['CAM_FRONT', 'CAM_FRONT_LEFT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT'] 75 | self.CAM_NAMES = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT'] 76 | 77 | 78 | if self.args.use_sdxl: 79 | self.tokenizer, self.text_encoders = tokenizer 80 | self.input_ids = self.tokenize_captions_sdxl(self.prompt * 6) 81 | else: 82 | self.tokenizer = tokenizer 83 | self.input_ids = self.tokenize_captions(self.prompt) 84 | 85 | if self.args.use_cbgs and self.trainorval == 'train': 86 | self.CLASSES = ('car', 'truck', 'trailer', 'bus', 'construction_vehicle', 87 | 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 88 | 'barrier', 'background') 89 | self.cat2id = {name: i for i, name in enumerate(self.CLASSES)} 90 | self.use_valid_flag = True 91 | 92 | self.sample_indices = self._get_sample_indices() 93 | self.length = len(self.sample_indices) 94 | print(f"cbgs data scale: {self.length}") 95 | 96 | 97 | self.weight_dtype = torch.float16 98 | 99 | 100 | 101 | 102 | 103 | def tokenize_captions(self, examples, is_train=True): 104 | captions = [] 105 | for caption in examples: 106 | if random.random() < self.args.proportion_empty_prompts: 107 | captions.append("") 108 | elif isinstance(caption, str): 109 | captions.append(caption) 110 | elif isinstance(caption, (list, np.ndarray)): 111 | # take a random caption if there are multiple 112 | captions.append(random.choice(caption) if is_train else caption[0]) 113 | else: 114 | raise ValueError( 115 | f"Caption column should contain either strings or lists of strings." 116 | ) 117 | inputs = self.tokenizer( 118 | captions, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 119 | ) 120 | return inputs.input_ids 121 | 122 | def tokenize_captions_sdxl(self, prompt_batch, is_train=True): 123 | 124 | original_size = (self.args.width, self.args.height) 125 | target_size = (self.args.width, self.args.height) 126 | crops_coords_top_left = (0, 0) 127 | 128 | prompt_embeds, pooled_prompt_embeds = self.encode_prompt( 129 | prompt_batch, self.text_encoders, self.tokenizer, self.args.proportion_empty_prompts, is_train 130 | ) 131 | add_text_embeds = pooled_prompt_embeds 132 | 133 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 134 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 135 | add_time_ids = torch.tensor([add_time_ids]) 136 | 137 | add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) 138 | 139 | return { 140 | "prompt_ids": prompt_embeds, 141 | "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, 142 | } 143 | 144 | 145 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 146 | def encode_prompt(self, prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): 147 | prompt_embeds_list = [] 148 | 149 | captions = [] 150 | for caption in prompt_batch: 151 | if random.random() < proportion_empty_prompts: 152 | captions.append("") 153 | elif isinstance(caption, str): 154 | captions.append(caption) 155 | elif isinstance(caption, (list, np.ndarray)): 156 | # take a random caption if there are multiple 157 | captions.append(random.choice(caption) if is_train else caption[0]) 158 | 159 | with torch.no_grad(): 160 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 161 | text_inputs = tokenizer( 162 | captions, 163 | padding="max_length", 164 | max_length=tokenizer.model_max_length, 165 | truncation=True, 166 | return_tensors="pt", 167 | ) 168 | text_input_ids = text_inputs.input_ids 169 | prompt_embeds = text_encoder( 170 | text_input_ids.to(text_encoder.device), 171 | output_hidden_states=True, 172 | ) 173 | 174 | # We are only ALWAYS interested in the pooled output of the final text encoder 175 | pooled_prompt_embeds = prompt_embeds[0] 176 | prompt_embeds = prompt_embeds.hidden_states[-2] 177 | bs_embed, seq_len, _ = prompt_embeds.shape 178 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 179 | prompt_embeds_list.append(prompt_embeds) 180 | 181 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 182 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 183 | return prompt_embeds, pooled_prompt_embeds 184 | 185 | 186 | 187 | def __len__(self): 188 | 'Denotes the total number of samples' 189 | if self.args.use_cbgs and self.trainorval == 'train': 190 | return len(self.sample_indices) 191 | return self.length 192 | 193 | def __getitem__(self, idx): 194 | if self.args.use_cbgs and self.trainorval == 'train': 195 | idx = self.sample_indices[idx] 196 | return self.get_data_dict(idx) 197 | 198 | def get_data_dict(self, index): 199 | 200 | curr_info = self.dataset[index] 201 | 202 | mtv_img = [] 203 | mtv_path_img = [] 204 | mtv_condition = [] 205 | mtv_prompt = [] 206 | mtv_weight_mask = [] 207 | mtv_occ_rgb = [] 208 | 209 | for cam_id in range(6): 210 | 211 | path_img = curr_info['cams'][self.CAM_NAMES[cam_id]]['data_path'] 212 | 213 | img = Image.open(path_img).convert("RGB") 214 | img = self.image_transforms(img)[None] 215 | mtv_img.append(img) 216 | 217 | pth_path = path_img.replace('jpg', 'pth') 218 | all_path = pth_path.replace('samples', self.save_name) 219 | 220 | 221 | in_cha = self.args.ctrl_channel - 1 222 | ctrl_img_path = all_path[:-4] + f'_mtp{in_cha}.npz' 223 | ctrl_img = sparse.load_npz(ctrl_img_path) 224 | ctrl_img = ctrl_img.toarray().reshape((in_cha, self.args.height//8, self.args.width//8)) 225 | ctrl_img = torch.tensor(ctrl_img)[None]#.to (device='cuda', dtype=self.weight_dtype) #/ 16 226 | 227 | 228 | 229 | fuse_path = all_path[:-4] + '_fuseweight.png' 230 | fuse_img = torch.tensor(cv2.imread(fuse_path, cv2.IMREAD_GRAYSCALE))#[..., 0] 231 | # fuse_img_down = tf.resize(fuse_img[None], (self.args.height//8, self.args.width//8))[0] 232 | fuse_img_down = fuse_img 233 | 234 | ctrl_img = torch.cat([ctrl_img, fuse_img_down[None, None]], 1) 235 | 236 | mtv_condition.append(ctrl_img) 237 | mtv_path_img.append(path_img) 238 | 239 | if not self.args.use_sdxl: 240 | # input_ids = self.tokenize_captions(self.prompt) 241 | input_ids = self.input_ids 242 | mtv_prompt.append(input_ids[None]) 243 | 244 | if self.trainorval == 'val': 245 | occrgb_path = all_path[:-4] + '_occrgb.png' 246 | # occrgb_path = all_path[:-4] + '_occrgb.jpg' 247 | occ_rgb = read_image(occrgb_path).permute(1,2,0) # 3hw to hw3 248 | occ_rgb = tf.resize(occ_rgb.permute(2,0,1), (self.args.height, self.args.width)).permute(1,2,0) 249 | mtv_occ_rgb.append(occ_rgb[None]) # hw3 250 | 251 | elif self.trainorval == 'train': 252 | mask = (fuse_img_down >= 1) & (fuse_img_down <= 10) 253 | weight_mask = torch.ones_like(mask) * 1 254 | weight_mask[mask] = 2 255 | 256 | depth_path = all_path[:-4] + '_depthmap.png' 257 | depth_map = torch.tensor(cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)) 258 | # depth_map = tf.resize(depth_map[None], (self.args.height//8, self.args.width//8))[0] 259 | weight_mask = torch.cat([weight_mask[None], depth_map[None]], 0) 260 | mtv_weight_mask.append(weight_mask[None]) 261 | 262 | 263 | mtv_img = torch.cat(mtv_img, 0) 264 | mtv_condition = torch.cat(mtv_condition, 0) 265 | 266 | 267 | data_dict = { 268 | "pixel_values": mtv_img, 269 | "path_img": mtv_path_img, 270 | "ctrl_img": mtv_condition, 271 | } 272 | 273 | if self.args.use_sdxl: 274 | data_dict.update(self.input_ids) 275 | else: 276 | mtv_prompt = torch.cat(mtv_prompt, 0) 277 | data_dict.update({"input_ids": mtv_prompt}) 278 | 279 | 280 | if self.trainorval == 'val': 281 | mtv_occ_rgb = torch.cat(mtv_occ_rgb, 0) 282 | data_dict["occ_rgb"] = mtv_occ_rgb 283 | elif self.trainorval == 'train': 284 | mtv_weight_mask = torch.cat(mtv_weight_mask, 0) 285 | data_dict["weight_mask"] = mtv_weight_mask 286 | 287 | return data_dict 288 | 289 | 290 | def get_cat_ids(self, idx): 291 | """Get category distribution of single scene. 292 | 293 | Args: 294 | idx (int): Index of the data_info. 295 | 296 | Returns: 297 | dict[list]: for each category, if the current scene 298 | contains such boxes, store a list containing idx, 299 | otherwise, store empty list. 300 | """ 301 | info = self.dataset[idx] 302 | if self.use_valid_flag: 303 | mask = info['valid_flag'] 304 | gt_names = set(info['gt_names'][mask]) 305 | else: 306 | gt_names = set(info['gt_names']) 307 | 308 | cat_ids = [] 309 | fore_flag = 0 310 | for name in gt_names: 311 | if name in self.CLASSES: 312 | cat_ids.append(self.cat2id[name]) 313 | fore_flag = 1 314 | if fore_flag == 0: 315 | # model background as two objects 316 | for _ in range (120): 317 | cat_ids.append(self.cat2id['background']) 318 | return cat_ids 319 | 320 | 321 | 322 | def _get_sample_indices(self): 323 | """Load annotations from ann_file. 324 | 325 | Args: 326 | ann_file (str): Path of the annotation file. 327 | 328 | Returns: 329 | list[dict]: List of annotations after class sampling. 330 | """ 331 | class_sample_idxs = {cat_id: [] for cat_id in self.cat2id.values()} 332 | for idx in range(len(self.dataset)): 333 | sample_cat_ids = self.get_cat_ids(idx) 334 | for cat_id in sample_cat_ids: 335 | class_sample_idxs[cat_id].append(idx) 336 | duplicated_samples = sum( 337 | [len(v) for _, v in class_sample_idxs.items()]) 338 | class_distribution = { 339 | k: len(v) / duplicated_samples 340 | for k, v in class_sample_idxs.items() 341 | } 342 | # print(class_sample_idxs, class_distribution) 343 | for key, value in class_sample_idxs.items(): 344 | print(key, len(value)) 345 | 346 | sample_indices = [] 347 | 348 | frac = 1.0 / len(self.CLASSES) 349 | ratios = [frac / v for v in class_distribution.values()] 350 | for cls_inds, ratio in zip(list(class_sample_idxs.values()), ratios): 351 | sample_indices += np.random.choice(cls_inds, 352 | int(len(cls_inds) * 353 | ratio)).tolist() 354 | return sample_indices 355 | -------------------------------------------------------------------------------- /utils/gen_mtp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import os 5 | import tqdm 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | 11 | import pickle 12 | from scipy import sparse 13 | 14 | 15 | data_file = "./data/nuscenes/nuscenes_occ_infos_val.pkl" 16 | # data_file = './data/nuscenes/nuscenes_occ_infos_train.pkl' 17 | 18 | with open(data_file, "rb") as file: 19 | nus_pkl = pickle.load(file) 20 | 21 | 22 | 23 | dataroot = "./data/nuscenes" 24 | save_name = "samples_syntheocc_surocc" 25 | gt_path = os.path.join(dataroot, save_name) 26 | os.makedirs(gt_path, exist_ok=True) 27 | 28 | 29 | CAM_NAMES = [ 30 | "CAM_FRONT_LEFT", 31 | "CAM_FRONT", 32 | "CAM_FRONT_RIGHT", 33 | "CAM_BACK_RIGHT", 34 | "CAM_BACK", 35 | "CAM_BACK_LEFT", 36 | ] 37 | for j in CAM_NAMES: 38 | os.makedirs(gt_path + "/" + j, exist_ok=True) 39 | 40 | 41 | def process_func(idx, rank): 42 | 43 | info = nus_pkl["infos"][idx] 44 | 45 | curr_name = info["lidar_path"].split("/")[-1] 46 | occ_path = f"data/nuscenes/dense_voxels_with_semantic_z-5/{curr_name}.npy" 47 | 48 | occ = np.load(occ_path)[:, [2, 1, 0, 3]] 49 | point_cloud_range = [-50, -50, -5.0, 50, 50, 3.0] 50 | 51 | num_classes = 16 52 | occupancy_size = [0.5, 0.5, 0.5] 53 | grid_size = [200, 200, 16] 54 | 55 | occupancy_size = [0.2, 0.2, 0.2] 56 | grid_size = [500, 500, 40] 57 | 58 | pc_range = torch.tensor(point_cloud_range) 59 | voxel_size = (pc_range[3:] - pc_range[:3]) / torch.tensor(grid_size) 60 | 61 | raw_w = 1600 62 | raw_h = 900 63 | 64 | img_w = 100 # target reso 65 | img_h = 56 66 | 67 | # img_w = 800 68 | # img_h = 448 69 | 70 | mtp_num = 96 71 | 72 | f = 0.0055 73 | 74 | def voxel2world(voxel): 75 | return voxel * voxel_size[None, :] + pc_range[:3][None, :] 76 | 77 | def world2voxel(wolrd): 78 | return (wolrd - pc_range[:3][None, :]) / voxel_size[None, :] 79 | 80 | colors_map = torch.tensor( 81 | [ 82 | [0, 0, 0, 255], # unknown 83 | [255, 158, 0, 255], # 1 car orange 84 | [255, 99, 71, 255], # 2 truck Tomato 85 | [255, 140, 0, 255], # 3 trailer Darkorange 86 | [255, 69, 0, 255], # 4 bus Orangered 87 | [233, 150, 70, 255], # 5 construction_vehicle Darksalmon 88 | [220, 20, 60, 255], # 6 bicycle Crimson 89 | [255, 61, 99, 255], # 7 motorcycle Red 90 | [0, 0, 230, 255], # 8 pedestrian Blue 91 | [47, 79, 79, 255], # 9 traffic_cone Darkslategrey 92 | [112, 128, 144, 255], # 10 barrier Slategrey 93 | [0, 207, 191, 255], # 11 driveable_surface nuTonomy green 94 | [175, 0, 75, 255], # 12 other_flat 95 | [75, 0, 75, 255], # 13 sidewalk 96 | [112, 180, 60, 255], # 14 terrain 97 | [222, 184, 135, 255], # 15 manmade Burlywood 98 | [0, 175, 0, 255], # 16 vegetation Green 99 | ] 100 | ).type(torch.uint8) 101 | 102 | c, r = np.meshgrid(np.arange(img_w), np.arange(img_h)) 103 | uv = np.stack([c, r]) 104 | uv = torch.tensor(uv) 105 | 106 | depth = ( 107 | torch.arange(0.2, 51.4, 0.2)[..., None][..., None] 108 | .repeat(1, img_h, 1) 109 | .repeat(1, 1, img_w) 110 | ) 111 | 112 | image_paths = [] 113 | lidar2img_rts = [] 114 | lidar2cam_rts = [] 115 | cam_intrinsics = [] 116 | cam_positions = [] 117 | focal_positions = [] 118 | for cam_type, cam_info in info["cams"].items(): 119 | image_paths.append(cam_info["data_path"]) 120 | cam_info["sensor2lidar_rotation"] = torch.tensor( 121 | cam_info["sensor2lidar_rotation"] 122 | ) 123 | cam_info["sensor2lidar_translation"] = torch.tensor( 124 | cam_info["sensor2lidar_translation"] 125 | ) 126 | cam_info["cam_intrinsic"] = torch.tensor(cam_info["cam_intrinsic"]) 127 | # obtain lidar to image transformation matrix 128 | lidar2cam_r = torch.linalg.inv(cam_info["sensor2lidar_rotation"]) 129 | lidar2cam_t = cam_info["sensor2lidar_translation"] @ lidar2cam_r.T 130 | lidar2cam_rt = torch.eye(4) 131 | lidar2cam_rt[:3, :3] = lidar2cam_r.T 132 | lidar2cam_rt[3, :3] = -lidar2cam_t 133 | intrinsic = cam_info["cam_intrinsic"] 134 | viewpad = torch.eye(4) 135 | viewpad[: intrinsic.shape[0], : intrinsic.shape[1]] = intrinsic 136 | lidar2img_rt = viewpad @ lidar2cam_rt.T 137 | lidar2img_rts.append(lidar2img_rt) 138 | 139 | cam_intrinsics.append(viewpad) 140 | lidar2cam_rts.append(lidar2cam_rt.T) 141 | 142 | cam_position = torch.linalg.inv(lidar2cam_rt.T) @ torch.tensor( 143 | [0.0, 0.0, 0.0, 1.0] 144 | ).reshape([4, 1]) 145 | cam_positions.append(cam_position.flatten()[:3]) 146 | 147 | focal_position = torch.linalg.inv(lidar2cam_rt.T) @ torch.tensor( 148 | [0.0, 0.0, f, 1.0] 149 | ).reshape([4, 1]) 150 | 151 | focal_positions.append(focal_position.flatten()[:3]) 152 | 153 | occ = torch.tensor(occ) 154 | 155 | dense_vox = torch.zeros(grid_size).type(torch.uint8) 156 | occ_tr = occ[..., [2, 1, 0, 3]] 157 | 158 | dense_vox[occ_tr[:, 0], occ_tr[:, 1], occ_tr[:, 2]] = occ_tr[:, 3].type(torch.uint8) 159 | 160 | for cam_i in range(len(cam_intrinsics)): 161 | 162 | all_pcl = [] 163 | all_col = [] 164 | all_img_fov = [] 165 | 166 | final_img = torch.zeros((img_h, img_w, 3)).type(torch.uint8) 167 | 168 | fuse_img = torch.zeros( 169 | ( 170 | img_h, 171 | img_w, 172 | ) 173 | ).type(torch.uint8) 174 | depth_map = torch.zeros((img_h, img_w, 1)).type(torch.uint8) 175 | 176 | curr_tr = lidar2cam_rts[cam_i] 177 | cam_in = cam_intrinsics[cam_i] 178 | c_u = cam_in[0, 2] / (raw_w / img_w) 179 | c_v = cam_in[1, 2] / (raw_h / img_h) 180 | f_u = cam_in[0, 0] / (raw_w / img_w) 181 | f_v = cam_in[1, 1] / (raw_h / img_h) 182 | 183 | b_x = cam_in[0, 3] / (-f_u) # relative 184 | b_y = cam_in[1, 3] / (-f_v) 185 | 186 | dep_num = depth.shape[0] 187 | mtp_vis = [] 188 | for _ in range(mtp_num): 189 | mtp_vis.append( 190 | torch.zeros( 191 | ( 192 | img_h, 193 | img_w, 194 | ) 195 | ).type(torch.uint8) 196 | ) 197 | 198 | for dep_i in range(dep_num): 199 | # for dep_i in tqdm.tqdm(range(depth.shape[0])): 200 | dep_i = dep_num - 1 - dep_i 201 | 202 | uv_depth = ( 203 | torch.cat([uv, depth[dep_i : dep_i + 1]], 0) 204 | .reshape((3, -1)) 205 | .transpose(1, 0) 206 | ) 207 | n = uv_depth.shape[0] 208 | x = ((uv_depth[:, 0] - c_u) * uv_depth[:, 2]) / f_u + b_x 209 | y = ((uv_depth[:, 1] - c_v) * uv_depth[:, 2]) / f_v + b_y 210 | pts_3d_rect = torch.zeros((n, 3)) 211 | pts_3d_rect[:, 0] = x 212 | pts_3d_rect[:, 1] = y 213 | pts_3d_rect[:, 2] = uv_depth[:, 2] 214 | 215 | new_pcl = torch.cat([pts_3d_rect, torch.ones_like(pts_3d_rect[:, :1])], 1) 216 | 217 | new_pcl = torch.einsum("mn, an -> am", torch.linalg.inv(curr_tr), new_pcl) 218 | # new_pcl = torch.einsum("mn, an -> am", curr_tr, new_pcl) 219 | 220 | # new_pcl[:, :3] -= 0.1 221 | new_pcl[:, :3] -= occupancy_size[0] / 2 222 | 223 | new_pcl = world2voxel(new_pcl[:, :3]) 224 | new_pcl = torch.round(new_pcl, decimals=0).type(torch.int32) 225 | 226 | pts_index = torch.zeros((new_pcl.shape[0])).type(torch.uint8) 227 | 228 | valid_flag = ( 229 | ((new_pcl[:, 0] < grid_size[0]) & (new_pcl[:, 0] >= 0)) 230 | & ((new_pcl[:, 1] < grid_size[1]) & (new_pcl[:, 1] >= 0)) 231 | & ((new_pcl[:, 2] < grid_size[2]) & (new_pcl[:, 2] >= 0)) 232 | ) 233 | 234 | if valid_flag.max() > 0: 235 | pts_index[valid_flag] = dense_vox[ 236 | new_pcl[valid_flag][:, 0], 237 | new_pcl[valid_flag][:, 1], 238 | new_pcl[valid_flag][:, 2], 239 | ] 240 | 241 | col_pcl = torch.index_select(colors_map, 0, pts_index.type(torch.int32)) 242 | 243 | img_fov = col_pcl[:, :3].reshape((img_h, img_w, 3)) 244 | # cv2.imwrite(f"./exp/mtp/{dep_i:06d}.jpg", img_fov.cpu().numpy()[..., [2,1,0]]) 245 | pts_index = pts_index.reshape( 246 | ( 247 | img_h, 248 | img_w, 249 | ) 250 | ) 251 | img_flag = pts_index[..., None].repeat(1, 1, 3) 252 | final_img[img_flag != 0] = img_fov[img_flag != 0] 253 | 254 | all_img_fov.append(pts_index[None]) 255 | 256 | mtp_idx = int(dep_i // (dep_num / mtp_num)) 257 | mtp_vis[mtp_idx][pts_index != 0] = pts_index[pts_index != 0] 258 | fuse_img[pts_index != 0] = pts_index[pts_index != 0] 259 | 260 | depth_map[pts_index != 0] = dep_i 261 | 262 | save_path = image_paths[cam_i] 263 | 264 | if "samples" in save_path: 265 | save_path = save_path.replace("samples", save_name) 266 | if "sweeps" in save_path: 267 | save_path = save_path.replace( 268 | "sweeps", save_name.replace("samples", "sweeps") 269 | ) 270 | 271 | final_img = final_img[..., [2, 1, 0]].cpu().numpy() 272 | cv2.imwrite(save_path[:-4] + "_occrgb.png", final_img) 273 | 274 | # rgb_img = cv2.imread(image_paths[cam_i]) 275 | # rgb_img = cv2.resize(rgb_img, (img_w, img_h)) 276 | # final_img = np.concatenate([final_img, rgb_img], 0) 277 | # # raw_occ_rgb = cv2.imread(save_path[:-4].replace(save_name, 'samples_syntheocc') + '_occrgb.jpg') 278 | # # final_img = np.concatenate([raw_occ_rgb, final_img], 0) 279 | # cv2.imwrite('output.jpg', final_img) 280 | 281 | if 1: 282 | all_img_fov = torch.cat(all_img_fov, 0).type(torch.uint8).flip(0) 283 | mtp_96 = torch.cat([x[None] for x in mtp_vis], 0).type(torch.uint8).flip(0) 284 | 285 | mtp_96_path = save_path[:-4] + "_mtp96.npz" 286 | mtp_256_path = save_path[:-4] + "_mtp256.npz" 287 | 288 | sparse_mat = mtp_96.cpu().numpy().reshape((-1, mtp_96.shape[-1])) 289 | # allmatrix_sp = sparse.coo_matrix(sparse_mat) # 采用行优先的方式压缩矩阵 290 | allmatrix_sp = sparse.csr_matrix(sparse_mat) # 采用行优先的方式压缩矩阵 291 | sparse.save_npz(mtp_96_path, allmatrix_sp) # 保存稀疏矩阵 292 | 293 | sparse_mat = all_img_fov.cpu().numpy().reshape((-1, all_img_fov.shape[-1])) 294 | # allmatrix_sp = sparse.coo_matrix(sparse_mat) # 采用行优先的方式压缩矩阵 295 | allmatrix_sp = sparse.csr_matrix(sparse_mat) # 采用行优先的方式压缩矩阵 296 | sparse.save_npz(mtp_256_path, allmatrix_sp) # 保存稀疏矩阵 297 | 298 | # allmatrix_sp = sparse.load_npz('allmatrix_sparse.npz') 299 | # allmatrix = allmatrix_sp.toarray().reshape(mtp_96.shape) 300 | 301 | fuse_path = save_path[:-4] + "_fuseweight.png" 302 | cv2.imwrite(fuse_path, fuse_img.cpu().numpy()) 303 | 304 | depth_map_path = save_path[:-4] + "_depthmap.png" 305 | cv2.imwrite(depth_map_path, depth_map[..., 0].cpu().numpy()) 306 | 307 | 308 | def run_inference(rank, world_size, pred_results, input_datas): 309 | if rank is not None: 310 | # dist.init_process_group("gloo", rank=rank, world_size=world_size) 311 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 312 | else: 313 | rank = 0 314 | print(rank) 315 | 316 | torch.set_default_device(rank) 317 | 318 | all_list = input_datas[rank] # [::6] 319 | 320 | for i in tqdm.tqdm(all_list): 321 | process_func(i, rank) 322 | 323 | 324 | if __name__ == "__main__": 325 | os.system("export NCCL_SOCKET_IFNAME=eth1") 326 | 327 | from torch.multiprocessing import Manager 328 | 329 | world_size = 8 330 | 331 | all_len = len(nus_pkl["infos"]) 332 | val_len = all_len // 8 * 8 333 | print(all_len, val_len) 334 | 335 | all_list = torch.arange(val_len).cpu().numpy() 336 | # all_list = torch.arange(16).cpu().numpy() 337 | 338 | all_list = np.split(all_list, 8) 339 | 340 | input_datas = {} 341 | for i in range(world_size): 342 | input_datas[i] = list(all_list[i]) 343 | print(len(input_datas[i])) 344 | 345 | input_datas[0] += list(np.arange(val_len, all_len)) 346 | 347 | for i in range(world_size): 348 | print(len(input_datas[i])) 349 | 350 | # run_inference(0, 1, None, input_datas) 351 | 352 | with Manager() as manager: 353 | pred_results = manager.list() 354 | mp.spawn( 355 | run_inference, 356 | nprocs=world_size, 357 | args=( 358 | world_size, 359 | pred_results, 360 | input_datas, 361 | ), 362 | join=True, 363 | ) 364 | --------------------------------------------------------------------------------